ultralytics 8.0.136
refactor and simplify package (#3748)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
@ -1,421 +1,10 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import contextlib
|
||||
import re
|
||||
import shutil
|
||||
import importlib
|
||||
import sys
|
||||
from difflib import get_close_matches
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, ROOT, USER_CONFIG_DIR,
|
||||
IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn,
|
||||
get_settings, yaml_load, yaml_print)
|
||||
from ultralytics.utils import LOGGER
|
||||
|
||||
# Define valid tasks and modes
|
||||
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
|
||||
TASKS = 'detect', 'segment', 'classify', 'pose'
|
||||
TASK2DATA = {'detect': 'coco8.yaml', 'segment': 'coco8-seg.yaml', 'classify': 'imagenet100', 'pose': 'coco8-pose.yaml'}
|
||||
TASK2MODEL = {
|
||||
'detect': 'yolov8n.pt',
|
||||
'segment': 'yolov8n-seg.pt',
|
||||
'classify': 'yolov8n-cls.pt',
|
||||
'pose': 'yolov8n-pose.pt'}
|
||||
TASK2METRIC = {
|
||||
'detect': 'metrics/mAP50-95(B)',
|
||||
'segment': 'metrics/mAP50-95(M)',
|
||||
'classify': 'metrics/accuracy_top1',
|
||||
'pose': 'metrics/mAP50-95(P)'}
|
||||
# Set modules in sys.modules under their old name
|
||||
sys.modules['ultralytics.yolo.cfg'] = importlib.import_module('ultralytics.cfg')
|
||||
|
||||
|
||||
CLI_HELP_MSG = \
|
||||
f"""
|
||||
Arguments received: {str(['yolo'] + sys.argv[1:])}. Ultralytics 'yolo' commands use the following syntax:
|
||||
|
||||
yolo TASK MODE ARGS
|
||||
|
||||
Where TASK (optional) is one of {TASKS}
|
||||
MODE (required) is one of {MODES}
|
||||
ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
|
||||
See all ARGS at https://docs.ultralytics.com/usage/cfg or with 'yolo cfg'
|
||||
|
||||
1. Train a detection model for 10 epochs with an initial learning_rate of 0.01
|
||||
yolo train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
|
||||
|
||||
2. Predict a YouTube video using a pretrained segmentation model at image size 320:
|
||||
yolo predict model=yolov8n-seg.pt source='https://youtu.be/Zgi9g1ksQHc' imgsz=320
|
||||
|
||||
3. Val a pretrained detection model at batch-size 1 and image size 640:
|
||||
yolo val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
|
||||
|
||||
4. Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
|
||||
yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
|
||||
|
||||
5. Run special commands:
|
||||
yolo help
|
||||
yolo checks
|
||||
yolo version
|
||||
yolo settings
|
||||
yolo copy-cfg
|
||||
yolo cfg
|
||||
|
||||
Docs: https://docs.ultralytics.com
|
||||
Community: https://community.ultralytics.com
|
||||
GitHub: https://github.com/ultralytics/ultralytics
|
||||
"""
|
||||
|
||||
# Define keys for arg type checks
|
||||
CFG_FLOAT_KEYS = 'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear'
|
||||
CFG_FRACTION_KEYS = ('dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr',
|
||||
'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud',
|
||||
'fliplr', 'mosaic', 'mixup', 'copy_paste', 'conf', 'iou', 'fraction') # fraction floats 0.0 - 1.0
|
||||
CFG_INT_KEYS = ('epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
|
||||
'line_width', 'workspace', 'nbs', 'save_period')
|
||||
CFG_BOOL_KEYS = ('save', 'exist_ok', 'verbose', 'deterministic', 'single_cls', 'rect', 'cos_lr', 'overlap_mask', 'val',
|
||||
'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt', 'save_conf', 'save_crop',
|
||||
'show_labels', 'show_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks', 'boxes', 'keras',
|
||||
'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'profile')
|
||||
|
||||
|
||||
def cfg2dict(cfg):
|
||||
"""
|
||||
Convert a configuration object to a dictionary, whether it is a file path, a string, or a SimpleNamespace object.
|
||||
|
||||
Args:
|
||||
cfg (str | Path | SimpleNamespace): Configuration object to be converted to a dictionary.
|
||||
|
||||
Returns:
|
||||
cfg (dict): Configuration object in dictionary format.
|
||||
"""
|
||||
if isinstance(cfg, (str, Path)):
|
||||
cfg = yaml_load(cfg) # load dict
|
||||
elif isinstance(cfg, SimpleNamespace):
|
||||
cfg = vars(cfg) # convert to dict
|
||||
return cfg
|
||||
|
||||
|
||||
def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, overrides: Dict = None):
|
||||
"""
|
||||
Load and merge configuration data from a file or dictionary.
|
||||
|
||||
Args:
|
||||
cfg (str | Path | Dict | SimpleNamespace): Configuration data.
|
||||
overrides (str | Dict | optional): Overrides in the form of a file name or a dictionary. Default is None.
|
||||
|
||||
Returns:
|
||||
(SimpleNamespace): Training arguments namespace.
|
||||
"""
|
||||
cfg = cfg2dict(cfg)
|
||||
|
||||
# Merge overrides
|
||||
if overrides:
|
||||
overrides = cfg2dict(overrides)
|
||||
check_cfg_mismatch(cfg, overrides)
|
||||
cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)
|
||||
|
||||
# Special handling for numeric project/name
|
||||
for k in 'project', 'name':
|
||||
if k in cfg and isinstance(cfg[k], (int, float)):
|
||||
cfg[k] = str(cfg[k])
|
||||
if cfg.get('name') == 'model': # assign model to 'name' arg
|
||||
cfg['name'] = cfg.get('model', '').split('.')[0]
|
||||
LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.")
|
||||
|
||||
# Type and Value checks
|
||||
for k, v in cfg.items():
|
||||
if v is not None: # None values may be from optional args
|
||||
if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)):
|
||||
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
|
||||
elif k in CFG_FRACTION_KEYS:
|
||||
if not isinstance(v, (int, float)):
|
||||
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
|
||||
if not (0.0 <= v <= 1.0):
|
||||
raise ValueError(f"'{k}={v}' is an invalid value. "
|
||||
f"Valid '{k}' values are between 0.0 and 1.0.")
|
||||
elif k in CFG_INT_KEYS and not isinstance(v, int):
|
||||
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||
f"'{k}' must be an int (i.e. '{k}=8')")
|
||||
elif k in CFG_BOOL_KEYS and not isinstance(v, bool):
|
||||
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||
f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')")
|
||||
|
||||
# Return instance
|
||||
return IterableSimpleNamespace(**cfg)
|
||||
|
||||
|
||||
def _handle_deprecation(custom):
|
||||
"""
|
||||
Hardcoded function to handle deprecated config keys
|
||||
"""
|
||||
|
||||
for key in custom.copy().keys():
|
||||
if key == 'hide_labels':
|
||||
deprecation_warn(key, 'show_labels')
|
||||
custom['show_labels'] = custom.pop('hide_labels') == 'False'
|
||||
if key == 'hide_conf':
|
||||
deprecation_warn(key, 'show_conf')
|
||||
custom['show_conf'] = custom.pop('hide_conf') == 'False'
|
||||
if key == 'line_thickness':
|
||||
deprecation_warn(key, 'line_width')
|
||||
custom['line_width'] = custom.pop('line_thickness')
|
||||
|
||||
return custom
|
||||
|
||||
|
||||
def check_cfg_mismatch(base: Dict, custom: Dict, e=None):
|
||||
"""
|
||||
This function checks for any mismatched keys between a custom configuration list and a base configuration list.
|
||||
If any mismatched keys are found, the function prints out similar keys from the base list and exits the program.
|
||||
|
||||
Args:
|
||||
custom (dict): a dictionary of custom configuration options
|
||||
base (dict): a dictionary of base configuration options
|
||||
"""
|
||||
custom = _handle_deprecation(custom)
|
||||
base, custom = (set(x.keys()) for x in (base, custom))
|
||||
mismatched = [x for x in custom if x not in base]
|
||||
if mismatched:
|
||||
string = ''
|
||||
for x in mismatched:
|
||||
matches = get_close_matches(x, base) # key list
|
||||
matches = [f'{k}={DEFAULT_CFG_DICT[k]}' if DEFAULT_CFG_DICT.get(k) is not None else k for k in matches]
|
||||
match_str = f'Similar arguments are i.e. {matches}.' if matches else ''
|
||||
string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n"
|
||||
raise SyntaxError(string + CLI_HELP_MSG) from e
|
||||
|
||||
|
||||
def merge_equals_args(args: List[str]) -> List[str]:
|
||||
"""
|
||||
Merges arguments around isolated '=' args in a list of strings.
|
||||
The function considers cases where the first argument ends with '=' or the second starts with '=',
|
||||
as well as when the middle one is an equals sign.
|
||||
|
||||
Args:
|
||||
args (List[str]): A list of strings where each element is an argument.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of strings where the arguments around isolated '=' are merged.
|
||||
"""
|
||||
new_args = []
|
||||
for i, arg in enumerate(args):
|
||||
if arg == '=' and 0 < i < len(args) - 1: # merge ['arg', '=', 'val']
|
||||
new_args[-1] += f'={args[i + 1]}'
|
||||
del args[i + 1]
|
||||
elif arg.endswith('=') and i < len(args) - 1 and '=' not in args[i + 1]: # merge ['arg=', 'val']
|
||||
new_args.append(f'{arg}{args[i + 1]}')
|
||||
del args[i + 1]
|
||||
elif arg.startswith('=') and i > 0: # merge ['arg', '=val']
|
||||
new_args[-1] += arg
|
||||
else:
|
||||
new_args.append(arg)
|
||||
return new_args
|
||||
|
||||
|
||||
def handle_yolo_hub(args: List[str]) -> None:
|
||||
"""
|
||||
Handle Ultralytics HUB command-line interface (CLI) commands.
|
||||
|
||||
This function processes Ultralytics HUB CLI commands such as login and logout.
|
||||
It should be called when executing a script with arguments related to HUB authentication.
|
||||
|
||||
Args:
|
||||
args (List[str]): A list of command line arguments
|
||||
|
||||
Example:
|
||||
python my_script.py hub login your_api_key
|
||||
"""
|
||||
from ultralytics import hub
|
||||
|
||||
if args[0] == 'login':
|
||||
key = args[1] if len(args) > 1 else ''
|
||||
# Log in to Ultralytics HUB using the provided API key
|
||||
hub.login(key)
|
||||
elif args[0] == 'logout':
|
||||
# Log out from Ultralytics HUB
|
||||
hub.logout()
|
||||
|
||||
|
||||
def handle_yolo_settings(args: List[str]) -> None:
|
||||
"""
|
||||
Handle YOLO settings command-line interface (CLI) commands.
|
||||
|
||||
This function processes YOLO settings CLI commands such as reset.
|
||||
It should be called when executing a script with arguments related to YOLO settings management.
|
||||
|
||||
Args:
|
||||
args (List[str]): A list of command line arguments for YOLO settings management.
|
||||
|
||||
Example:
|
||||
python my_script.py yolo settings reset
|
||||
"""
|
||||
path = USER_CONFIG_DIR / 'settings.yaml' # get SETTINGS YAML file path
|
||||
if any(args) and args[0] == 'reset':
|
||||
path.unlink() # delete the settings file
|
||||
get_settings() # create new settings
|
||||
LOGGER.info('Settings reset successfully') # inform the user that settings have been reset
|
||||
yaml_print(path) # print the current settings
|
||||
|
||||
|
||||
def entrypoint(debug=''):
|
||||
"""
|
||||
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
|
||||
to the package.
|
||||
|
||||
This function allows for:
|
||||
- passing mandatory YOLO args as a list of strings
|
||||
- specifying the task to be performed, either 'detect', 'segment' or 'classify'
|
||||
- specifying the mode, either 'train', 'val', 'test', or 'predict'
|
||||
- running special modes like 'checks'
|
||||
- passing overrides to the package's configuration
|
||||
|
||||
It uses the package's default cfg and initializes it using the passed overrides.
|
||||
Then it calls the CLI function with the composed cfg
|
||||
"""
|
||||
args = (debug.split(' ') if debug else sys.argv)[1:]
|
||||
if not args: # no arguments passed
|
||||
LOGGER.info(CLI_HELP_MSG)
|
||||
return
|
||||
|
||||
special = {
|
||||
'help': lambda: LOGGER.info(CLI_HELP_MSG),
|
||||
'checks': checks.check_yolo,
|
||||
'version': lambda: LOGGER.info(__version__),
|
||||
'settings': lambda: handle_yolo_settings(args[1:]),
|
||||
'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
|
||||
'hub': lambda: handle_yolo_hub(args[1:]),
|
||||
'login': lambda: handle_yolo_hub(args),
|
||||
'copy-cfg': copy_default_cfg}
|
||||
full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special}
|
||||
|
||||
# Define common mis-uses of special commands, i.e. -h, -help, --help
|
||||
special.update({k[0]: v for k, v in special.items()}) # singular
|
||||
special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith('s')}) # singular
|
||||
special = {**special, **{f'-{k}': v for k, v in special.items()}, **{f'--{k}': v for k, v in special.items()}}
|
||||
|
||||
overrides = {} # basic overrides, i.e. imgsz=320
|
||||
for a in merge_equals_args(args): # merge spaces around '=' sign
|
||||
if a.startswith('--'):
|
||||
LOGGER.warning(f"WARNING ⚠️ '{a}' does not require leading dashes '--', updating to '{a[2:]}'.")
|
||||
a = a[2:]
|
||||
if a.endswith(','):
|
||||
LOGGER.warning(f"WARNING ⚠️ '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.")
|
||||
a = a[:-1]
|
||||
if '=' in a:
|
||||
try:
|
||||
re.sub(r' *= *', '=', a) # remove spaces around equals sign
|
||||
k, v = a.split('=', 1) # split on first '=' sign
|
||||
assert v, f"missing '{k}' value"
|
||||
if k == 'cfg': # custom.yaml passed
|
||||
LOGGER.info(f'Overriding {DEFAULT_CFG_PATH} with {v}')
|
||||
overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != 'cfg'}
|
||||
else:
|
||||
if v.lower() == 'none':
|
||||
v = None
|
||||
elif v.lower() == 'true':
|
||||
v = True
|
||||
elif v.lower() == 'false':
|
||||
v = False
|
||||
else:
|
||||
with contextlib.suppress(Exception):
|
||||
v = eval(v)
|
||||
overrides[k] = v
|
||||
except (NameError, SyntaxError, ValueError, AssertionError) as e:
|
||||
check_cfg_mismatch(full_args_dict, {a: ''}, e)
|
||||
|
||||
elif a in TASKS:
|
||||
overrides['task'] = a
|
||||
elif a in MODES:
|
||||
overrides['mode'] = a
|
||||
elif a.lower() in special:
|
||||
special[a.lower()]()
|
||||
return
|
||||
elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool):
|
||||
overrides[a] = True # auto-True for default bool args, i.e. 'yolo show' sets show=True
|
||||
elif a in DEFAULT_CFG_DICT:
|
||||
raise SyntaxError(f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign "
|
||||
f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}")
|
||||
else:
|
||||
check_cfg_mismatch(full_args_dict, {a: ''})
|
||||
|
||||
# Check keys
|
||||
check_cfg_mismatch(full_args_dict, overrides)
|
||||
|
||||
# Mode
|
||||
mode = overrides.get('mode', None)
|
||||
if mode is None:
|
||||
mode = DEFAULT_CFG.mode or 'predict'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {MODES}. Using default 'mode={mode}'.")
|
||||
elif mode not in MODES:
|
||||
if mode not in ('checks', checks):
|
||||
raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}")
|
||||
LOGGER.warning("WARNING ⚠️ 'yolo mode=checks' is deprecated. Use 'yolo checks' instead.")
|
||||
checks.check_yolo()
|
||||
return
|
||||
|
||||
# Task
|
||||
task = overrides.pop('task', None)
|
||||
if task:
|
||||
if task not in TASKS:
|
||||
raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}")
|
||||
if 'model' not in overrides:
|
||||
overrides['model'] = TASK2MODEL[task]
|
||||
|
||||
# Model
|
||||
model = overrides.pop('model', DEFAULT_CFG.model)
|
||||
if model is None:
|
||||
model = 'yolov8n.pt'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
|
||||
overrides['model'] = model
|
||||
if 'rtdetr' in model.lower(): # guess architecture
|
||||
from ultralytics import RTDETR
|
||||
model = RTDETR(model) # no task argument
|
||||
elif 'sam' in model.lower():
|
||||
from ultralytics import SAM
|
||||
model = SAM(model)
|
||||
else:
|
||||
from ultralytics import YOLO
|
||||
model = YOLO(model, task=task)
|
||||
if isinstance(overrides.get('pretrained'), str):
|
||||
model.load(overrides['pretrained'])
|
||||
|
||||
# Task Update
|
||||
if task != model.task:
|
||||
if task:
|
||||
LOGGER.warning(f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. "
|
||||
f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model.")
|
||||
task = model.task
|
||||
|
||||
# Mode
|
||||
if mode in ('predict', 'track') and 'source' not in overrides:
|
||||
overrides['source'] = DEFAULT_CFG.source or ROOT / 'assets' if (ROOT / 'assets').exists() \
|
||||
else 'https://ultralytics.com/images/bus.jpg'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
|
||||
elif mode in ('train', 'val'):
|
||||
if 'data' not in overrides:
|
||||
overrides['data'] = TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
|
||||
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
|
||||
elif mode == 'export':
|
||||
if 'format' not in overrides:
|
||||
overrides['format'] = DEFAULT_CFG.format or 'torchscript'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.")
|
||||
|
||||
# Run command in python
|
||||
# getattr(model, mode)(**vars(get_cfg(overrides=overrides))) # default args using default.yaml
|
||||
getattr(model, mode)(**overrides) # default args from model
|
||||
|
||||
|
||||
# Special modes --------------------------------------------------------------------------------------------------------
|
||||
def copy_default_cfg():
|
||||
"""Copy and create a new default configuration file with '_copy' appended to its name."""
|
||||
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
|
||||
shutil.copy2(DEFAULT_CFG_PATH, new_file)
|
||||
LOGGER.info(f'{DEFAULT_CFG_PATH} copied to {new_file}\n'
|
||||
f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Example Usage: entrypoint(debug='yolo predict model=yolov8n.pt')
|
||||
entrypoint(debug='')
|
||||
LOGGER.warning("WARNING ⚠️ 'ultralytics.yolo.cfg' is deprecated since '8.0.136' and will be removed in '8.1.0'. "
|
||||
"Please use 'ultralytics.cfg' instead.")
|
||||
|
@ -1,114 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
# Default training settings and hyperparameters for medium-augmentation COCO training
|
||||
|
||||
task: detect # (str) YOLO task, i.e. detect, segment, classify, pose
|
||||
mode: train # (str) YOLO mode, i.e. train, val, predict, export, track, benchmark
|
||||
|
||||
# Train settings -------------------------------------------------------------------------------------------------------
|
||||
model: # (str, optional) path to model file, i.e. yolov8n.pt, yolov8n.yaml
|
||||
data: # (str, optional) path to data file, i.e. coco128.yaml
|
||||
epochs: 100 # (int) number of epochs to train for
|
||||
patience: 50 # (int) epochs to wait for no observable improvement for early stopping of training
|
||||
batch: 16 # (int) number of images per batch (-1 for AutoBatch)
|
||||
imgsz: 640 # (int | list) input images size as int for train and val modes, or list[w,h] for predict and export modes
|
||||
save: True # (bool) save train checkpoints and predict results
|
||||
save_period: -1 # (int) Save checkpoint every x epochs (disabled if < 1)
|
||||
cache: False # (bool) True/ram, disk or False. Use cache for data loading
|
||||
device: # (int | str | list, optional) device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu
|
||||
workers: 8 # (int) number of worker threads for data loading (per RANK if DDP)
|
||||
project: # (str, optional) project name
|
||||
name: # (str, optional) experiment name, results saved to 'project/name' directory
|
||||
exist_ok: False # (bool) whether to overwrite existing experiment
|
||||
pretrained: True # (bool | str) whether to use a pretrained model (bool) or a model to load weights from (str)
|
||||
optimizer: auto # (str) optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto]
|
||||
verbose: True # (bool) whether to print verbose output
|
||||
seed: 0 # (int) random seed for reproducibility
|
||||
deterministic: True # (bool) whether to enable deterministic mode
|
||||
single_cls: False # (bool) train multi-class data as single-class
|
||||
rect: False # (bool) rectangular training if mode='train' or rectangular validation if mode='val'
|
||||
cos_lr: False # (bool) use cosine learning rate scheduler
|
||||
close_mosaic: 10 # (int) disable mosaic augmentation for final epochs
|
||||
resume: False # (bool) resume training from last checkpoint
|
||||
amp: True # (bool) Automatic Mixed Precision (AMP) training, choices=[True, False], True runs AMP check
|
||||
fraction: 1.0 # (float) dataset fraction to train on (default is 1.0, all images in train set)
|
||||
profile: False # (bool) profile ONNX and TensorRT speeds during training for loggers
|
||||
# Segmentation
|
||||
overlap_mask: True # (bool) masks should overlap during training (segment train only)
|
||||
mask_ratio: 4 # (int) mask downsample ratio (segment train only)
|
||||
# Classification
|
||||
dropout: 0.0 # (float) use dropout regularization (classify train only)
|
||||
|
||||
# Val/Test settings ----------------------------------------------------------------------------------------------------
|
||||
val: True # (bool) validate/test during training
|
||||
split: val # (str) dataset split to use for validation, i.e. 'val', 'test' or 'train'
|
||||
save_json: False # (bool) save results to JSON file
|
||||
save_hybrid: False # (bool) save hybrid version of labels (labels + additional predictions)
|
||||
conf: # (float, optional) object confidence threshold for detection (default 0.25 predict, 0.001 val)
|
||||
iou: 0.7 # (float) intersection over union (IoU) threshold for NMS
|
||||
max_det: 300 # (int) maximum number of detections per image
|
||||
half: False # (bool) use half precision (FP16)
|
||||
dnn: False # (bool) use OpenCV DNN for ONNX inference
|
||||
plots: True # (bool) save plots during train/val
|
||||
|
||||
# Prediction settings --------------------------------------------------------------------------------------------------
|
||||
source: # (str, optional) source directory for images or videos
|
||||
show: False # (bool) show results if possible
|
||||
save_txt: False # (bool) save results as .txt file
|
||||
save_conf: False # (bool) save results with confidence scores
|
||||
save_crop: False # (bool) save cropped images with results
|
||||
show_labels: True # (bool) show object labels in plots
|
||||
show_conf: True # (bool) show object confidence scores in plots
|
||||
vid_stride: 1 # (int) video frame-rate stride
|
||||
line_width: # (int, optional) line width of the bounding boxes, auto if missing
|
||||
visualize: False # (bool) visualize model features
|
||||
augment: False # (bool) apply image augmentation to prediction sources
|
||||
agnostic_nms: False # (bool) class-agnostic NMS
|
||||
classes: # (int | list[int], optional) filter results by class, i.e. class=0, or class=[0,2,3]
|
||||
retina_masks: False # (bool) use high-resolution segmentation masks
|
||||
boxes: True # (bool) Show boxes in segmentation predictions
|
||||
|
||||
# Export settings ------------------------------------------------------------------------------------------------------
|
||||
format: torchscript # (str) format to export to, choices at https://docs.ultralytics.com/modes/export/#export-formats
|
||||
keras: False # (bool) use Kera=s
|
||||
optimize: False # (bool) TorchScript: optimize for mobile
|
||||
int8: False # (bool) CoreML/TF INT8 quantization
|
||||
dynamic: False # (bool) ONNX/TF/TensorRT: dynamic axes
|
||||
simplify: False # (bool) ONNX: simplify model
|
||||
opset: # (int, optional) ONNX: opset version
|
||||
workspace: 4 # (int) TensorRT: workspace size (GB)
|
||||
nms: False # (bool) CoreML: add NMS
|
||||
|
||||
# Hyperparameters ------------------------------------------------------------------------------------------------------
|
||||
lr0: 0.01 # (float) initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
|
||||
lrf: 0.01 # (float) final learning rate (lr0 * lrf)
|
||||
momentum: 0.937 # (float) SGD momentum/Adam beta1
|
||||
weight_decay: 0.0005 # (float) optimizer weight decay 5e-4
|
||||
warmup_epochs: 3.0 # (float) warmup epochs (fractions ok)
|
||||
warmup_momentum: 0.8 # (float) warmup initial momentum
|
||||
warmup_bias_lr: 0.1 # (float) warmup initial bias lr
|
||||
box: 7.5 # (float) box loss gain
|
||||
cls: 0.5 # (float) cls loss gain (scale with pixels)
|
||||
dfl: 1.5 # (float) dfl loss gain
|
||||
pose: 12.0 # (float) pose loss gain
|
||||
kobj: 1.0 # (float) keypoint obj loss gain
|
||||
label_smoothing: 0.0 # (float) label smoothing (fraction)
|
||||
nbs: 64 # (int) nominal batch size
|
||||
hsv_h: 0.015 # (float) image HSV-Hue augmentation (fraction)
|
||||
hsv_s: 0.7 # (float) image HSV-Saturation augmentation (fraction)
|
||||
hsv_v: 0.4 # (float) image HSV-Value augmentation (fraction)
|
||||
degrees: 0.0 # (float) image rotation (+/- deg)
|
||||
translate: 0.1 # (float) image translation (+/- fraction)
|
||||
scale: 0.5 # (float) image scale (+/- gain)
|
||||
shear: 0.0 # (float) image shear (+/- deg)
|
||||
perspective: 0.0 # (float) image perspective (+/- fraction), range 0-0.001
|
||||
flipud: 0.0 # (float) image flip up-down (probability)
|
||||
fliplr: 0.5 # (float) image flip left-right (probability)
|
||||
mosaic: 1.0 # (float) image mosaic (probability)
|
||||
mixup: 0.0 # (float) image mixup (probability)
|
||||
copy_paste: 0.0 # (float) segment copy-paste (probability)
|
||||
|
||||
# Custom config.yaml ---------------------------------------------------------------------------------------------------
|
||||
cfg: # (str, optional) for overriding defaults.yaml
|
||||
|
||||
# Tracker settings ------------------------------------------------------------------------------------------------------
|
||||
tracker: botsort.yaml # (str) tracker type, choices=[botsort.yaml, bytetrack.yaml]
|
@ -1,9 +1,17 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
from .base import BaseDataset
|
||||
from .build import build_dataloader, build_yolo_dataset, load_inference_source
|
||||
from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
|
||||
from .dataset_wrappers import MixAndRectDataset
|
||||
from ultralytics.utils import LOGGER
|
||||
|
||||
__all__ = ('BaseDataset', 'ClassificationDataset', 'MixAndRectDataset', 'SemanticDataset', 'YOLODataset',
|
||||
'build_yolo_dataset', 'build_dataloader', 'load_inference_source')
|
||||
# Set modules in sys.modules under their old name
|
||||
sys.modules['ultralytics.yolo.data'] = importlib.import_module('ultralytics.data')
|
||||
# This is for updating old cls models, or the way in following warning won't work.
|
||||
sys.modules['ultralytics.yolo.data.augment'] = importlib.import_module('ultralytics.data.augment')
|
||||
|
||||
DATA_WARNING = """WARNING ⚠️ 'ultralytics.yolo.data' is deprecated since '8.0.136' and will be removed in '8.1.0'. Please use 'ultralytics.data' instead.
|
||||
Note this warning may be related to loading older models. You can update your model to current structure with:
|
||||
import torch
|
||||
ckpt = torch.load("model.pt") # applies to both official and custom models
|
||||
torch.save(ckpt, "updated-model.pt")
|
||||
"""
|
||||
LOGGER.warning(DATA_WARNING)
|
||||
|
@ -1,39 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics import SAM, YOLO
|
||||
|
||||
|
||||
def auto_annotate(data, det_model='yolov8x.pt', sam_model='sam_b.pt', device='', output_dir=None):
|
||||
"""
|
||||
Automatically annotates images using a YOLO object detection model and a SAM segmentation model.
|
||||
Args:
|
||||
data (str): Path to a folder containing images to be annotated.
|
||||
det_model (str, optional): Pre-trained YOLO detection model. Defaults to 'yolov8x.pt'.
|
||||
sam_model (str, optional): Pre-trained SAM segmentation model. Defaults to 'sam_b.pt'.
|
||||
device (str, optional): Device to run the models on. Defaults to an empty string (CPU or GPU, if available).
|
||||
output_dir (str | None | optional): Directory to save the annotated results.
|
||||
Defaults to a 'labels' folder in the same directory as 'data'.
|
||||
"""
|
||||
det_model = YOLO(det_model)
|
||||
sam_model = SAM(sam_model)
|
||||
|
||||
if not output_dir:
|
||||
output_dir = Path(str(data)).parent / 'labels'
|
||||
Path(output_dir).mkdir(exist_ok=True, parents=True)
|
||||
|
||||
det_results = det_model(data, stream=True, device=device)
|
||||
|
||||
for result in det_results:
|
||||
boxes = result.boxes.xyxy # Boxes object for bbox outputs
|
||||
class_ids = result.boxes.cls.int().tolist() # noqa
|
||||
if len(class_ids):
|
||||
sam_results = sam_model(result.orig_img, bboxes=boxes, verbose=False, save=False, device=device)
|
||||
segments = sam_results[0].masks.xyn # noqa
|
||||
|
||||
with open(str(Path(output_dir) / Path(result.path).stem) + '.txt', 'w') as f:
|
||||
for i in range(len(segments)):
|
||||
s = segments[i]
|
||||
if len(s) == 0:
|
||||
continue
|
||||
segment = map(str, segments[i].reshape(-1).tolist())
|
||||
f.write(f'{class_ids[i]} ' + ' '.join(segment) + '\n')
|
@ -1,905 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import math
|
||||
import random
|
||||
from copy import deepcopy
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
|
||||
from ..utils import LOGGER, colorstr
|
||||
from ..utils.checks import check_version
|
||||
from ..utils.instance import Instances
|
||||
from ..utils.metrics import bbox_ioa
|
||||
from ..utils.ops import segment2box
|
||||
from .utils import polygons2masks, polygons2masks_overlap
|
||||
|
||||
POSE_FLIPLR_INDEX = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
|
||||
|
||||
|
||||
# TODO: we might need a BaseTransform to make all these augments be compatible with both classification and semantic
|
||||
class BaseTransform:
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def apply_image(self, labels):
|
||||
"""Applies image transformation to labels."""
|
||||
pass
|
||||
|
||||
def apply_instances(self, labels):
|
||||
"""Applies transformations to input 'labels' and returns object instances."""
|
||||
pass
|
||||
|
||||
def apply_semantic(self, labels):
|
||||
"""Applies semantic segmentation to an image."""
|
||||
pass
|
||||
|
||||
def __call__(self, labels):
|
||||
"""Applies label transformations to an image, instances and semantic masks."""
|
||||
self.apply_image(labels)
|
||||
self.apply_instances(labels)
|
||||
self.apply_semantic(labels)
|
||||
|
||||
|
||||
class Compose:
|
||||
|
||||
def __init__(self, transforms):
|
||||
"""Initializes the Compose object with a list of transforms."""
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, data):
|
||||
"""Applies a series of transformations to input data."""
|
||||
for t in self.transforms:
|
||||
data = t(data)
|
||||
return data
|
||||
|
||||
def append(self, transform):
|
||||
"""Appends a new transform to the existing list of transforms."""
|
||||
self.transforms.append(transform)
|
||||
|
||||
def tolist(self):
|
||||
"""Converts list of transforms to a standard Python list."""
|
||||
return self.transforms
|
||||
|
||||
def __repr__(self):
|
||||
"""Return string representation of object."""
|
||||
format_string = f'{self.__class__.__name__}('
|
||||
for t in self.transforms:
|
||||
format_string += '\n'
|
||||
format_string += f' {t}'
|
||||
format_string += '\n)'
|
||||
return format_string
|
||||
|
||||
|
||||
class BaseMixTransform:
|
||||
"""This implementation is from mmyolo."""
|
||||
|
||||
def __init__(self, dataset, pre_transform=None, p=0.0) -> None:
|
||||
self.dataset = dataset
|
||||
self.pre_transform = pre_transform
|
||||
self.p = p
|
||||
|
||||
def __call__(self, labels):
|
||||
"""Applies pre-processing transforms and mixup/mosaic transforms to labels data."""
|
||||
if random.uniform(0, 1) > self.p:
|
||||
return labels
|
||||
|
||||
# Get index of one or three other images
|
||||
indexes = self.get_indexes()
|
||||
if isinstance(indexes, int):
|
||||
indexes = [indexes]
|
||||
|
||||
# Get images information will be used for Mosaic or MixUp
|
||||
mix_labels = [self.dataset.get_image_and_label(i) for i in indexes]
|
||||
|
||||
if self.pre_transform is not None:
|
||||
for i, data in enumerate(mix_labels):
|
||||
mix_labels[i] = self.pre_transform(data)
|
||||
labels['mix_labels'] = mix_labels
|
||||
|
||||
# Mosaic or MixUp
|
||||
labels = self._mix_transform(labels)
|
||||
labels.pop('mix_labels', None)
|
||||
return labels
|
||||
|
||||
def _mix_transform(self, labels):
|
||||
"""Applies MixUp or Mosaic augmentation to the label dictionary."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_indexes(self):
|
||||
"""Gets a list of shuffled indexes for mosaic augmentation."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Mosaic(BaseMixTransform):
|
||||
"""
|
||||
Mosaic augmentation.
|
||||
|
||||
This class performs mosaic augmentation by combining multiple (4 or 9) images into a single mosaic image.
|
||||
The augmentation is applied to a dataset with a given probability.
|
||||
|
||||
Attributes:
|
||||
dataset: The dataset on which the mosaic augmentation is applied.
|
||||
imgsz (int, optional): Image size (height and width) after mosaic pipeline of a single image. Default to 640.
|
||||
p (float, optional): Probability of applying the mosaic augmentation. Must be in the range 0-1. Default to 1.0.
|
||||
n (int, optional): The grid size, either 4 (for 2x2) or 9 (for 3x3).
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, imgsz=640, p=1.0, n=4):
|
||||
"""Initializes the object with a dataset, image size, probability, and border."""
|
||||
assert 0 <= p <= 1.0, f'The probability should be in range [0, 1], but got {p}.'
|
||||
assert n in (4, 9), 'grid must be equal to 4 or 9.'
|
||||
super().__init__(dataset=dataset, p=p)
|
||||
self.dataset = dataset
|
||||
self.imgsz = imgsz
|
||||
self.border = (-imgsz // 2, -imgsz // 2) # width, height
|
||||
self.n = n
|
||||
|
||||
def get_indexes(self, buffer=True):
|
||||
"""Return a list of random indexes from the dataset."""
|
||||
if buffer: # select images from buffer
|
||||
return random.choices(list(self.dataset.buffer), k=self.n - 1)
|
||||
else: # select any images
|
||||
return [random.randint(0, len(self.dataset) - 1) for _ in range(self.n - 1)]
|
||||
|
||||
def _mix_transform(self, labels):
|
||||
"""Apply mixup transformation to the input image and labels."""
|
||||
assert labels.get('rect_shape', None) is None, 'rect and mosaic are mutually exclusive.'
|
||||
assert len(labels.get('mix_labels', [])), 'There are no other images for mosaic augment.'
|
||||
return self._mosaic4(labels) if self.n == 4 else self._mosaic9(labels)
|
||||
|
||||
def _mosaic4(self, labels):
|
||||
"""Create a 2x2 image mosaic."""
|
||||
mosaic_labels = []
|
||||
s = self.imgsz
|
||||
yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y
|
||||
for i in range(4):
|
||||
labels_patch = labels if i == 0 else labels['mix_labels'][i - 1]
|
||||
# Load image
|
||||
img = labels_patch['img']
|
||||
h, w = labels_patch.pop('resized_shape')
|
||||
|
||||
# Place img in img4
|
||||
if i == 0: # top left
|
||||
img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
|
||||
x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
|
||||
x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
|
||||
elif i == 1: # top right
|
||||
x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
|
||||
x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
|
||||
elif i == 2: # bottom left
|
||||
x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
|
||||
x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
|
||||
elif i == 3: # bottom right
|
||||
x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
|
||||
x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
|
||||
|
||||
img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
|
||||
padw = x1a - x1b
|
||||
padh = y1a - y1b
|
||||
|
||||
labels_patch = self._update_labels(labels_patch, padw, padh)
|
||||
mosaic_labels.append(labels_patch)
|
||||
final_labels = self._cat_labels(mosaic_labels)
|
||||
final_labels['img'] = img4
|
||||
return final_labels
|
||||
|
||||
def _mosaic9(self, labels):
|
||||
"""Create a 3x3 image mosaic."""
|
||||
mosaic_labels = []
|
||||
s = self.imgsz
|
||||
hp, wp = -1, -1 # height, width previous
|
||||
for i in range(9):
|
||||
labels_patch = labels if i == 0 else labels['mix_labels'][i - 1]
|
||||
# Load image
|
||||
img = labels_patch['img']
|
||||
h, w = labels_patch.pop('resized_shape')
|
||||
|
||||
# Place img in img9
|
||||
if i == 0: # center
|
||||
img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
|
||||
h0, w0 = h, w
|
||||
c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
|
||||
elif i == 1: # top
|
||||
c = s, s - h, s + w, s
|
||||
elif i == 2: # top right
|
||||
c = s + wp, s - h, s + wp + w, s
|
||||
elif i == 3: # right
|
||||
c = s + w0, s, s + w0 + w, s + h
|
||||
elif i == 4: # bottom right
|
||||
c = s + w0, s + hp, s + w0 + w, s + hp + h
|
||||
elif i == 5: # bottom
|
||||
c = s + w0 - w, s + h0, s + w0, s + h0 + h
|
||||
elif i == 6: # bottom left
|
||||
c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
|
||||
elif i == 7: # left
|
||||
c = s - w, s + h0 - h, s, s + h0
|
||||
elif i == 8: # top left
|
||||
c = s - w, s + h0 - hp - h, s, s + h0 - hp
|
||||
|
||||
padw, padh = c[:2]
|
||||
x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
|
||||
|
||||
# Image
|
||||
img9[y1:y2, x1:x2] = img[y1 - padh:, x1 - padw:] # img9[ymin:ymax, xmin:xmax]
|
||||
hp, wp = h, w # height, width previous for next iteration
|
||||
|
||||
# Labels assuming imgsz*2 mosaic size
|
||||
labels_patch = self._update_labels(labels_patch, padw + self.border[0], padh + self.border[1])
|
||||
mosaic_labels.append(labels_patch)
|
||||
final_labels = self._cat_labels(mosaic_labels)
|
||||
|
||||
final_labels['img'] = img9[-self.border[0]:self.border[0], -self.border[1]:self.border[1]]
|
||||
return final_labels
|
||||
|
||||
@staticmethod
|
||||
def _update_labels(labels, padw, padh):
|
||||
"""Update labels."""
|
||||
nh, nw = labels['img'].shape[:2]
|
||||
labels['instances'].convert_bbox(format='xyxy')
|
||||
labels['instances'].denormalize(nw, nh)
|
||||
labels['instances'].add_padding(padw, padh)
|
||||
return labels
|
||||
|
||||
def _cat_labels(self, mosaic_labels):
|
||||
"""Return labels with mosaic border instances clipped."""
|
||||
if len(mosaic_labels) == 0:
|
||||
return {}
|
||||
cls = []
|
||||
instances = []
|
||||
imgsz = self.imgsz * 2 # mosaic imgsz
|
||||
for labels in mosaic_labels:
|
||||
cls.append(labels['cls'])
|
||||
instances.append(labels['instances'])
|
||||
final_labels = {
|
||||
'im_file': mosaic_labels[0]['im_file'],
|
||||
'ori_shape': mosaic_labels[0]['ori_shape'],
|
||||
'resized_shape': (imgsz, imgsz),
|
||||
'cls': np.concatenate(cls, 0),
|
||||
'instances': Instances.concatenate(instances, axis=0),
|
||||
'mosaic_border': self.border} # final_labels
|
||||
final_labels['instances'].clip(imgsz, imgsz)
|
||||
good = final_labels['instances'].remove_zero_area_boxes()
|
||||
final_labels['cls'] = final_labels['cls'][good]
|
||||
return final_labels
|
||||
|
||||
|
||||
class MixUp(BaseMixTransform):
|
||||
|
||||
def __init__(self, dataset, pre_transform=None, p=0.0) -> None:
|
||||
super().__init__(dataset=dataset, pre_transform=pre_transform, p=p)
|
||||
|
||||
def get_indexes(self):
|
||||
"""Get a random index from the dataset."""
|
||||
return random.randint(0, len(self.dataset) - 1)
|
||||
|
||||
def _mix_transform(self, labels):
|
||||
"""Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf."""
|
||||
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
|
||||
labels2 = labels['mix_labels'][0]
|
||||
labels['img'] = (labels['img'] * r + labels2['img'] * (1 - r)).astype(np.uint8)
|
||||
labels['instances'] = Instances.concatenate([labels['instances'], labels2['instances']], axis=0)
|
||||
labels['cls'] = np.concatenate([labels['cls'], labels2['cls']], 0)
|
||||
return labels
|
||||
|
||||
|
||||
class RandomPerspective:
|
||||
|
||||
def __init__(self,
|
||||
degrees=0.0,
|
||||
translate=0.1,
|
||||
scale=0.5,
|
||||
shear=0.0,
|
||||
perspective=0.0,
|
||||
border=(0, 0),
|
||||
pre_transform=None):
|
||||
self.degrees = degrees
|
||||
self.translate = translate
|
||||
self.scale = scale
|
||||
self.shear = shear
|
||||
self.perspective = perspective
|
||||
# Mosaic border
|
||||
self.border = border
|
||||
self.pre_transform = pre_transform
|
||||
|
||||
def affine_transform(self, img, border):
|
||||
"""Center."""
|
||||
C = np.eye(3, dtype=np.float32)
|
||||
|
||||
C[0, 2] = -img.shape[1] / 2 # x translation (pixels)
|
||||
C[1, 2] = -img.shape[0] / 2 # y translation (pixels)
|
||||
|
||||
# Perspective
|
||||
P = np.eye(3, dtype=np.float32)
|
||||
P[2, 0] = random.uniform(-self.perspective, self.perspective) # x perspective (about y)
|
||||
P[2, 1] = random.uniform(-self.perspective, self.perspective) # y perspective (about x)
|
||||
|
||||
# Rotation and Scale
|
||||
R = np.eye(3, dtype=np.float32)
|
||||
a = random.uniform(-self.degrees, self.degrees)
|
||||
# a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
|
||||
s = random.uniform(1 - self.scale, 1 + self.scale)
|
||||
# s = 2 ** random.uniform(-scale, scale)
|
||||
R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
|
||||
|
||||
# Shear
|
||||
S = np.eye(3, dtype=np.float32)
|
||||
S[0, 1] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) # x shear (deg)
|
||||
S[1, 0] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) # y shear (deg)
|
||||
|
||||
# Translation
|
||||
T = np.eye(3, dtype=np.float32)
|
||||
T[0, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[0] # x translation (pixels)
|
||||
T[1, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[1] # y translation (pixels)
|
||||
|
||||
# Combined rotation matrix
|
||||
M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
|
||||
# Affine image
|
||||
if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
|
||||
if self.perspective:
|
||||
img = cv2.warpPerspective(img, M, dsize=self.size, borderValue=(114, 114, 114))
|
||||
else: # affine
|
||||
img = cv2.warpAffine(img, M[:2], dsize=self.size, borderValue=(114, 114, 114))
|
||||
return img, M, s
|
||||
|
||||
def apply_bboxes(self, bboxes, M):
|
||||
"""
|
||||
Apply affine to bboxes only.
|
||||
|
||||
Args:
|
||||
bboxes (ndarray): list of bboxes, xyxy format, with shape (num_bboxes, 4).
|
||||
M (ndarray): affine matrix.
|
||||
|
||||
Returns:
|
||||
new_bboxes (ndarray): bboxes after affine, [num_bboxes, 4].
|
||||
"""
|
||||
n = len(bboxes)
|
||||
if n == 0:
|
||||
return bboxes
|
||||
|
||||
xy = np.ones((n * 4, 3), dtype=bboxes.dtype)
|
||||
xy[:, :2] = bboxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
|
||||
xy = xy @ M.T # transform
|
||||
xy = (xy[:, :2] / xy[:, 2:3] if self.perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
|
||||
|
||||
# Create new boxes
|
||||
x = xy[:, [0, 2, 4, 6]]
|
||||
y = xy[:, [1, 3, 5, 7]]
|
||||
return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1)), dtype=bboxes.dtype).reshape(4, n).T
|
||||
|
||||
def apply_segments(self, segments, M):
|
||||
"""
|
||||
Apply affine to segments and generate new bboxes from segments.
|
||||
|
||||
Args:
|
||||
segments (ndarray): list of segments, [num_samples, 500, 2].
|
||||
M (ndarray): affine matrix.
|
||||
|
||||
Returns:
|
||||
new_segments (ndarray): list of segments after affine, [num_samples, 500, 2].
|
||||
new_bboxes (ndarray): bboxes after affine, [N, 4].
|
||||
"""
|
||||
n, num = segments.shape[:2]
|
||||
if n == 0:
|
||||
return [], segments
|
||||
|
||||
xy = np.ones((n * num, 3), dtype=segments.dtype)
|
||||
segments = segments.reshape(-1, 2)
|
||||
xy[:, :2] = segments
|
||||
xy = xy @ M.T # transform
|
||||
xy = xy[:, :2] / xy[:, 2:3]
|
||||
segments = xy.reshape(n, -1, 2)
|
||||
bboxes = np.stack([segment2box(xy, self.size[0], self.size[1]) for xy in segments], 0)
|
||||
return bboxes, segments
|
||||
|
||||
def apply_keypoints(self, keypoints, M):
|
||||
"""
|
||||
Apply affine to keypoints.
|
||||
|
||||
Args:
|
||||
keypoints (ndarray): keypoints, [N, 17, 3].
|
||||
M (ndarray): affine matrix.
|
||||
|
||||
Return:
|
||||
new_keypoints (ndarray): keypoints after affine, [N, 17, 3].
|
||||
"""
|
||||
n, nkpt = keypoints.shape[:2]
|
||||
if n == 0:
|
||||
return keypoints
|
||||
xy = np.ones((n * nkpt, 3), dtype=keypoints.dtype)
|
||||
visible = keypoints[..., 2].reshape(n * nkpt, 1)
|
||||
xy[:, :2] = keypoints[..., :2].reshape(n * nkpt, 2)
|
||||
xy = xy @ M.T # transform
|
||||
xy = xy[:, :2] / xy[:, 2:3] # perspective rescale or affine
|
||||
out_mask = (xy[:, 0] < 0) | (xy[:, 1] < 0) | (xy[:, 0] > self.size[0]) | (xy[:, 1] > self.size[1])
|
||||
visible[out_mask] = 0
|
||||
return np.concatenate([xy, visible], axis=-1).reshape(n, nkpt, 3)
|
||||
|
||||
def __call__(self, labels):
|
||||
"""
|
||||
Affine images and targets.
|
||||
|
||||
Args:
|
||||
labels (dict): a dict of `bboxes`, `segments`, `keypoints`.
|
||||
"""
|
||||
if self.pre_transform and 'mosaic_border' not in labels:
|
||||
labels = self.pre_transform(labels)
|
||||
labels.pop('ratio_pad', None) # do not need ratio pad
|
||||
|
||||
img = labels['img']
|
||||
cls = labels['cls']
|
||||
instances = labels.pop('instances')
|
||||
# Make sure the coord formats are right
|
||||
instances.convert_bbox(format='xyxy')
|
||||
instances.denormalize(*img.shape[:2][::-1])
|
||||
|
||||
border = labels.pop('mosaic_border', self.border)
|
||||
self.size = img.shape[1] + border[1] * 2, img.shape[0] + border[0] * 2 # w, h
|
||||
# M is affine matrix
|
||||
# scale for func:`box_candidates`
|
||||
img, M, scale = self.affine_transform(img, border)
|
||||
|
||||
bboxes = self.apply_bboxes(instances.bboxes, M)
|
||||
|
||||
segments = instances.segments
|
||||
keypoints = instances.keypoints
|
||||
# Update bboxes if there are segments.
|
||||
if len(segments):
|
||||
bboxes, segments = self.apply_segments(segments, M)
|
||||
|
||||
if keypoints is not None:
|
||||
keypoints = self.apply_keypoints(keypoints, M)
|
||||
new_instances = Instances(bboxes, segments, keypoints, bbox_format='xyxy', normalized=False)
|
||||
# Clip
|
||||
new_instances.clip(*self.size)
|
||||
|
||||
# Filter instances
|
||||
instances.scale(scale_w=scale, scale_h=scale, bbox_only=True)
|
||||
# Make the bboxes have the same scale with new_bboxes
|
||||
i = self.box_candidates(box1=instances.bboxes.T,
|
||||
box2=new_instances.bboxes.T,
|
||||
area_thr=0.01 if len(segments) else 0.10)
|
||||
labels['instances'] = new_instances[i]
|
||||
labels['cls'] = cls[i]
|
||||
labels['img'] = img
|
||||
labels['resized_shape'] = img.shape[:2]
|
||||
return labels
|
||||
|
||||
def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
|
||||
# Compute box candidates: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
|
||||
w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
|
||||
w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
|
||||
ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
|
||||
return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
|
||||
|
||||
|
||||
class RandomHSV:
|
||||
|
||||
def __init__(self, hgain=0.5, sgain=0.5, vgain=0.5) -> None:
|
||||
self.hgain = hgain
|
||||
self.sgain = sgain
|
||||
self.vgain = vgain
|
||||
|
||||
def __call__(self, labels):
|
||||
"""Applies random horizontal or vertical flip to an image with a given probability."""
|
||||
img = labels['img']
|
||||
if self.hgain or self.sgain or self.vgain:
|
||||
r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1 # random gains
|
||||
hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
|
||||
dtype = img.dtype # uint8
|
||||
|
||||
x = np.arange(0, 256, dtype=r.dtype)
|
||||
lut_hue = ((x * r[0]) % 180).astype(dtype)
|
||||
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
|
||||
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
|
||||
|
||||
im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
|
||||
cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
|
||||
return labels
|
||||
|
||||
|
||||
class RandomFlip:
|
||||
|
||||
def __init__(self, p=0.5, direction='horizontal', flip_idx=None) -> None:
|
||||
assert direction in ['horizontal', 'vertical'], f'Support direction `horizontal` or `vertical`, got {direction}'
|
||||
assert 0 <= p <= 1.0
|
||||
|
||||
self.p = p
|
||||
self.direction = direction
|
||||
self.flip_idx = flip_idx
|
||||
|
||||
def __call__(self, labels):
|
||||
"""Resize image and padding for detection, instance segmentation, pose."""
|
||||
img = labels['img']
|
||||
instances = labels.pop('instances')
|
||||
instances.convert_bbox(format='xywh')
|
||||
h, w = img.shape[:2]
|
||||
h = 1 if instances.normalized else h
|
||||
w = 1 if instances.normalized else w
|
||||
|
||||
# Flip up-down
|
||||
if self.direction == 'vertical' and random.random() < self.p:
|
||||
img = np.flipud(img)
|
||||
instances.flipud(h)
|
||||
if self.direction == 'horizontal' and random.random() < self.p:
|
||||
img = np.fliplr(img)
|
||||
instances.fliplr(w)
|
||||
# For keypoints
|
||||
if self.flip_idx is not None and instances.keypoints is not None:
|
||||
instances.keypoints = np.ascontiguousarray(instances.keypoints[:, self.flip_idx, :])
|
||||
labels['img'] = np.ascontiguousarray(img)
|
||||
labels['instances'] = instances
|
||||
return labels
|
||||
|
||||
|
||||
class LetterBox:
|
||||
"""Resize image and padding for detection, instance segmentation, pose."""
|
||||
|
||||
def __init__(self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, center=True, stride=32):
|
||||
"""Initialize LetterBox object with specific parameters."""
|
||||
self.new_shape = new_shape
|
||||
self.auto = auto
|
||||
self.scaleFill = scaleFill
|
||||
self.scaleup = scaleup
|
||||
self.stride = stride
|
||||
self.center = center # Put the image in the middle or top-left
|
||||
|
||||
def __call__(self, labels=None, image=None):
|
||||
"""Return updated labels and image with added border."""
|
||||
if labels is None:
|
||||
labels = {}
|
||||
img = labels.get('img') if image is None else image
|
||||
shape = img.shape[:2] # current shape [height, width]
|
||||
new_shape = labels.pop('rect_shape', self.new_shape)
|
||||
if isinstance(new_shape, int):
|
||||
new_shape = (new_shape, new_shape)
|
||||
|
||||
# Scale ratio (new / old)
|
||||
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
||||
if not self.scaleup: # only scale down, do not scale up (for better val mAP)
|
||||
r = min(r, 1.0)
|
||||
|
||||
# Compute padding
|
||||
ratio = r, r # width, height ratios
|
||||
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
||||
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
||||
if self.auto: # minimum rectangle
|
||||
dw, dh = np.mod(dw, self.stride), np.mod(dh, self.stride) # wh padding
|
||||
elif self.scaleFill: # stretch
|
||||
dw, dh = 0.0, 0.0
|
||||
new_unpad = (new_shape[1], new_shape[0])
|
||||
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
|
||||
|
||||
if self.center:
|
||||
dw /= 2 # divide padding into 2 sides
|
||||
dh /= 2
|
||||
if labels.get('ratio_pad'):
|
||||
labels['ratio_pad'] = (labels['ratio_pad'], (dw, dh)) # for evaluation
|
||||
|
||||
if shape[::-1] != new_unpad: # resize
|
||||
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||
top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
|
||||
left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
|
||||
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT,
|
||||
value=(114, 114, 114)) # add border
|
||||
|
||||
if len(labels):
|
||||
labels = self._update_labels(labels, ratio, dw, dh)
|
||||
labels['img'] = img
|
||||
labels['resized_shape'] = new_shape
|
||||
return labels
|
||||
else:
|
||||
return img
|
||||
|
||||
def _update_labels(self, labels, ratio, padw, padh):
|
||||
"""Update labels."""
|
||||
labels['instances'].convert_bbox(format='xyxy')
|
||||
labels['instances'].denormalize(*labels['img'].shape[:2][::-1])
|
||||
labels['instances'].scale(*ratio)
|
||||
labels['instances'].add_padding(padw, padh)
|
||||
return labels
|
||||
|
||||
|
||||
class CopyPaste:
|
||||
|
||||
def __init__(self, p=0.5) -> None:
|
||||
self.p = p
|
||||
|
||||
def __call__(self, labels):
|
||||
"""Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)."""
|
||||
im = labels['img']
|
||||
cls = labels['cls']
|
||||
h, w = im.shape[:2]
|
||||
instances = labels.pop('instances')
|
||||
instances.convert_bbox(format='xyxy')
|
||||
instances.denormalize(w, h)
|
||||
if self.p and len(instances.segments):
|
||||
n = len(instances)
|
||||
_, w, _ = im.shape # height, width, channels
|
||||
im_new = np.zeros(im.shape, np.uint8)
|
||||
|
||||
# Calculate ioa first then select indexes randomly
|
||||
ins_flip = deepcopy(instances)
|
||||
ins_flip.fliplr(w)
|
||||
|
||||
ioa = bbox_ioa(ins_flip.bboxes, instances.bboxes) # intersection over area, (N, M)
|
||||
indexes = np.nonzero((ioa < 0.30).all(1))[0] # (N, )
|
||||
n = len(indexes)
|
||||
for j in random.sample(list(indexes), k=round(self.p * n)):
|
||||
cls = np.concatenate((cls, cls[[j]]), axis=0)
|
||||
instances = Instances.concatenate((instances, ins_flip[[j]]), axis=0)
|
||||
cv2.drawContours(im_new, instances.segments[[j]].astype(np.int32), -1, (1, 1, 1), cv2.FILLED)
|
||||
|
||||
result = cv2.flip(im, 1) # augment segments (flip left-right)
|
||||
i = cv2.flip(im_new, 1).astype(bool)
|
||||
im[i] = result[i] # cv2.imwrite('debug.jpg', im) # debug
|
||||
|
||||
labels['img'] = im
|
||||
labels['cls'] = cls
|
||||
labels['instances'] = instances
|
||||
return labels
|
||||
|
||||
|
||||
class Albumentations:
|
||||
"""YOLOv8 Albumentations class (optional, only used if package is installed)"""
|
||||
|
||||
def __init__(self, p=1.0):
|
||||
"""Initialize the transform object for YOLO bbox formatted params."""
|
||||
self.p = p
|
||||
self.transform = None
|
||||
prefix = colorstr('albumentations: ')
|
||||
try:
|
||||
import albumentations as A
|
||||
|
||||
check_version(A.__version__, '1.0.3', hard=True) # version requirement
|
||||
|
||||
T = [
|
||||
A.Blur(p=0.01),
|
||||
A.MedianBlur(p=0.01),
|
||||
A.ToGray(p=0.01),
|
||||
A.CLAHE(p=0.01),
|
||||
A.RandomBrightnessContrast(p=0.0),
|
||||
A.RandomGamma(p=0.0),
|
||||
A.ImageCompression(quality_lower=75, p=0.0)] # transforms
|
||||
self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
|
||||
|
||||
LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
|
||||
except ImportError: # package not installed, skip
|
||||
pass
|
||||
except Exception as e:
|
||||
LOGGER.info(f'{prefix}{e}')
|
||||
|
||||
def __call__(self, labels):
|
||||
"""Generates object detections and returns a dictionary with detection results."""
|
||||
im = labels['img']
|
||||
cls = labels['cls']
|
||||
if len(cls):
|
||||
labels['instances'].convert_bbox('xywh')
|
||||
labels['instances'].normalize(*im.shape[:2][::-1])
|
||||
bboxes = labels['instances'].bboxes
|
||||
# TODO: add supports of segments and keypoints
|
||||
if self.transform and random.random() < self.p:
|
||||
new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed
|
||||
if len(new['class_labels']) > 0: # skip update if no bbox in new im
|
||||
labels['img'] = new['image']
|
||||
labels['cls'] = np.array(new['class_labels'])
|
||||
bboxes = np.array(new['bboxes'], dtype=np.float32)
|
||||
labels['instances'].update(bboxes=bboxes)
|
||||
return labels
|
||||
|
||||
|
||||
# TODO: technically this is not an augmentation, maybe we should put this to another files
|
||||
class Format:
|
||||
|
||||
def __init__(self,
|
||||
bbox_format='xywh',
|
||||
normalize=True,
|
||||
return_mask=False,
|
||||
return_keypoint=False,
|
||||
mask_ratio=4,
|
||||
mask_overlap=True,
|
||||
batch_idx=True):
|
||||
self.bbox_format = bbox_format
|
||||
self.normalize = normalize
|
||||
self.return_mask = return_mask # set False when training detection only
|
||||
self.return_keypoint = return_keypoint
|
||||
self.mask_ratio = mask_ratio
|
||||
self.mask_overlap = mask_overlap
|
||||
self.batch_idx = batch_idx # keep the batch indexes
|
||||
|
||||
def __call__(self, labels):
|
||||
"""Return formatted image, classes, bounding boxes & keypoints to be used by 'collate_fn'."""
|
||||
img = labels.pop('img')
|
||||
h, w = img.shape[:2]
|
||||
cls = labels.pop('cls')
|
||||
instances = labels.pop('instances')
|
||||
instances.convert_bbox(format=self.bbox_format)
|
||||
instances.denormalize(w, h)
|
||||
nl = len(instances)
|
||||
|
||||
if self.return_mask:
|
||||
if nl:
|
||||
masks, instances, cls = self._format_segments(instances, cls, w, h)
|
||||
masks = torch.from_numpy(masks)
|
||||
else:
|
||||
masks = torch.zeros(1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio,
|
||||
img.shape[1] // self.mask_ratio)
|
||||
labels['masks'] = masks
|
||||
if self.normalize:
|
||||
instances.normalize(w, h)
|
||||
labels['img'] = self._format_img(img)
|
||||
labels['cls'] = torch.from_numpy(cls) if nl else torch.zeros(nl)
|
||||
labels['bboxes'] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
|
||||
if self.return_keypoint:
|
||||
labels['keypoints'] = torch.from_numpy(instances.keypoints)
|
||||
# Then we can use collate_fn
|
||||
if self.batch_idx:
|
||||
labels['batch_idx'] = torch.zeros(nl)
|
||||
return labels
|
||||
|
||||
def _format_img(self, img):
|
||||
"""Format the image for YOLOv5 from Numpy array to PyTorch tensor."""
|
||||
if len(img.shape) < 3:
|
||||
img = np.expand_dims(img, -1)
|
||||
img = np.ascontiguousarray(img.transpose(2, 0, 1)[::-1])
|
||||
img = torch.from_numpy(img)
|
||||
return img
|
||||
|
||||
def _format_segments(self, instances, cls, w, h):
|
||||
"""convert polygon points to bitmap."""
|
||||
segments = instances.segments
|
||||
if self.mask_overlap:
|
||||
masks, sorted_idx = polygons2masks_overlap((h, w), segments, downsample_ratio=self.mask_ratio)
|
||||
masks = masks[None] # (640, 640) -> (1, 640, 640)
|
||||
instances = instances[sorted_idx]
|
||||
cls = cls[sorted_idx]
|
||||
else:
|
||||
masks = polygons2masks((h, w), segments, color=1, downsample_ratio=self.mask_ratio)
|
||||
|
||||
return masks, instances, cls
|
||||
|
||||
|
||||
def v8_transforms(dataset, imgsz, hyp, stretch=False):
|
||||
"""Convert images to a size suitable for YOLOv8 training."""
|
||||
pre_transform = Compose([
|
||||
Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic),
|
||||
CopyPaste(p=hyp.copy_paste),
|
||||
RandomPerspective(
|
||||
degrees=hyp.degrees,
|
||||
translate=hyp.translate,
|
||||
scale=hyp.scale,
|
||||
shear=hyp.shear,
|
||||
perspective=hyp.perspective,
|
||||
pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)),
|
||||
)])
|
||||
flip_idx = dataset.data.get('flip_idx', []) # for keypoints augmentation
|
||||
if dataset.use_keypoints:
|
||||
kpt_shape = dataset.data.get('kpt_shape', None)
|
||||
if len(flip_idx) == 0 and hyp.fliplr > 0.0:
|
||||
hyp.fliplr = 0.0
|
||||
LOGGER.warning("WARNING ⚠️ No 'flip_idx' array defined in data.yaml, setting augmentation 'fliplr=0.0'")
|
||||
elif flip_idx and (len(flip_idx) != kpt_shape[0]):
|
||||
raise ValueError(f'data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}')
|
||||
|
||||
return Compose([
|
||||
pre_transform,
|
||||
MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
|
||||
Albumentations(p=1.0),
|
||||
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
|
||||
RandomFlip(direction='vertical', p=hyp.flipud),
|
||||
RandomFlip(direction='horizontal', p=hyp.fliplr, flip_idx=flip_idx)]) # transforms
|
||||
|
||||
|
||||
# Classification augmentations -----------------------------------------------------------------------------------------
|
||||
def classify_transforms(size=224, mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0)): # IMAGENET_MEAN, IMAGENET_STD
|
||||
# Transforms to apply if albumentations not installed
|
||||
if not isinstance(size, int):
|
||||
raise TypeError(f'classify_transforms() size {size} must be integer, not (list, tuple)')
|
||||
if any(mean) or any(std):
|
||||
return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(mean, std, inplace=True)])
|
||||
else:
|
||||
return T.Compose([CenterCrop(size), ToTensor()])
|
||||
|
||||
|
||||
def hsv2colorjitter(h, s, v):
|
||||
"""Map HSV (hue, saturation, value) jitter into ColorJitter values (brightness, contrast, saturation, hue)"""
|
||||
return v, v, s, h
|
||||
|
||||
|
||||
def classify_albumentations(
|
||||
augment=True,
|
||||
size=224,
|
||||
scale=(0.08, 1.0),
|
||||
hflip=0.5,
|
||||
vflip=0.0,
|
||||
hsv_h=0.015, # image HSV-Hue augmentation (fraction)
|
||||
hsv_s=0.7, # image HSV-Saturation augmentation (fraction)
|
||||
hsv_v=0.4, # image HSV-Value augmentation (fraction)
|
||||
mean=(0.0, 0.0, 0.0), # IMAGENET_MEAN
|
||||
std=(1.0, 1.0, 1.0), # IMAGENET_STD
|
||||
auto_aug=False,
|
||||
):
|
||||
"""YOLOv8 classification Albumentations (optional, only used if package is installed)."""
|
||||
prefix = colorstr('albumentations: ')
|
||||
try:
|
||||
import albumentations as A
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
|
||||
check_version(A.__version__, '1.0.3', hard=True) # version requirement
|
||||
if augment: # Resize and crop
|
||||
T = [A.RandomResizedCrop(height=size, width=size, scale=scale)]
|
||||
if auto_aug:
|
||||
# TODO: implement AugMix, AutoAug & RandAug in albumentations
|
||||
LOGGER.info(f'{prefix}auto augmentations are currently not supported')
|
||||
else:
|
||||
if hflip > 0:
|
||||
T += [A.HorizontalFlip(p=hflip)]
|
||||
if vflip > 0:
|
||||
T += [A.VerticalFlip(p=vflip)]
|
||||
if any((hsv_h, hsv_s, hsv_v)):
|
||||
T += [A.ColorJitter(*hsv2colorjitter(hsv_h, hsv_s, hsv_v))] # brightness, contrast, saturation, hue
|
||||
else: # Use fixed crop for eval set (reproducibility)
|
||||
T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
|
||||
T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor
|
||||
LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
|
||||
return A.Compose(T)
|
||||
|
||||
except ImportError: # package not installed, skip
|
||||
pass
|
||||
except Exception as e:
|
||||
LOGGER.info(f'{prefix}{e}')
|
||||
|
||||
|
||||
class ClassifyLetterBox:
|
||||
"""YOLOv8 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])"""
|
||||
|
||||
def __init__(self, size=(640, 640), auto=False, stride=32):
|
||||
"""Resizes image and crops it to center with max dimensions 'h' and 'w'."""
|
||||
super().__init__()
|
||||
self.h, self.w = (size, size) if isinstance(size, int) else size
|
||||
self.auto = auto # pass max size integer, automatically solve for short side using stride
|
||||
self.stride = stride # used with auto
|
||||
|
||||
def __call__(self, im): # im = np.array HWC
|
||||
imh, imw = im.shape[:2]
|
||||
r = min(self.h / imh, self.w / imw) # ratio of new/old
|
||||
h, w = round(imh * r), round(imw * r) # resized image
|
||||
hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w
|
||||
top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1)
|
||||
im_out = np.full((self.h, self.w, 3), 114, dtype=im.dtype)
|
||||
im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
|
||||
return im_out
|
||||
|
||||
|
||||
class CenterCrop:
|
||||
"""YOLOv8 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])"""
|
||||
|
||||
def __init__(self, size=640):
|
||||
"""Converts an image from numpy array to PyTorch tensor."""
|
||||
super().__init__()
|
||||
self.h, self.w = (size, size) if isinstance(size, int) else size
|
||||
|
||||
def __call__(self, im): # im = np.array HWC
|
||||
imh, imw = im.shape[:2]
|
||||
m = min(imh, imw) # min dimension
|
||||
top, left = (imh - m) // 2, (imw - m) // 2
|
||||
return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
|
||||
class ToTensor:
|
||||
"""YOLOv8 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])."""
|
||||
|
||||
def __init__(self, half=False):
|
||||
"""Initialize YOLOv8 ToTensor object with optional half-precision support."""
|
||||
super().__init__()
|
||||
self.half = half
|
||||
|
||||
def __call__(self, im): # im = np.array HWC in BGR order
|
||||
im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1]) # HWC to CHW -> BGR to RGB -> contiguous
|
||||
im = torch.from_numpy(im) # to torch
|
||||
im = im.half() if self.half else im.float() # uint8 to fp16/32
|
||||
im /= 255.0 # 0-255 to 0.0-1.0
|
||||
return im
|
@ -1,286 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import psutil
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT
|
||||
from .utils import HELP_URL, IMG_FORMATS
|
||||
|
||||
|
||||
class BaseDataset(Dataset):
|
||||
"""
|
||||
Base dataset class for loading and processing image data.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the folder containing images.
|
||||
imgsz (int, optional): Image size. Defaults to 640.
|
||||
cache (bool, optional): Cache images to RAM or disk during training. Defaults to False.
|
||||
augment (bool, optional): If True, data augmentation is applied. Defaults to True.
|
||||
hyp (dict, optional): Hyperparameters to apply data augmentation. Defaults to None.
|
||||
prefix (str, optional): Prefix to print in log messages. Defaults to ''.
|
||||
rect (bool, optional): If True, rectangular training is used. Defaults to False.
|
||||
batch_size (int, optional): Size of batches. Defaults to None.
|
||||
stride (int, optional): Stride. Defaults to 32.
|
||||
pad (float, optional): Padding. Defaults to 0.0.
|
||||
single_cls (bool, optional): If True, single class training is used. Defaults to False.
|
||||
classes (list): List of included classes. Default is None.
|
||||
fraction (float): Fraction of dataset to utilize. Default is 1.0 (use all data).
|
||||
|
||||
Attributes:
|
||||
im_files (list): List of image file paths.
|
||||
labels (list): List of label data dictionaries.
|
||||
ni (int): Number of images in the dataset.
|
||||
ims (list): List of loaded images.
|
||||
npy_files (list): List of numpy file paths.
|
||||
transforms (callable): Image transformation function.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_path,
|
||||
imgsz=640,
|
||||
cache=False,
|
||||
augment=True,
|
||||
hyp=DEFAULT_CFG,
|
||||
prefix='',
|
||||
rect=False,
|
||||
batch_size=16,
|
||||
stride=32,
|
||||
pad=0.5,
|
||||
single_cls=False,
|
||||
classes=None,
|
||||
fraction=1.0):
|
||||
super().__init__()
|
||||
self.img_path = img_path
|
||||
self.imgsz = imgsz
|
||||
self.augment = augment
|
||||
self.single_cls = single_cls
|
||||
self.prefix = prefix
|
||||
self.fraction = fraction
|
||||
self.im_files = self.get_img_files(self.img_path)
|
||||
self.labels = self.get_labels()
|
||||
self.update_labels(include_class=classes) # single_cls and include_class
|
||||
self.ni = len(self.labels) # number of images
|
||||
self.rect = rect
|
||||
self.batch_size = batch_size
|
||||
self.stride = stride
|
||||
self.pad = pad
|
||||
if self.rect:
|
||||
assert self.batch_size is not None
|
||||
self.set_rectangle()
|
||||
|
||||
# Buffer thread for mosaic images
|
||||
self.buffer = [] # buffer size = batch size
|
||||
self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0
|
||||
|
||||
# Cache stuff
|
||||
if cache == 'ram' and not self.check_cache_ram():
|
||||
cache = False
|
||||
self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni
|
||||
self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
|
||||
if cache:
|
||||
self.cache_images(cache)
|
||||
|
||||
# Transforms
|
||||
self.transforms = self.build_transforms(hyp=hyp)
|
||||
|
||||
def get_img_files(self, img_path):
|
||||
"""Read image files."""
|
||||
try:
|
||||
f = [] # image files
|
||||
for p in img_path if isinstance(img_path, list) else [img_path]:
|
||||
p = Path(p) # os-agnostic
|
||||
if p.is_dir(): # dir
|
||||
f += glob.glob(str(p / '**' / '*.*'), recursive=True)
|
||||
# F = list(p.rglob('*.*')) # pathlib
|
||||
elif p.is_file(): # file
|
||||
with open(p) as t:
|
||||
t = t.read().strip().splitlines()
|
||||
parent = str(p.parent) + os.sep
|
||||
f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
|
||||
# F += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
|
||||
else:
|
||||
raise FileNotFoundError(f'{self.prefix}{p} does not exist')
|
||||
im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
|
||||
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
|
||||
assert im_files, f'{self.prefix}No images found'
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(f'{self.prefix}Error loading data from {img_path}\n{HELP_URL}') from e
|
||||
if self.fraction < 1:
|
||||
im_files = im_files[:round(len(im_files) * self.fraction)]
|
||||
return im_files
|
||||
|
||||
def update_labels(self, include_class: Optional[list]):
|
||||
"""include_class, filter labels to include only these classes (optional)."""
|
||||
include_class_array = np.array(include_class).reshape(1, -1)
|
||||
for i in range(len(self.labels)):
|
||||
if include_class is not None:
|
||||
cls = self.labels[i]['cls']
|
||||
bboxes = self.labels[i]['bboxes']
|
||||
segments = self.labels[i]['segments']
|
||||
keypoints = self.labels[i]['keypoints']
|
||||
j = (cls == include_class_array).any(1)
|
||||
self.labels[i]['cls'] = cls[j]
|
||||
self.labels[i]['bboxes'] = bboxes[j]
|
||||
if segments:
|
||||
self.labels[i]['segments'] = [segments[si] for si, idx in enumerate(j) if idx]
|
||||
if keypoints is not None:
|
||||
self.labels[i]['keypoints'] = keypoints[j]
|
||||
if self.single_cls:
|
||||
self.labels[i]['cls'][:, 0] = 0
|
||||
|
||||
def load_image(self, i):
|
||||
"""Loads 1 image from dataset index 'i', returns (im, resized hw)."""
|
||||
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
|
||||
if im is None: # not cached in RAM
|
||||
if fn.exists(): # load npy
|
||||
im = np.load(fn)
|
||||
else: # read image
|
||||
im = cv2.imread(f) # BGR
|
||||
if im is None:
|
||||
raise FileNotFoundError(f'Image Not Found {f}')
|
||||
h0, w0 = im.shape[:2] # orig hw
|
||||
r = self.imgsz / max(h0, w0) # ratio
|
||||
if r != 1: # if sizes are not equal
|
||||
interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
|
||||
im = cv2.resize(im, (min(math.ceil(w0 * r), self.imgsz), min(math.ceil(h0 * r), self.imgsz)),
|
||||
interpolation=interp)
|
||||
|
||||
# Add to buffer if training with augmentations
|
||||
if self.augment:
|
||||
self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
|
||||
self.buffer.append(i)
|
||||
if len(self.buffer) >= self.max_buffer_length:
|
||||
j = self.buffer.pop(0)
|
||||
self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None
|
||||
|
||||
return im, (h0, w0), im.shape[:2]
|
||||
|
||||
return self.ims[i], self.im_hw0[i], self.im_hw[i]
|
||||
|
||||
def cache_images(self, cache):
|
||||
"""Cache images to memory or disk."""
|
||||
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
|
||||
fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
results = pool.imap(fcn, range(self.ni))
|
||||
pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
|
||||
for i, x in pbar:
|
||||
if cache == 'disk':
|
||||
b += self.npy_files[i].stat().st_size
|
||||
else: # 'ram'
|
||||
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
|
||||
b += self.ims[i].nbytes
|
||||
pbar.desc = f'{self.prefix}Caching images ({b / gb:.1f}GB {cache})'
|
||||
pbar.close()
|
||||
|
||||
def cache_images_to_disk(self, i):
|
||||
"""Saves an image as an *.npy file for faster loading."""
|
||||
f = self.npy_files[i]
|
||||
if not f.exists():
|
||||
np.save(f.as_posix(), cv2.imread(self.im_files[i]))
|
||||
|
||||
def check_cache_ram(self, safety_margin=0.5):
|
||||
"""Check image caching requirements vs available memory."""
|
||||
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
|
||||
n = min(self.ni, 30) # extrapolate from 30 random images
|
||||
for _ in range(n):
|
||||
im = cv2.imread(random.choice(self.im_files)) # sample image
|
||||
ratio = self.imgsz / max(im.shape[0], im.shape[1]) # max(h, w) # ratio
|
||||
b += im.nbytes * ratio ** 2
|
||||
mem_required = b * self.ni / n * (1 + safety_margin) # GB required to cache dataset into RAM
|
||||
mem = psutil.virtual_memory()
|
||||
cache = mem_required < mem.available # to cache or not to cache, that is the question
|
||||
if not cache:
|
||||
LOGGER.info(f'{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images '
|
||||
f'with {int(safety_margin * 100)}% safety margin but only '
|
||||
f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, '
|
||||
f"{'caching images ✅' if cache else 'not caching images ⚠️'}")
|
||||
return cache
|
||||
|
||||
def set_rectangle(self):
|
||||
"""Sets the shape of bounding boxes for YOLO detections as rectangles."""
|
||||
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
|
||||
nb = bi[-1] + 1 # number of batches
|
||||
|
||||
s = np.array([x.pop('shape') for x in self.labels]) # hw
|
||||
ar = s[:, 0] / s[:, 1] # aspect ratio
|
||||
irect = ar.argsort()
|
||||
self.im_files = [self.im_files[i] for i in irect]
|
||||
self.labels = [self.labels[i] for i in irect]
|
||||
ar = ar[irect]
|
||||
|
||||
# Set training image shapes
|
||||
shapes = [[1, 1]] * nb
|
||||
for i in range(nb):
|
||||
ari = ar[bi == i]
|
||||
mini, maxi = ari.min(), ari.max()
|
||||
if maxi < 1:
|
||||
shapes[i] = [maxi, 1]
|
||||
elif mini > 1:
|
||||
shapes[i] = [1, 1 / mini]
|
||||
|
||||
self.batch_shapes = np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int) * self.stride
|
||||
self.batch = bi # batch index of image
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Returns transformed label information for given index."""
|
||||
return self.transforms(self.get_image_and_label(index))
|
||||
|
||||
def get_image_and_label(self, index):
|
||||
"""Get and return label information from the dataset."""
|
||||
label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948
|
||||
label.pop('shape', None) # shape is for rect, remove it
|
||||
label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index)
|
||||
label['ratio_pad'] = (label['resized_shape'][0] / label['ori_shape'][0],
|
||||
label['resized_shape'][1] / label['ori_shape'][1]) # for evaluation
|
||||
if self.rect:
|
||||
label['rect_shape'] = self.batch_shapes[self.batch[index]]
|
||||
return self.update_labels_info(label)
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the length of the labels list for the dataset."""
|
||||
return len(self.labels)
|
||||
|
||||
def update_labels_info(self, label):
|
||||
"""custom your label format here."""
|
||||
return label
|
||||
|
||||
def build_transforms(self, hyp=None):
|
||||
"""Users can custom augmentations here
|
||||
like:
|
||||
if self.augment:
|
||||
# Training transforms
|
||||
return Compose([])
|
||||
else:
|
||||
# Val transforms
|
||||
return Compose([])
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_labels(self):
|
||||
"""Users can custom their own format here.
|
||||
Make sure your output is a list with each element like below:
|
||||
dict(
|
||||
im_file=im_file,
|
||||
shape=shape, # format: (height, width)
|
||||
cls=cls,
|
||||
bboxes=bboxes, # xywh
|
||||
segments=segments, # xy
|
||||
keypoints=keypoints, # xy
|
||||
normalized=True, # or False
|
||||
bbox_format="xyxy", # or xywh, ltwh
|
||||
)
|
||||
"""
|
||||
raise NotImplementedError
|
@ -1,170 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import dataloader, distributed
|
||||
|
||||
from ultralytics.yolo.data.dataloaders.stream_loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots,
|
||||
LoadStreams, LoadTensor, SourceTypes, autocast_list)
|
||||
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
||||
from ultralytics.yolo.utils.checks import check_file
|
||||
|
||||
from ..utils import RANK, colorstr
|
||||
from .dataset import YOLODataset
|
||||
from .utils import PIN_MEMORY
|
||||
|
||||
|
||||
class InfiniteDataLoader(dataloader.DataLoader):
|
||||
"""Dataloader that reuses workers. Uses same syntax as vanilla DataLoader."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Dataloader that infinitely recycles workers, inherits from DataLoader."""
|
||||
super().__init__(*args, **kwargs)
|
||||
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
|
||||
self.iterator = super().__iter__()
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the length of the batch sampler's sampler."""
|
||||
return len(self.batch_sampler.sampler)
|
||||
|
||||
def __iter__(self):
|
||||
"""Creates a sampler that repeats indefinitely."""
|
||||
for _ in range(len(self)):
|
||||
yield next(self.iterator)
|
||||
|
||||
def reset(self):
|
||||
"""Reset iterator.
|
||||
This is useful when we want to modify settings of dataset while training.
|
||||
"""
|
||||
self.iterator = self._get_iterator()
|
||||
|
||||
|
||||
class _RepeatSampler:
|
||||
"""
|
||||
Sampler that repeats forever.
|
||||
|
||||
Args:
|
||||
sampler (Dataset.sampler): The sampler to repeat.
|
||||
"""
|
||||
|
||||
def __init__(self, sampler):
|
||||
"""Initializes an object that repeats a given sampler indefinitely."""
|
||||
self.sampler = sampler
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterates over the 'sampler' and yields its contents."""
|
||||
while True:
|
||||
yield from iter(self.sampler)
|
||||
|
||||
|
||||
def seed_worker(worker_id): # noqa
|
||||
"""Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader."""
|
||||
worker_seed = torch.initial_seed() % 2 ** 32
|
||||
np.random.seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
|
||||
def build_yolo_dataset(cfg, img_path, batch, data, mode='train', rect=False, stride=32):
|
||||
"""Build YOLO Dataset"""
|
||||
return YOLODataset(
|
||||
img_path=img_path,
|
||||
imgsz=cfg.imgsz,
|
||||
batch_size=batch,
|
||||
augment=mode == 'train', # augmentation
|
||||
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
|
||||
rect=cfg.rect or rect, # rectangular batches
|
||||
cache=cfg.cache or None,
|
||||
single_cls=cfg.single_cls or False,
|
||||
stride=int(stride),
|
||||
pad=0.0 if mode == 'train' else 0.5,
|
||||
prefix=colorstr(f'{mode}: '),
|
||||
use_segments=cfg.task == 'segment',
|
||||
use_keypoints=cfg.task == 'pose',
|
||||
classes=cfg.classes,
|
||||
data=data,
|
||||
fraction=cfg.fraction if mode == 'train' else 1.0)
|
||||
|
||||
|
||||
def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
|
||||
"""Return an InfiniteDataLoader or DataLoader for training or validation set."""
|
||||
batch = min(batch, len(dataset))
|
||||
nd = torch.cuda.device_count() # number of CUDA devices
|
||||
nw = min([os.cpu_count() // max(nd, 1), batch if batch > 1 else 0, workers]) # number of workers
|
||||
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(6148914691236517205 + RANK)
|
||||
return InfiniteDataLoader(dataset=dataset,
|
||||
batch_size=batch,
|
||||
shuffle=shuffle and sampler is None,
|
||||
num_workers=nw,
|
||||
sampler=sampler,
|
||||
pin_memory=PIN_MEMORY,
|
||||
collate_fn=getattr(dataset, 'collate_fn', None),
|
||||
worker_init_fn=seed_worker,
|
||||
generator=generator)
|
||||
|
||||
|
||||
def check_source(source):
|
||||
"""Check source type and return corresponding flag values."""
|
||||
webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False
|
||||
if isinstance(source, (str, int, Path)): # int for local usb camera
|
||||
source = str(source)
|
||||
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
||||
is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://'))
|
||||
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
|
||||
screenshot = source.lower() == 'screen'
|
||||
if is_url and is_file:
|
||||
source = check_file(source) # download
|
||||
elif isinstance(source, tuple(LOADERS)):
|
||||
in_memory = True
|
||||
elif isinstance(source, (list, tuple)):
|
||||
source = autocast_list(source) # convert all list elements to PIL or np arrays
|
||||
from_img = True
|
||||
elif isinstance(source, (Image.Image, np.ndarray)):
|
||||
from_img = True
|
||||
elif isinstance(source, torch.Tensor):
|
||||
tensor = True
|
||||
else:
|
||||
raise TypeError('Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict')
|
||||
|
||||
return source, webcam, screenshot, from_img, in_memory, tensor
|
||||
|
||||
|
||||
def load_inference_source(source=None, imgsz=640, vid_stride=1):
|
||||
"""
|
||||
Loads an inference source for object detection and applies necessary transformations.
|
||||
|
||||
Args:
|
||||
source (str, Path, Tensor, PIL.Image, np.ndarray): The input source for inference.
|
||||
imgsz (int, optional): The size of the image for inference. Default is 640.
|
||||
vid_stride (int, optional): The frame interval for video sources. Default is 1.
|
||||
|
||||
Returns:
|
||||
dataset (Dataset): A dataset object for the specified input source.
|
||||
"""
|
||||
source, webcam, screenshot, from_img, in_memory, tensor = check_source(source)
|
||||
source_type = source.source_type if in_memory else SourceTypes(webcam, screenshot, from_img, tensor)
|
||||
|
||||
# Dataloader
|
||||
if tensor:
|
||||
dataset = LoadTensor(source)
|
||||
elif in_memory:
|
||||
dataset = source
|
||||
elif webcam:
|
||||
dataset = LoadStreams(source, imgsz=imgsz, vid_stride=vid_stride)
|
||||
elif screenshot:
|
||||
dataset = LoadScreenshots(source, imgsz=imgsz)
|
||||
elif from_img:
|
||||
dataset = LoadPilAndNumpy(source, imgsz=imgsz)
|
||||
else:
|
||||
dataset = LoadImages(source, imgsz=imgsz, vid_stride=vid_stride)
|
||||
|
||||
# Attach source types to the dataset
|
||||
setattr(dataset, 'source_type', source_type)
|
||||
|
||||
return dataset
|
@ -1,230 +0,0 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
from ultralytics.yolo.utils.files import make_dirs
|
||||
|
||||
|
||||
def coco91_to_coco80_class():
|
||||
"""Converts 91-index COCO class IDs to 80-index COCO class IDs.
|
||||
|
||||
Returns:
|
||||
(list): A list of 91 class IDs where the index represents the 80-index class ID and the value is the
|
||||
corresponding 91-index class ID.
|
||||
|
||||
"""
|
||||
return [
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, None, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, None, 24, 25, None,
|
||||
None, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, None, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
|
||||
51, 52, 53, 54, 55, 56, 57, 58, 59, None, 60, None, None, 61, None, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72,
|
||||
None, 73, 74, 75, 76, 77, 78, 79, None]
|
||||
|
||||
|
||||
def convert_coco(labels_dir='../coco/annotations/', use_segments=False, use_keypoints=False, cls91to80=True):
|
||||
"""Converts COCO dataset annotations to a format suitable for training YOLOv5 models.
|
||||
|
||||
Args:
|
||||
labels_dir (str, optional): Path to directory containing COCO dataset annotation files.
|
||||
use_segments (bool, optional): Whether to include segmentation masks in the output.
|
||||
use_keypoints (bool, optional): Whether to include keypoint annotations in the output.
|
||||
cls91to80 (bool, optional): Whether to map 91 COCO class IDs to the corresponding 80 COCO class IDs.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the labels_dir path does not exist.
|
||||
|
||||
Example Usage:
|
||||
convert_coco(labels_dir='../coco/annotations/', use_segments=True, use_keypoints=True, cls91to80=True)
|
||||
|
||||
Output:
|
||||
Generates output files in the specified output directory.
|
||||
"""
|
||||
|
||||
save_dir = make_dirs('yolo_labels') # output directory
|
||||
coco80 = coco91_to_coco80_class()
|
||||
|
||||
# Import json
|
||||
for json_file in sorted(Path(labels_dir).resolve().glob('*.json')):
|
||||
fn = Path(save_dir) / 'labels' / json_file.stem.replace('instances_', '') # folder name
|
||||
fn.mkdir(parents=True, exist_ok=True)
|
||||
with open(json_file) as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Create image dict
|
||||
images = {f'{x["id"]:d}': x for x in data['images']}
|
||||
# Create image-annotations dict
|
||||
imgToAnns = defaultdict(list)
|
||||
for ann in data['annotations']:
|
||||
imgToAnns[ann['image_id']].append(ann)
|
||||
|
||||
# Write labels file
|
||||
for img_id, anns in tqdm(imgToAnns.items(), desc=f'Annotations {json_file}'):
|
||||
img = images[f'{img_id:d}']
|
||||
h, w, f = img['height'], img['width'], img['file_name']
|
||||
|
||||
bboxes = []
|
||||
segments = []
|
||||
keypoints = []
|
||||
for ann in anns:
|
||||
if ann['iscrowd']:
|
||||
continue
|
||||
# The COCO box format is [top left x, top left y, width, height]
|
||||
box = np.array(ann['bbox'], dtype=np.float64)
|
||||
box[:2] += box[2:] / 2 # xy top-left corner to center
|
||||
box[[0, 2]] /= w # normalize x
|
||||
box[[1, 3]] /= h # normalize y
|
||||
if box[2] <= 0 or box[3] <= 0: # if w <= 0 and h <= 0
|
||||
continue
|
||||
|
||||
cls = coco80[ann['category_id'] - 1] if cls91to80 else ann['category_id'] - 1 # class
|
||||
box = [cls] + box.tolist()
|
||||
if box not in bboxes:
|
||||
bboxes.append(box)
|
||||
if use_segments and ann.get('segmentation') is not None:
|
||||
if len(ann['segmentation']) == 0:
|
||||
segments.append([])
|
||||
continue
|
||||
if isinstance(ann['segmentation'], dict):
|
||||
ann['segmentation'] = rle2polygon(ann['segmentation'])
|
||||
if len(ann['segmentation']) > 1:
|
||||
s = merge_multi_segment(ann['segmentation'])
|
||||
s = (np.concatenate(s, axis=0) / np.array([w, h])).reshape(-1).tolist()
|
||||
else:
|
||||
s = [j for i in ann['segmentation'] for j in i] # all segments concatenated
|
||||
s = (np.array(s).reshape(-1, 2) / np.array([w, h])).reshape(-1).tolist()
|
||||
s = [cls] + s
|
||||
if s not in segments:
|
||||
segments.append(s)
|
||||
if use_keypoints and ann.get('keypoints') is not None:
|
||||
k = (np.array(ann['keypoints']).reshape(-1, 3) / np.array([w, h, 1])).reshape(-1).tolist()
|
||||
k = box + k
|
||||
keypoints.append(k)
|
||||
|
||||
# Write
|
||||
with open((fn / f).with_suffix('.txt'), 'a') as file:
|
||||
for i in range(len(bboxes)):
|
||||
if use_keypoints:
|
||||
line = *(keypoints[i]), # cls, box, keypoints
|
||||
else:
|
||||
line = *(segments[i]
|
||||
if use_segments and len(segments[i]) > 0 else bboxes[i]), # cls, box or segments
|
||||
file.write(('%g ' * len(line)).rstrip() % line + '\n')
|
||||
|
||||
|
||||
def rle2polygon(segmentation):
|
||||
"""
|
||||
Convert Run-Length Encoding (RLE) mask to polygon coordinates.
|
||||
|
||||
Args:
|
||||
segmentation (dict, list): RLE mask representation of the object segmentation.
|
||||
|
||||
Returns:
|
||||
(list): A list of lists representing the polygon coordinates for each contour.
|
||||
|
||||
Note:
|
||||
Requires the 'pycocotools' package to be installed.
|
||||
"""
|
||||
check_requirements('pycocotools')
|
||||
from pycocotools import mask
|
||||
|
||||
m = mask.decode(segmentation)
|
||||
m[m > 0] = 255
|
||||
contours, _ = cv2.findContours(m, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_KCOS)
|
||||
polygons = []
|
||||
for contour in contours:
|
||||
epsilon = 0.001 * cv2.arcLength(contour, True)
|
||||
contour_approx = cv2.approxPolyDP(contour, epsilon, True)
|
||||
polygon = contour_approx.flatten().tolist()
|
||||
polygons.append(polygon)
|
||||
return polygons
|
||||
|
||||
|
||||
def min_index(arr1, arr2):
|
||||
"""
|
||||
Find a pair of indexes with the shortest distance between two arrays of 2D points.
|
||||
|
||||
Args:
|
||||
arr1 (np.array): A NumPy array of shape (N, 2) representing N 2D points.
|
||||
arr2 (np.array): A NumPy array of shape (M, 2) representing M 2D points.
|
||||
|
||||
Returns:
|
||||
(tuple): A tuple containing the indexes of the points with the shortest distance in arr1 and arr2 respectively.
|
||||
"""
|
||||
dis = ((arr1[:, None, :] - arr2[None, :, :]) ** 2).sum(-1)
|
||||
return np.unravel_index(np.argmin(dis, axis=None), dis.shape)
|
||||
|
||||
|
||||
def merge_multi_segment(segments):
|
||||
"""
|
||||
Merge multiple segments into one list by connecting the coordinates with the minimum distance between each segment.
|
||||
This function connects these coordinates with a thin line to merge all segments into one.
|
||||
|
||||
Args:
|
||||
segments (List[List]): Original segmentations in COCO's JSON file.
|
||||
Each element is a list of coordinates, like [segmentation1, segmentation2,...].
|
||||
|
||||
Returns:
|
||||
s (List[np.ndarray]): A list of connected segments represented as NumPy arrays.
|
||||
"""
|
||||
s = []
|
||||
segments = [np.array(i).reshape(-1, 2) for i in segments]
|
||||
idx_list = [[] for _ in range(len(segments))]
|
||||
|
||||
# record the indexes with min distance between each segment
|
||||
for i in range(1, len(segments)):
|
||||
idx1, idx2 = min_index(segments[i - 1], segments[i])
|
||||
idx_list[i - 1].append(idx1)
|
||||
idx_list[i].append(idx2)
|
||||
|
||||
# use two round to connect all the segments
|
||||
for k in range(2):
|
||||
# forward connection
|
||||
if k == 0:
|
||||
for i, idx in enumerate(idx_list):
|
||||
# middle segments have two indexes
|
||||
# reverse the index of middle segments
|
||||
if len(idx) == 2 and idx[0] > idx[1]:
|
||||
idx = idx[::-1]
|
||||
segments[i] = segments[i][::-1, :]
|
||||
|
||||
segments[i] = np.roll(segments[i], -idx[0], axis=0)
|
||||
segments[i] = np.concatenate([segments[i], segments[i][:1]])
|
||||
# deal with the first segment and the last one
|
||||
if i in [0, len(idx_list) - 1]:
|
||||
s.append(segments[i])
|
||||
else:
|
||||
idx = [0, idx[1] - idx[0]]
|
||||
s.append(segments[i][idx[0]:idx[1] + 1])
|
||||
|
||||
else:
|
||||
for i in range(len(idx_list) - 1, -1, -1):
|
||||
if i not in [0, len(idx_list) - 1]:
|
||||
idx = idx_list[i]
|
||||
nidx = abs(idx[1] - idx[0])
|
||||
s.append(segments[i][nidx:])
|
||||
return s
|
||||
|
||||
|
||||
def delete_dsstore(path='../datasets'):
|
||||
"""Delete Apple .DS_Store files in the specified directory and its subdirectories."""
|
||||
from pathlib import Path
|
||||
|
||||
files = list(Path(path).rglob('.DS_store'))
|
||||
print(files)
|
||||
for f in files:
|
||||
f.unlink()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
source = 'COCO'
|
||||
|
||||
if source == 'COCO':
|
||||
convert_coco(
|
||||
'../datasets/coco/annotations', # directory with *.json
|
||||
use_segments=False,
|
||||
use_keypoints=True,
|
||||
cls91to80=False)
|
@ -1,404 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
||||
from ultralytics.yolo.utils import LOGGER, ROOT, is_colab, is_kaggle, ops
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
|
||||
|
||||
@dataclass
|
||||
class SourceTypes:
|
||||
webcam: bool = False
|
||||
screenshot: bool = False
|
||||
from_img: bool = False
|
||||
tensor: bool = False
|
||||
|
||||
|
||||
class LoadStreams:
|
||||
"""YOLOv8 streamloader, i.e. `yolo predict source='rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`."""
|
||||
|
||||
def __init__(self, sources='file.streams', imgsz=640, vid_stride=1):
|
||||
"""Initialize instance variables and check for consistent input stream shapes."""
|
||||
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
|
||||
self.mode = 'stream'
|
||||
self.imgsz = imgsz
|
||||
self.vid_stride = vid_stride # video frame-rate stride
|
||||
sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
|
||||
n = len(sources)
|
||||
self.sources = [ops.clean_str(x) for x in sources] # clean source names for later
|
||||
self.imgs, self.fps, self.frames, self.threads, self.shape = [[]] * n, [0] * n, [0] * n, [None] * n, [None] * n
|
||||
for i, s in enumerate(sources): # index, source
|
||||
# Start thread to read frames from video stream
|
||||
st = f'{i + 1}/{n}: {s}... '
|
||||
if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video
|
||||
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc'
|
||||
s = get_best_youtube_url(s)
|
||||
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
|
||||
if s == 0 and (is_colab() or is_kaggle()):
|
||||
raise NotImplementedError("'source=0' webcam not supported in Colab and Kaggle notebooks. "
|
||||
"Try running 'source=0' in a local environment.")
|
||||
cap = cv2.VideoCapture(s)
|
||||
if not cap.isOpened():
|
||||
raise ConnectionError(f'{st}Failed to open {s}')
|
||||
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
|
||||
self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
|
||||
self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
|
||||
|
||||
success, im = cap.read() # guarantee first frame
|
||||
if not success or im is None:
|
||||
raise ConnectionError(f'{st}Failed to read images from {s}')
|
||||
self.imgs[i].append(im)
|
||||
self.shape[i] = im.shape
|
||||
self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
|
||||
LOGGER.info(f'{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)')
|
||||
self.threads[i].start()
|
||||
LOGGER.info('') # newline
|
||||
|
||||
# Check for common shapes
|
||||
self.bs = self.__len__()
|
||||
|
||||
def update(self, i, cap, stream):
|
||||
"""Read stream `i` frames in daemon thread."""
|
||||
n, f = 0, self.frames[i] # frame number, frame array
|
||||
while cap.isOpened() and n < f:
|
||||
# Only read a new frame if the buffer is empty
|
||||
if not self.imgs[i]:
|
||||
n += 1
|
||||
cap.grab() # .read() = .grab() followed by .retrieve()
|
||||
if n % self.vid_stride == 0:
|
||||
success, im = cap.retrieve()
|
||||
if success:
|
||||
self.imgs[i].append(im) # add image to buffer
|
||||
else:
|
||||
LOGGER.warning('WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.')
|
||||
self.imgs[i].append(np.zeros(self.shape[i]))
|
||||
cap.open(stream) # re-open stream if signal was lost
|
||||
else:
|
||||
time.sleep(0.01) # wait until the buffer is empty
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterates through YOLO image feed and re-opens unresponsive streams."""
|
||||
self.count = -1
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
"""Returns source paths, transformed and original images for processing."""
|
||||
self.count += 1
|
||||
|
||||
# Wait until a frame is available in each buffer
|
||||
while not all(self.imgs):
|
||||
if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
|
||||
cv2.destroyAllWindows()
|
||||
raise StopIteration
|
||||
time.sleep(1 / min(self.fps))
|
||||
|
||||
# Get and remove the next frame from imgs buffer
|
||||
return self.sources, [x.pop(0) for x in self.imgs], None, ''
|
||||
|
||||
def __len__(self):
|
||||
"""Return the length of the sources object."""
|
||||
return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
|
||||
|
||||
|
||||
class LoadScreenshots:
|
||||
"""YOLOv8 screenshot dataloader, i.e. `yolo predict source=screen`."""
|
||||
|
||||
def __init__(self, source, imgsz=640):
|
||||
"""source = [screen_number left top width height] (pixels)."""
|
||||
check_requirements('mss')
|
||||
import mss # noqa
|
||||
|
||||
source, *params = source.split()
|
||||
self.screen, left, top, width, height = 0, None, None, None, None # default to full screen 0
|
||||
if len(params) == 1:
|
||||
self.screen = int(params[0])
|
||||
elif len(params) == 4:
|
||||
left, top, width, height = (int(x) for x in params)
|
||||
elif len(params) == 5:
|
||||
self.screen, left, top, width, height = (int(x) for x in params)
|
||||
self.imgsz = imgsz
|
||||
self.mode = 'stream'
|
||||
self.frame = 0
|
||||
self.sct = mss.mss()
|
||||
self.bs = 1
|
||||
|
||||
# Parse monitor shape
|
||||
monitor = self.sct.monitors[self.screen]
|
||||
self.top = monitor['top'] if top is None else (monitor['top'] + top)
|
||||
self.left = monitor['left'] if left is None else (monitor['left'] + left)
|
||||
self.width = width or monitor['width']
|
||||
self.height = height or monitor['height']
|
||||
self.monitor = {'left': self.left, 'top': self.top, 'width': self.width, 'height': self.height}
|
||||
|
||||
def __iter__(self):
|
||||
"""Returns an iterator of the object."""
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
"""mss screen capture: get raw pixels from the screen as np array."""
|
||||
im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR
|
||||
s = f'screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: '
|
||||
|
||||
self.frame += 1
|
||||
return str(self.screen), im0, None, s # screen, img, original img, im0s, s
|
||||
|
||||
|
||||
class LoadImages:
|
||||
"""YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4`."""
|
||||
|
||||
def __init__(self, path, imgsz=640, vid_stride=1):
|
||||
"""Initialize the Dataloader and raise FileNotFoundError if file not found."""
|
||||
if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line
|
||||
path = Path(path).read_text().rsplit()
|
||||
files = []
|
||||
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
|
||||
p = str(Path(p).absolute()) # do not use .resolve() https://github.com/ultralytics/ultralytics/issues/2912
|
||||
if '*' in p:
|
||||
files.extend(sorted(glob.glob(p, recursive=True))) # glob
|
||||
elif os.path.isdir(p):
|
||||
files.extend(sorted(glob.glob(os.path.join(p, '*.*')))) # dir
|
||||
elif os.path.isfile(p):
|
||||
files.append(p) # files
|
||||
else:
|
||||
raise FileNotFoundError(f'{p} does not exist')
|
||||
|
||||
images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
|
||||
videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
|
||||
ni, nv = len(images), len(videos)
|
||||
|
||||
self.imgsz = imgsz
|
||||
self.files = images + videos
|
||||
self.nf = ni + nv # number of files
|
||||
self.video_flag = [False] * ni + [True] * nv
|
||||
self.mode = 'image'
|
||||
self.vid_stride = vid_stride # video frame-rate stride
|
||||
self.bs = 1
|
||||
if any(videos):
|
||||
self.orientation = None # rotation degrees
|
||||
self._new_video(videos[0]) # new video
|
||||
else:
|
||||
self.cap = None
|
||||
if self.nf == 0:
|
||||
raise FileNotFoundError(f'No images or videos found in {p}. '
|
||||
f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}')
|
||||
|
||||
def __iter__(self):
|
||||
"""Returns an iterator object for VideoStream or ImageFolder."""
|
||||
self.count = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
"""Return next image, path and metadata from dataset."""
|
||||
if self.count == self.nf:
|
||||
raise StopIteration
|
||||
path = self.files[self.count]
|
||||
|
||||
if self.video_flag[self.count]:
|
||||
# Read video
|
||||
self.mode = 'video'
|
||||
for _ in range(self.vid_stride):
|
||||
self.cap.grab()
|
||||
success, im0 = self.cap.retrieve()
|
||||
while not success:
|
||||
self.count += 1
|
||||
self.cap.release()
|
||||
if self.count == self.nf: # last video
|
||||
raise StopIteration
|
||||
path = self.files[self.count]
|
||||
self._new_video(path)
|
||||
success, im0 = self.cap.read()
|
||||
|
||||
self.frame += 1
|
||||
# im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
|
||||
s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
|
||||
|
||||
else:
|
||||
# Read image
|
||||
self.count += 1
|
||||
im0 = cv2.imread(path) # BGR
|
||||
if im0 is None:
|
||||
raise FileNotFoundError(f'Image Not Found {path}')
|
||||
s = f'image {self.count}/{self.nf} {path}: '
|
||||
|
||||
return [path], [im0], self.cap, s
|
||||
|
||||
def _new_video(self, path):
|
||||
"""Create a new video capture object."""
|
||||
self.frame = 0
|
||||
self.cap = cv2.VideoCapture(path)
|
||||
self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
|
||||
if hasattr(cv2, 'CAP_PROP_ORIENTATION_META'): # cv2<4.6.0 compatibility
|
||||
self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META)) # rotation degrees
|
||||
# Disable auto-orientation due to known issues in https://github.com/ultralytics/yolov5/issues/8493
|
||||
# self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0)
|
||||
|
||||
def _cv2_rotate(self, im):
|
||||
"""Rotate a cv2 video manually."""
|
||||
if self.orientation == 0:
|
||||
return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
|
||||
elif self.orientation == 180:
|
||||
return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
||||
elif self.orientation == 90:
|
||||
return cv2.rotate(im, cv2.ROTATE_180)
|
||||
return im
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the number of files in the object."""
|
||||
return self.nf # number of files
|
||||
|
||||
|
||||
class LoadPilAndNumpy:
|
||||
|
||||
def __init__(self, im0, imgsz=640):
|
||||
"""Initialize PIL and Numpy Dataloader."""
|
||||
if not isinstance(im0, list):
|
||||
im0 = [im0]
|
||||
self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)]
|
||||
self.im0 = [self._single_check(im) for im in im0]
|
||||
self.imgsz = imgsz
|
||||
self.mode = 'image'
|
||||
# Generate fake paths
|
||||
self.bs = len(self.im0)
|
||||
|
||||
@staticmethod
|
||||
def _single_check(im):
|
||||
"""Validate and format an image to numpy array."""
|
||||
assert isinstance(im, (Image.Image, np.ndarray)), f'Expected PIL/np.ndarray image type, but got {type(im)}'
|
||||
if isinstance(im, Image.Image):
|
||||
if im.mode != 'RGB':
|
||||
im = im.convert('RGB')
|
||||
im = np.asarray(im)[:, :, ::-1]
|
||||
im = np.ascontiguousarray(im) # contiguous
|
||||
return im
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the length of the 'im0' attribute."""
|
||||
return len(self.im0)
|
||||
|
||||
def __next__(self):
|
||||
"""Returns batch paths, images, processed images, None, ''."""
|
||||
if self.count == 1: # loop only once as it's batch inference
|
||||
raise StopIteration
|
||||
self.count += 1
|
||||
return self.paths, self.im0, None, ''
|
||||
|
||||
def __iter__(self):
|
||||
"""Enables iteration for class LoadPilAndNumpy."""
|
||||
self.count = 0
|
||||
return self
|
||||
|
||||
|
||||
class LoadTensor:
|
||||
|
||||
def __init__(self, im0) -> None:
|
||||
self.im0 = self._single_check(im0)
|
||||
self.bs = self.im0.shape[0]
|
||||
self.mode = 'image'
|
||||
self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)]
|
||||
|
||||
@staticmethod
|
||||
def _single_check(im, stride=32):
|
||||
"""Validate and format an image to torch.Tensor."""
|
||||
s = f'WARNING ⚠️ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) ' \
|
||||
f'divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible.'
|
||||
if len(im.shape) != 4:
|
||||
if len(im.shape) == 3:
|
||||
LOGGER.warning(s)
|
||||
im = im.unsqueeze(0)
|
||||
else:
|
||||
raise ValueError(s)
|
||||
if im.shape[2] % stride or im.shape[3] % stride:
|
||||
raise ValueError(s)
|
||||
if im.max() > 1.0:
|
||||
LOGGER.warning(f'WARNING ⚠️ torch.Tensor inputs should be normalized 0.0-1.0 but max value is {im.max()}. '
|
||||
f'Dividing input by 255.')
|
||||
im = im.float() / 255.0
|
||||
|
||||
return im
|
||||
|
||||
def __iter__(self):
|
||||
"""Returns an iterator object."""
|
||||
self.count = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
"""Return next item in the iterator."""
|
||||
if self.count == 1:
|
||||
raise StopIteration
|
||||
self.count += 1
|
||||
return self.paths, self.im0, None, ''
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the batch size."""
|
||||
return self.bs
|
||||
|
||||
|
||||
def autocast_list(source):
|
||||
"""
|
||||
Merges a list of source of different types into a list of numpy arrays or PIL images
|
||||
"""
|
||||
files = []
|
||||
for im in source:
|
||||
if isinstance(im, (str, Path)): # filename or uri
|
||||
files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im))
|
||||
elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image
|
||||
files.append(im)
|
||||
else:
|
||||
raise TypeError(f'type {type(im).__name__} is not a supported Ultralytics prediction source type. \n'
|
||||
f'See https://docs.ultralytics.com/modes/predict for supported source types.')
|
||||
|
||||
return files
|
||||
|
||||
|
||||
LOADERS = [LoadStreams, LoadPilAndNumpy, LoadImages, LoadScreenshots]
|
||||
|
||||
|
||||
def get_best_youtube_url(url, use_pafy=True):
|
||||
"""
|
||||
Retrieves the URL of the best quality MP4 video stream from a given YouTube video.
|
||||
|
||||
This function uses the pafy or yt_dlp library to extract the video info from YouTube. It then finds the highest
|
||||
quality MP4 format that has video codec but no audio codec, and returns the URL of this video stream.
|
||||
|
||||
Args:
|
||||
url (str): The URL of the YouTube video.
|
||||
use_pafy (bool): Use the pafy package, default=True, otherwise use yt_dlp package.
|
||||
|
||||
Returns:
|
||||
(str): The URL of the best quality MP4 video stream, or None if no suitable stream is found.
|
||||
"""
|
||||
if use_pafy:
|
||||
check_requirements(('pafy', 'youtube_dl==2020.12.2'))
|
||||
import pafy # noqa
|
||||
return pafy.new(url).getbest(preftype='mp4').url
|
||||
else:
|
||||
check_requirements('yt-dlp')
|
||||
import yt_dlp
|
||||
with yt_dlp.YoutubeDL({'quiet': True}) as ydl:
|
||||
info_dict = ydl.extract_info(url, download=False) # extract info
|
||||
for f in info_dict.get('formats', None):
|
||||
if f['vcodec'] != 'none' and f['acodec'] == 'none' and f['ext'] == 'mp4':
|
||||
return f.get('url', None)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
img = cv2.imread(str(ROOT / 'assets/bus.jpg'))
|
||||
dataset = LoadPilAndNumpy(im0=img)
|
||||
for d in dataset:
|
||||
print(d[0])
|
@ -1,274 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from itertools import repeat
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT, is_dir_writeable
|
||||
from .augment import Compose, Format, Instances, LetterBox, classify_albumentations, classify_transforms, v8_transforms
|
||||
from .base import BaseDataset
|
||||
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image_label
|
||||
|
||||
|
||||
class YOLODataset(BaseDataset):
|
||||
"""
|
||||
Dataset class for loading object detection and/or segmentation labels in YOLO format.
|
||||
|
||||
Args:
|
||||
data (dict, optional): A dataset YAML dictionary. Defaults to None.
|
||||
use_segments (bool, optional): If True, segmentation masks are used as labels. Defaults to False.
|
||||
use_keypoints (bool, optional): If True, keypoints are used as labels. Defaults to False.
|
||||
|
||||
Returns:
|
||||
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
|
||||
"""
|
||||
cache_version = '1.0.2' # dataset labels *.cache version, >= 1.0.0 for YOLOv8
|
||||
rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
|
||||
|
||||
def __init__(self, *args, data=None, use_segments=False, use_keypoints=False, **kwargs):
|
||||
self.use_segments = use_segments
|
||||
self.use_keypoints = use_keypoints
|
||||
self.data = data
|
||||
assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.'
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def cache_labels(self, path=Path('./labels.cache')):
|
||||
"""Cache dataset labels, check images and read shapes.
|
||||
Args:
|
||||
path (Path): path where to save the cache file (default: Path('./labels.cache')).
|
||||
Returns:
|
||||
(dict): labels.
|
||||
"""
|
||||
x = {'labels': []}
|
||||
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
||||
desc = f'{self.prefix}Scanning {path.parent / path.stem}...'
|
||||
total = len(self.im_files)
|
||||
nkpt, ndim = self.data.get('kpt_shape', (0, 0))
|
||||
if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)):
|
||||
raise ValueError("'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
|
||||
"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'")
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
results = pool.imap(func=verify_image_label,
|
||||
iterable=zip(self.im_files, self.label_files, repeat(self.prefix),
|
||||
repeat(self.use_keypoints), repeat(len(self.data['names'])), repeat(nkpt),
|
||||
repeat(ndim)))
|
||||
pbar = tqdm(results, desc=desc, total=total, bar_format=TQDM_BAR_FORMAT)
|
||||
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
|
||||
nm += nm_f
|
||||
nf += nf_f
|
||||
ne += ne_f
|
||||
nc += nc_f
|
||||
if im_file:
|
||||
x['labels'].append(
|
||||
dict(
|
||||
im_file=im_file,
|
||||
shape=shape,
|
||||
cls=lb[:, 0:1], # n, 1
|
||||
bboxes=lb[:, 1:], # n, 4
|
||||
segments=segments,
|
||||
keypoints=keypoint,
|
||||
normalized=True,
|
||||
bbox_format='xywh'))
|
||||
if msg:
|
||||
msgs.append(msg)
|
||||
pbar.desc = f'{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt'
|
||||
pbar.close()
|
||||
|
||||
if msgs:
|
||||
LOGGER.info('\n'.join(msgs))
|
||||
if nf == 0:
|
||||
LOGGER.warning(f'{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}')
|
||||
x['hash'] = get_hash(self.label_files + self.im_files)
|
||||
x['results'] = nf, nm, ne, nc, len(self.im_files)
|
||||
x['msgs'] = msgs # warnings
|
||||
x['version'] = self.cache_version # cache version
|
||||
if is_dir_writeable(path.parent):
|
||||
if path.exists():
|
||||
path.unlink() # remove *.cache file if exists
|
||||
np.save(str(path), x) # save cache for next time
|
||||
path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
|
||||
LOGGER.info(f'{self.prefix}New cache created: {path}')
|
||||
else:
|
||||
LOGGER.warning(f'{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.')
|
||||
return x
|
||||
|
||||
def get_labels(self):
|
||||
"""Returns dictionary of labels for YOLO training."""
|
||||
self.label_files = img2label_paths(self.im_files)
|
||||
cache_path = Path(self.label_files[0]).parent.with_suffix('.cache')
|
||||
try:
|
||||
import gc
|
||||
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
|
||||
cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dict
|
||||
gc.enable()
|
||||
assert cache['version'] == self.cache_version # matches current version
|
||||
assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash
|
||||
except (FileNotFoundError, AssertionError, AttributeError):
|
||||
cache, exists = self.cache_labels(cache_path), False # run cache ops
|
||||
|
||||
# Display cache
|
||||
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
|
||||
if exists and LOCAL_RANK in (-1, 0):
|
||||
d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt'
|
||||
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
|
||||
if cache['msgs']:
|
||||
LOGGER.info('\n'.join(cache['msgs'])) # display warnings
|
||||
if nf == 0: # number of labels found
|
||||
raise FileNotFoundError(f'{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}')
|
||||
|
||||
# Read cache
|
||||
[cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
|
||||
labels = cache['labels']
|
||||
self.im_files = [lb['im_file'] for lb in labels] # update im_files
|
||||
|
||||
# Check if the dataset is all boxes or all segments
|
||||
lengths = ((len(lb['cls']), len(lb['bboxes']), len(lb['segments'])) for lb in labels)
|
||||
len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
|
||||
if len_segments and len_boxes != len_segments:
|
||||
LOGGER.warning(
|
||||
f'WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, '
|
||||
f'len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. '
|
||||
'To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.')
|
||||
for lb in labels:
|
||||
lb['segments'] = []
|
||||
if len_cls == 0:
|
||||
raise ValueError(f'All labels empty in {cache_path}, can not start training without labels. {HELP_URL}')
|
||||
return labels
|
||||
|
||||
# TODO: use hyp config to set all these augmentations
|
||||
def build_transforms(self, hyp=None):
|
||||
"""Builds and appends transforms to the list."""
|
||||
if self.augment:
|
||||
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
|
||||
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
|
||||
transforms = v8_transforms(self, self.imgsz, hyp)
|
||||
else:
|
||||
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
|
||||
transforms.append(
|
||||
Format(bbox_format='xywh',
|
||||
normalize=True,
|
||||
return_mask=self.use_segments,
|
||||
return_keypoint=self.use_keypoints,
|
||||
batch_idx=True,
|
||||
mask_ratio=hyp.mask_ratio,
|
||||
mask_overlap=hyp.overlap_mask))
|
||||
return transforms
|
||||
|
||||
def close_mosaic(self, hyp):
|
||||
"""Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations."""
|
||||
hyp.mosaic = 0.0 # set mosaic ratio=0.0
|
||||
hyp.copy_paste = 0.0 # keep the same behavior as previous v8 close-mosaic
|
||||
hyp.mixup = 0.0 # keep the same behavior as previous v8 close-mosaic
|
||||
self.transforms = self.build_transforms(hyp)
|
||||
|
||||
def update_labels_info(self, label):
|
||||
"""custom your label format here."""
|
||||
# NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label
|
||||
# we can make it also support classification and semantic segmentation by add or remove some dict keys there.
|
||||
bboxes = label.pop('bboxes')
|
||||
segments = label.pop('segments')
|
||||
keypoints = label.pop('keypoints', None)
|
||||
bbox_format = label.pop('bbox_format')
|
||||
normalized = label.pop('normalized')
|
||||
label['instances'] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
|
||||
return label
|
||||
|
||||
@staticmethod
|
||||
def collate_fn(batch):
|
||||
"""Collates data samples into batches."""
|
||||
new_batch = {}
|
||||
keys = batch[0].keys()
|
||||
values = list(zip(*[list(b.values()) for b in batch]))
|
||||
for i, k in enumerate(keys):
|
||||
value = values[i]
|
||||
if k == 'img':
|
||||
value = torch.stack(value, 0)
|
||||
if k in ['masks', 'keypoints', 'bboxes', 'cls']:
|
||||
value = torch.cat(value, 0)
|
||||
new_batch[k] = value
|
||||
new_batch['batch_idx'] = list(new_batch['batch_idx'])
|
||||
for i in range(len(new_batch['batch_idx'])):
|
||||
new_batch['batch_idx'][i] += i # add target image index for build_targets()
|
||||
new_batch['batch_idx'] = torch.cat(new_batch['batch_idx'], 0)
|
||||
return new_batch
|
||||
|
||||
|
||||
# Classification dataloaders -------------------------------------------------------------------------------------------
|
||||
class ClassificationDataset(torchvision.datasets.ImageFolder):
|
||||
"""
|
||||
YOLO Classification Dataset.
|
||||
|
||||
Args:
|
||||
root (str): Dataset path.
|
||||
|
||||
Attributes:
|
||||
cache_ram (bool): True if images should be cached in RAM, False otherwise.
|
||||
cache_disk (bool): True if images should be cached on disk, False otherwise.
|
||||
samples (list): List of samples containing file, index, npy, and im.
|
||||
torch_transforms (callable): torchvision transforms applied to the dataset.
|
||||
album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True.
|
||||
"""
|
||||
|
||||
def __init__(self, root, args, augment=False, cache=False):
|
||||
"""
|
||||
Initialize YOLO object with root, image size, augmentations, and cache settings.
|
||||
|
||||
Args:
|
||||
root (str): Dataset path.
|
||||
args (Namespace): Argument parser containing dataset related settings.
|
||||
augment (bool, optional): True if dataset should be augmented, False otherwise. Defaults to False.
|
||||
cache (bool | str | optional): Cache setting, can be True, False, 'ram' or 'disk'. Defaults to False.
|
||||
"""
|
||||
super().__init__(root=root)
|
||||
if augment and args.fraction < 1.0: # reduce training fraction
|
||||
self.samples = self.samples[:round(len(self.samples) * args.fraction)]
|
||||
self.cache_ram = cache is True or cache == 'ram'
|
||||
self.cache_disk = cache == 'disk'
|
||||
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
|
||||
self.torch_transforms = classify_transforms(args.imgsz)
|
||||
self.album_transforms = classify_albumentations(
|
||||
augment=augment,
|
||||
size=args.imgsz,
|
||||
scale=(1.0 - args.scale, 1.0), # (0.08, 1.0)
|
||||
hflip=args.fliplr,
|
||||
vflip=args.flipud,
|
||||
hsv_h=args.hsv_h, # HSV-Hue augmentation (fraction)
|
||||
hsv_s=args.hsv_s, # HSV-Saturation augmentation (fraction)
|
||||
hsv_v=args.hsv_v, # HSV-Value augmentation (fraction)
|
||||
mean=(0.0, 0.0, 0.0), # IMAGENET_MEAN
|
||||
std=(1.0, 1.0, 1.0), # IMAGENET_STD
|
||||
auto_aug=False) if augment else None
|
||||
|
||||
def __getitem__(self, i):
|
||||
"""Returns subset of data and targets corresponding to given indices."""
|
||||
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
|
||||
if self.cache_ram and im is None:
|
||||
im = self.samples[i][3] = cv2.imread(f)
|
||||
elif self.cache_disk:
|
||||
if not fn.exists(): # load npy
|
||||
np.save(fn.as_posix(), cv2.imread(f))
|
||||
im = np.load(fn)
|
||||
else: # read image
|
||||
im = cv2.imread(f) # BGR
|
||||
if self.album_transforms:
|
||||
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))['image']
|
||||
else:
|
||||
sample = self.torch_transforms(im)
|
||||
return {'img': sample, 'cls': j}
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.samples)
|
||||
|
||||
|
||||
# TODO: support semantic segmentation
|
||||
class SemanticDataset(BaseDataset):
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize a SemanticDataset object."""
|
||||
super().__init__()
|
@ -1,53 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import collections
|
||||
from copy import deepcopy
|
||||
|
||||
from .augment import LetterBox
|
||||
|
||||
|
||||
class MixAndRectDataset:
|
||||
"""
|
||||
A dataset class that applies mosaic and mixup transformations as well as rectangular training.
|
||||
|
||||
Attributes:
|
||||
dataset: The base dataset.
|
||||
imgsz: The size of the images in the dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset):
|
||||
"""
|
||||
Args:
|
||||
dataset (BaseDataset): The base dataset to apply transformations to.
|
||||
"""
|
||||
self.dataset = dataset
|
||||
self.imgsz = dataset.imgsz
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the number of items in the dataset."""
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Applies mosaic, mixup and rectangular training transformations to an item in the dataset.
|
||||
|
||||
Args:
|
||||
index (int): Index of the item in the dataset.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary containing the transformed item data.
|
||||
"""
|
||||
labels = deepcopy(self.dataset[index])
|
||||
for transform in self.dataset.transforms.tolist():
|
||||
# Mosaic and mixup
|
||||
if hasattr(transform, 'get_indexes'):
|
||||
indexes = transform.get_indexes(self.dataset)
|
||||
if not isinstance(indexes, collections.abc.Sequence):
|
||||
indexes = [indexes]
|
||||
labels['mix_labels'] = [deepcopy(self.dataset[index]) for index in indexes]
|
||||
if self.dataset.rect and isinstance(transform, LetterBox):
|
||||
transform.new_shape = self.dataset.batch_shapes[self.dataset.batch[index]]
|
||||
labels = transform(labels)
|
||||
if 'mix_labels' in labels:
|
||||
labels.pop('mix_labels')
|
||||
return labels
|
@ -1,18 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
# Download latest models from https://github.com/ultralytics/assets/releases
|
||||
# Example usage: bash ultralytics/yolo/data/scripts/download_weights.sh
|
||||
# parent
|
||||
# └── weights
|
||||
# ├── yolov8n.pt ← downloads here
|
||||
# ├── yolov8s.pt
|
||||
# └── ...
|
||||
|
||||
python - <<EOF
|
||||
from ultralytics.yolo.utils.downloads import attempt_download_asset
|
||||
|
||||
assets = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '-cls', '-seg')]
|
||||
for x in assets:
|
||||
attempt_download_asset(f'weights/{x}')
|
||||
|
||||
EOF
|
@ -1,60 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
# Download COCO 2017 dataset http://cocodataset.org
|
||||
# Example usage: bash data/scripts/get_coco.sh
|
||||
# parent
|
||||
# ├── ultralytics
|
||||
# └── datasets
|
||||
# └── coco ← downloads here
|
||||
|
||||
# Arguments (optional) Usage: bash data/scripts/get_coco.sh --train --val --test --segments
|
||||
if [ "$#" -gt 0 ]; then
|
||||
for opt in "$@"; do
|
||||
case "${opt}" in
|
||||
--train) train=true ;;
|
||||
--val) val=true ;;
|
||||
--test) test=true ;;
|
||||
--segments) segments=true ;;
|
||||
--sama) sama=true ;;
|
||||
esac
|
||||
done
|
||||
else
|
||||
train=true
|
||||
val=true
|
||||
test=false
|
||||
segments=false
|
||||
sama=false
|
||||
fi
|
||||
|
||||
# Download/unzip labels
|
||||
d='../datasets' # unzip directory
|
||||
url=https://github.com/ultralytics/yolov5/releases/download/v1.0/
|
||||
if [ "$segments" == "true" ]; then
|
||||
f='coco2017labels-segments.zip' # 169 MB
|
||||
elif [ "$sama" == "true" ]; then
|
||||
f='coco2017labels-segments-sama.zip' # 199 MB https://www.sama.com/sama-coco-dataset/
|
||||
else
|
||||
f='coco2017labels.zip' # 46 MB
|
||||
fi
|
||||
echo 'Downloading' $url$f ' ...'
|
||||
curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
|
||||
|
||||
# Download/unzip images
|
||||
d='../datasets/coco/images' # unzip directory
|
||||
url=http://images.cocodataset.org/zips/
|
||||
if [ "$train" == "true" ]; then
|
||||
f='train2017.zip' # 19G, 118k images
|
||||
echo 'Downloading' $url$f '...'
|
||||
curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
|
||||
fi
|
||||
if [ "$val" == "true" ]; then
|
||||
f='val2017.zip' # 1G, 5k images
|
||||
echo 'Downloading' $url$f '...'
|
||||
curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
|
||||
fi
|
||||
if [ "$test" == "true" ]; then
|
||||
f='test2017.zip' # 7G, 41k images (optional)
|
||||
echo 'Downloading' $url$f '...'
|
||||
curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
|
||||
fi
|
||||
wait # finish background tasks
|
@ -1,17 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
# Download COCO128 dataset https://www.kaggle.com/ultralytics/coco128 (first 128 images from COCO train2017)
|
||||
# Example usage: bash data/scripts/get_coco128.sh
|
||||
# parent
|
||||
# ├── ultralytics
|
||||
# └── datasets
|
||||
# └── coco128 ← downloads here
|
||||
|
||||
# Download/unzip images and labels
|
||||
d='../datasets' # unzip directory
|
||||
url=https://github.com/ultralytics/yolov5/releases/download/v1.0/
|
||||
f='coco128.zip' # or 'coco128-segments.zip', 68 MB
|
||||
echo 'Downloading' $url$f ' ...'
|
||||
curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
|
||||
|
||||
wait # finish background tasks
|
@ -1,51 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
# Download ILSVRC2012 ImageNet dataset https://image-net.org
|
||||
# Example usage: bash data/scripts/get_imagenet.sh
|
||||
# parent
|
||||
# ├── ultralytics
|
||||
# └── datasets
|
||||
# └── imagenet ← downloads here
|
||||
|
||||
# Arguments (optional) Usage: bash data/scripts/get_imagenet.sh --train --val
|
||||
if [ "$#" -gt 0 ]; then
|
||||
for opt in "$@"; do
|
||||
case "${opt}" in
|
||||
--train) train=true ;;
|
||||
--val) val=true ;;
|
||||
esac
|
||||
done
|
||||
else
|
||||
train=true
|
||||
val=true
|
||||
fi
|
||||
|
||||
# Make dir
|
||||
d='../datasets/imagenet' # unzip directory
|
||||
mkdir -p $d && cd $d
|
||||
|
||||
# Download/unzip train
|
||||
if [ "$train" == "true" ]; then
|
||||
wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar # download 138G, 1281167 images
|
||||
mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train
|
||||
tar -xf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar
|
||||
find . -name "*.tar" | while read NAME; do
|
||||
mkdir -p "${NAME%.tar}"
|
||||
tar -xf "${NAME}" -C "${NAME%.tar}"
|
||||
rm -f "${NAME}"
|
||||
done
|
||||
cd ..
|
||||
fi
|
||||
|
||||
# Download/unzip val
|
||||
if [ "$val" == "true" ]; then
|
||||
wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar # download 6.3G, 50000 images
|
||||
mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xf ILSVRC2012_img_val.tar
|
||||
wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash # move into subdirs
|
||||
fi
|
||||
|
||||
# Delete corrupted image (optional: PNG under JPEG name that may cause dataloaders to fail)
|
||||
# rm train/n04266014/n04266014_10835.JPEG
|
||||
|
||||
# TFRecords (optional)
|
||||
# wget https://raw.githubusercontent.com/tensorflow/models/master/research/slim/datasets/imagenet_lsvrc_2015_synsets.txt
|
@ -1,557 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import contextlib
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import subprocess
|
||||
import time
|
||||
import zipfile
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
from tarfile import is_tarfile
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import ExifTags, Image, ImageOps
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.nn.autobackend import check_class_names
|
||||
from ultralytics.yolo.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, SETTINGS_YAML, clean_url, colorstr, emojis,
|
||||
yaml_load)
|
||||
from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii
|
||||
from ultralytics.yolo.utils.downloads import download, safe_download, unzip_file
|
||||
from ultralytics.yolo.utils.ops import segments2boxes
|
||||
|
||||
HELP_URL = 'See https://docs.ultralytics.com/yolov5/tutorials/train_custom_data'
|
||||
IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # image suffixes
|
||||
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv', 'webm' # video suffixes
|
||||
PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
|
||||
IMAGENET_MEAN = 0.485, 0.456, 0.406 # RGB mean
|
||||
IMAGENET_STD = 0.229, 0.224, 0.225 # RGB standard deviation
|
||||
|
||||
# Get orientation exif tag
|
||||
for orientation in ExifTags.TAGS.keys():
|
||||
if ExifTags.TAGS[orientation] == 'Orientation':
|
||||
break
|
||||
|
||||
|
||||
def img2label_paths(img_paths):
|
||||
"""Define label paths as a function of image paths."""
|
||||
sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}' # /images/, /labels/ substrings
|
||||
return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
|
||||
|
||||
|
||||
def get_hash(paths):
|
||||
"""Returns a single hash value of a list of paths (files or dirs)."""
|
||||
size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
|
||||
h = hashlib.sha256(str(size).encode()) # hash sizes
|
||||
h.update(''.join(paths).encode()) # hash paths
|
||||
return h.hexdigest() # return hash
|
||||
|
||||
|
||||
def exif_size(img):
|
||||
"""Returns exif-corrected PIL size."""
|
||||
s = img.size # (width, height)
|
||||
with contextlib.suppress(Exception):
|
||||
rotation = dict(img._getexif().items())[orientation]
|
||||
if rotation in [6, 8]: # rotation 270 or 90
|
||||
s = (s[1], s[0])
|
||||
return s
|
||||
|
||||
|
||||
def verify_image_label(args):
|
||||
"""Verify one image-label pair."""
|
||||
im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args
|
||||
# Number (missing, found, empty, corrupt), message, segments, keypoints
|
||||
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, '', [], None
|
||||
try:
|
||||
# Verify images
|
||||
im = Image.open(im_file)
|
||||
im.verify() # PIL verify
|
||||
shape = exif_size(im) # image size
|
||||
shape = (shape[1], shape[0]) # hw
|
||||
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
|
||||
assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
|
||||
if im.format.lower() in ('jpg', 'jpeg'):
|
||||
with open(im_file, 'rb') as f:
|
||||
f.seek(-2, 2)
|
||||
if f.read() != b'\xff\xd9': # corrupt JPEG
|
||||
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
|
||||
msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'
|
||||
|
||||
# Verify labels
|
||||
if os.path.isfile(lb_file):
|
||||
nf = 1 # label found
|
||||
with open(lb_file) as f:
|
||||
lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
|
||||
if any(len(x) > 6 for x in lb) and (not keypoint): # is segment
|
||||
classes = np.array([x[0] for x in lb], dtype=np.float32)
|
||||
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
|
||||
lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
|
||||
lb = np.array(lb, dtype=np.float32)
|
||||
nl = len(lb)
|
||||
if nl:
|
||||
if keypoint:
|
||||
assert lb.shape[1] == (5 + nkpt * ndim), f'labels require {(5 + nkpt * ndim)} columns each'
|
||||
assert (lb[:, 5::ndim] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
|
||||
assert (lb[:, 6::ndim] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
|
||||
else:
|
||||
assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
|
||||
assert (lb[:, 1:] <= 1).all(), \
|
||||
f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}'
|
||||
assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
|
||||
# All labels
|
||||
max_cls = int(lb[:, 0].max()) # max label count
|
||||
assert max_cls <= num_cls, \
|
||||
f'Label class {max_cls} exceeds dataset class count {num_cls}. ' \
|
||||
f'Possible class labels are 0-{num_cls - 1}'
|
||||
_, i = np.unique(lb, axis=0, return_index=True)
|
||||
if len(i) < nl: # duplicate row check
|
||||
lb = lb[i] # remove duplicates
|
||||
if segments:
|
||||
segments = [segments[x] for x in i]
|
||||
msg = f'{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed'
|
||||
else:
|
||||
ne = 1 # label empty
|
||||
lb = np.zeros((0, (5 + nkpt * ndim)), dtype=np.float32) if keypoint else np.zeros(
|
||||
(0, 5), dtype=np.float32)
|
||||
else:
|
||||
nm = 1 # label missing
|
||||
lb = np.zeros((0, (5 + nkpt * ndim)), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32)
|
||||
if keypoint:
|
||||
keypoints = lb[:, 5:].reshape(-1, nkpt, ndim)
|
||||
if ndim == 2:
|
||||
kpt_mask = np.ones(keypoints.shape[:2], dtype=np.float32)
|
||||
kpt_mask = np.where(keypoints[..., 0] < 0, 0.0, kpt_mask)
|
||||
kpt_mask = np.where(keypoints[..., 1] < 0, 0.0, kpt_mask)
|
||||
keypoints = np.concatenate([keypoints, kpt_mask[..., None]], axis=-1) # (nl, nkpt, 3)
|
||||
lb = lb[:, :5]
|
||||
return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
|
||||
except Exception as e:
|
||||
nc = 1
|
||||
msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
|
||||
return [None, None, None, None, None, nm, nf, ne, nc, msg]
|
||||
|
||||
|
||||
def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1):
|
||||
"""
|
||||
Args:
|
||||
imgsz (tuple): The image size.
|
||||
polygons (list[np.ndarray]): [N, M], N is the number of polygons, M is the number of points(Be divided by 2).
|
||||
color (int): color
|
||||
downsample_ratio (int): downsample ratio
|
||||
"""
|
||||
mask = np.zeros(imgsz, dtype=np.uint8)
|
||||
polygons = np.asarray(polygons)
|
||||
polygons = polygons.astype(np.int32)
|
||||
shape = polygons.shape
|
||||
polygons = polygons.reshape(shape[0], -1, 2)
|
||||
cv2.fillPoly(mask, polygons, color=color)
|
||||
nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio)
|
||||
# NOTE: fillPoly firstly then resize is trying the keep the same way
|
||||
# of loss calculation when mask-ratio=1.
|
||||
mask = cv2.resize(mask, (nw, nh))
|
||||
return mask
|
||||
|
||||
|
||||
def polygons2masks(imgsz, polygons, color, downsample_ratio=1):
|
||||
"""
|
||||
Args:
|
||||
imgsz (tuple): The image size.
|
||||
polygons (list[np.ndarray]): each polygon is [N, M], N is number of polygons, M is number of points (M % 2 = 0)
|
||||
color (int): color
|
||||
downsample_ratio (int): downsample ratio
|
||||
"""
|
||||
masks = []
|
||||
for si in range(len(polygons)):
|
||||
mask = polygon2mask(imgsz, [polygons[si].reshape(-1)], color, downsample_ratio)
|
||||
masks.append(mask)
|
||||
return np.array(masks)
|
||||
|
||||
|
||||
def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
|
||||
"""Return a (640, 640) overlap mask."""
|
||||
masks = np.zeros((imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
|
||||
dtype=np.int32 if len(segments) > 255 else np.uint8)
|
||||
areas = []
|
||||
ms = []
|
||||
for si in range(len(segments)):
|
||||
mask = polygon2mask(imgsz, [segments[si].reshape(-1)], downsample_ratio=downsample_ratio, color=1)
|
||||
ms.append(mask)
|
||||
areas.append(mask.sum())
|
||||
areas = np.asarray(areas)
|
||||
index = np.argsort(-areas)
|
||||
ms = np.array(ms)[index]
|
||||
for i in range(len(segments)):
|
||||
mask = ms[i] * (i + 1)
|
||||
masks = masks + mask
|
||||
masks = np.clip(masks, a_min=0, a_max=i + 1)
|
||||
return masks, index
|
||||
|
||||
|
||||
def check_det_dataset(dataset, autodownload=True):
|
||||
"""Download, check and/or unzip dataset if not found locally."""
|
||||
data = check_file(dataset)
|
||||
|
||||
# Download (optional)
|
||||
extract_dir = ''
|
||||
if isinstance(data, (str, Path)) and (zipfile.is_zipfile(data) or is_tarfile(data)):
|
||||
new_dir = safe_download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False)
|
||||
data = next((DATASETS_DIR / new_dir).rglob('*.yaml'))
|
||||
extract_dir, autodownload = data.parent, False
|
||||
|
||||
# Read yaml (optional)
|
||||
if isinstance(data, (str, Path)):
|
||||
data = yaml_load(data, append_filename=True) # dictionary
|
||||
|
||||
# Checks
|
||||
for k in 'train', 'val':
|
||||
if k not in data:
|
||||
raise SyntaxError(
|
||||
emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs."))
|
||||
if 'names' not in data and 'nc' not in data:
|
||||
raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs."))
|
||||
if 'names' in data and 'nc' in data and len(data['names']) != data['nc']:
|
||||
raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match."))
|
||||
if 'names' not in data:
|
||||
data['names'] = [f'class_{i}' for i in range(data['nc'])]
|
||||
else:
|
||||
data['nc'] = len(data['names'])
|
||||
|
||||
data['names'] = check_class_names(data['names'])
|
||||
|
||||
# Resolve paths
|
||||
path = Path(extract_dir or data.get('path') or Path(data.get('yaml_file', '')).parent) # dataset root
|
||||
|
||||
if not path.is_absolute():
|
||||
path = (DATASETS_DIR / path).resolve()
|
||||
data['path'] = path # download scripts
|
||||
for k in 'train', 'val', 'test':
|
||||
if data.get(k): # prepend path
|
||||
if isinstance(data[k], str):
|
||||
x = (path / data[k]).resolve()
|
||||
if not x.exists() and data[k].startswith('../'):
|
||||
x = (path / data[k][3:]).resolve()
|
||||
data[k] = str(x)
|
||||
else:
|
||||
data[k] = [str((path / x).resolve()) for x in data[k]]
|
||||
|
||||
# Parse yaml
|
||||
train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
|
||||
if val:
|
||||
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
|
||||
if not all(x.exists() for x in val):
|
||||
name = clean_url(dataset) # dataset name with URL auth stripped
|
||||
m = f"\nDataset '{name}' images not found ⚠️, missing paths %s" % [str(x) for x in val if not x.exists()]
|
||||
if s and autodownload:
|
||||
LOGGER.warning(m)
|
||||
else:
|
||||
m += f"\nNote dataset download directory is '{DATASETS_DIR}'. You can update this in '{SETTINGS_YAML}'"
|
||||
raise FileNotFoundError(m)
|
||||
t = time.time()
|
||||
if s.startswith('http') and s.endswith('.zip'): # URL
|
||||
safe_download(url=s, dir=DATASETS_DIR, delete=True)
|
||||
r = None # success
|
||||
elif s.startswith('bash '): # bash script
|
||||
LOGGER.info(f'Running {s} ...')
|
||||
r = os.system(s)
|
||||
else: # python script
|
||||
r = exec(s, {'yaml': data}) # return None
|
||||
dt = f'({round(time.time() - t, 1)}s)'
|
||||
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f'failure {dt} ❌'
|
||||
LOGGER.info(f'Dataset download {s}\n')
|
||||
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf') # download fonts
|
||||
|
||||
return data # dictionary
|
||||
|
||||
|
||||
def check_cls_dataset(dataset: str, split=''):
|
||||
"""
|
||||
Checks a classification dataset such as Imagenet.
|
||||
|
||||
This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information.
|
||||
If the dataset is not found locally, it attempts to download the dataset from the internet and save it locally.
|
||||
|
||||
Args:
|
||||
dataset (str): The name of the dataset.
|
||||
split (str, optional): The split of the dataset. Either 'val', 'test', or ''. Defaults to ''.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary containing the following keys:
|
||||
- 'train' (Path): The directory path containing the training set of the dataset.
|
||||
- 'val' (Path): The directory path containing the validation set of the dataset.
|
||||
- 'test' (Path): The directory path containing the test set of the dataset.
|
||||
- 'nc' (int): The number of classes in the dataset.
|
||||
- 'names' (dict): A dictionary of class names in the dataset.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the specified dataset is not found and cannot be downloaded.
|
||||
"""
|
||||
|
||||
dataset = Path(dataset)
|
||||
data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
|
||||
if not data_dir.is_dir():
|
||||
LOGGER.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
|
||||
t = time.time()
|
||||
if str(dataset) == 'imagenet':
|
||||
subprocess.run(f"bash {ROOT / 'yolo/data/scripts/get_imagenet.sh'}", shell=True, check=True)
|
||||
else:
|
||||
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
|
||||
download(url, dir=data_dir.parent)
|
||||
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
|
||||
LOGGER.info(s)
|
||||
train_set = data_dir / 'train'
|
||||
val_set = data_dir / 'val' if (data_dir / 'val').exists() else None # data/test or data/val
|
||||
test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test
|
||||
if split == 'val' and not val_set:
|
||||
LOGGER.info("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
|
||||
elif split == 'test' and not test_set:
|
||||
LOGGER.info("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.")
|
||||
|
||||
nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
|
||||
names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
|
||||
names = dict(enumerate(sorted(names)))
|
||||
return {'train': train_set, 'val': val_set or test_set, 'test': test_set or val_set, 'nc': nc, 'names': names}
|
||||
|
||||
|
||||
class HUBDatasetStats():
|
||||
"""
|
||||
A class for generating HUB dataset JSON and `-hub` dataset directory.
|
||||
|
||||
Args:
|
||||
path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip). Default is 'coco128.yaml'.
|
||||
task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Default is 'detect'.
|
||||
autodownload (bool): Attempt to download dataset if not found locally. Default is False.
|
||||
|
||||
Usage
|
||||
from ultralytics.yolo.data.utils import HUBDatasetStats
|
||||
stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco8.zip', task='detect') # detect dataset
|
||||
stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco8-seg.zip', task='segment') # segment dataset
|
||||
stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco8-pose.zip', task='pose') # pose dataset
|
||||
stats.get_json(save=False)
|
||||
stats.process_images()
|
||||
"""
|
||||
|
||||
def __init__(self, path='coco128.yaml', task='detect', autodownload=False):
|
||||
"""Initialize class."""
|
||||
LOGGER.info(f'Starting HUB dataset checks for {path}....')
|
||||
zipped, data_dir, yaml_path = self._unzip(Path(path))
|
||||
try:
|
||||
# data = yaml_load(check_yaml(yaml_path)) # data dict
|
||||
data = check_det_dataset(yaml_path, autodownload) # data dict
|
||||
if zipped:
|
||||
data['path'] = data_dir
|
||||
except Exception as e:
|
||||
raise Exception('error/HUB/dataset_stats/yaml_load') from e
|
||||
|
||||
self.hub_dir = Path(str(data['path']) + '-hub')
|
||||
self.im_dir = self.hub_dir / 'images'
|
||||
self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
|
||||
self.stats = {'nc': len(data['names']), 'names': list(data['names'].values())} # statistics dictionary
|
||||
self.data = data
|
||||
self.task = task # detect, segment, pose, classify
|
||||
|
||||
@staticmethod
|
||||
def _find_yaml(dir):
|
||||
"""Return data.yaml file."""
|
||||
files = list(dir.glob('*.yaml')) or list(dir.rglob('*.yaml')) # try root level first and then recursive
|
||||
assert files, f'No *.yaml file found in {dir}'
|
||||
if len(files) > 1:
|
||||
files = [f for f in files if f.stem == dir.stem] # prefer *.yaml files that match dir name
|
||||
assert files, f'Multiple *.yaml files found in {dir}, only 1 *.yaml file allowed'
|
||||
assert len(files) == 1, f'Multiple *.yaml files found: {files}, only 1 *.yaml file allowed in {dir}'
|
||||
return files[0]
|
||||
|
||||
def _unzip(self, path):
|
||||
"""Unzip data.zip."""
|
||||
if not str(path).endswith('.zip'): # path is data.yaml
|
||||
return False, None, path
|
||||
unzip_dir = unzip_file(path, path=path.parent)
|
||||
assert unzip_dir.is_dir(), f'Error unzipping {path}, {unzip_dir} not found. ' \
|
||||
f'path/to/abc.zip MUST unzip to path/to/abc/'
|
||||
return True, str(unzip_dir), self._find_yaml(unzip_dir) # zipped, data_dir, yaml_path
|
||||
|
||||
def _hub_ops(self, f):
|
||||
"""Saves a compressed image for HUB previews."""
|
||||
compress_one_image(f, self.im_dir / Path(f).name) # save to dataset-hub
|
||||
|
||||
def get_json(self, save=False, verbose=False):
|
||||
"""Return dataset JSON for Ultralytics HUB."""
|
||||
from ultralytics.yolo.data import YOLODataset # ClassificationDataset
|
||||
|
||||
def _round(labels):
|
||||
"""Update labels to integer class and 4 decimal place floats."""
|
||||
if self.task == 'detect':
|
||||
coordinates = labels['bboxes']
|
||||
elif self.task == 'segment':
|
||||
coordinates = [x.flatten() for x in labels['segments']]
|
||||
elif self.task == 'pose':
|
||||
n = labels['keypoints'].shape[0]
|
||||
coordinates = np.concatenate((labels['bboxes'], labels['keypoints'].reshape(n, -1)), 1)
|
||||
else:
|
||||
raise ValueError('Undefined dataset task.')
|
||||
zipped = zip(labels['cls'], coordinates)
|
||||
return [[int(c), *(round(float(x), 4) for x in points)] for c, points in zipped]
|
||||
|
||||
for split in 'train', 'val', 'test':
|
||||
if self.data.get(split) is None:
|
||||
self.stats[split] = None # i.e. no test set
|
||||
continue
|
||||
|
||||
dataset = YOLODataset(img_path=self.data[split],
|
||||
data=self.data,
|
||||
use_segments=self.task == 'segment',
|
||||
use_keypoints=self.task == 'pose')
|
||||
x = np.array([
|
||||
np.bincount(label['cls'].astype(int).flatten(), minlength=self.data['nc'])
|
||||
for label in tqdm(dataset.labels, total=len(dataset), desc='Statistics')]) # shape(128x80)
|
||||
self.stats[split] = {
|
||||
'instance_stats': {
|
||||
'total': int(x.sum()),
|
||||
'per_class': x.sum(0).tolist()},
|
||||
'image_stats': {
|
||||
'total': len(dataset),
|
||||
'unlabelled': int(np.all(x == 0, 1).sum()),
|
||||
'per_class': (x > 0).sum(0).tolist()},
|
||||
'labels': [{
|
||||
Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)]}
|
||||
|
||||
# Save, print and return
|
||||
if save:
|
||||
stats_path = self.hub_dir / 'stats.json'
|
||||
LOGGER.info(f'Saving {stats_path.resolve()}...')
|
||||
with open(stats_path, 'w') as f:
|
||||
json.dump(self.stats, f) # save stats.json
|
||||
if verbose:
|
||||
LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
|
||||
return self.stats
|
||||
|
||||
def process_images(self):
|
||||
"""Compress images for Ultralytics HUB."""
|
||||
from ultralytics.yolo.data import YOLODataset # ClassificationDataset
|
||||
|
||||
for split in 'train', 'val', 'test':
|
||||
if self.data.get(split) is None:
|
||||
continue
|
||||
dataset = YOLODataset(img_path=self.data[split], data=self.data)
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
for _ in tqdm(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f'{split} images'):
|
||||
pass
|
||||
LOGGER.info(f'Done. All images saved to {self.im_dir}')
|
||||
return self.im_dir
|
||||
|
||||
|
||||
def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
|
||||
"""
|
||||
Compresses a single image file to reduced size while preserving its aspect ratio and quality using either the
|
||||
Python Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it will
|
||||
not be resized.
|
||||
|
||||
Args:
|
||||
f (str): The path to the input image file.
|
||||
f_new (str, optional): The path to the output image file. If not specified, the input file will be overwritten.
|
||||
max_dim (int, optional): The maximum dimension (width or height) of the output image. Default is 1920 pixels.
|
||||
quality (int, optional): The image compression quality as a percentage. Default is 50%.
|
||||
|
||||
Usage:
|
||||
from pathlib import Path
|
||||
from ultralytics.yolo.data.utils import compress_one_image
|
||||
for f in Path('/Users/glennjocher/Downloads/dataset').rglob('*.jpg'):
|
||||
compress_one_image(f)
|
||||
"""
|
||||
try: # use PIL
|
||||
im = Image.open(f)
|
||||
r = max_dim / max(im.height, im.width) # ratio
|
||||
if r < 1.0: # image too large
|
||||
im = im.resize((int(im.width * r), int(im.height * r)))
|
||||
im.save(f_new or f, 'JPEG', quality=quality, optimize=True) # save
|
||||
except Exception as e: # use OpenCV
|
||||
LOGGER.info(f'WARNING ⚠️ HUB ops PIL failure {f}: {e}')
|
||||
im = cv2.imread(f)
|
||||
im_height, im_width = im.shape[:2]
|
||||
r = max_dim / max(im_height, im_width) # ratio
|
||||
if r < 1.0: # image too large
|
||||
im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
|
||||
cv2.imwrite(str(f_new or f), im)
|
||||
|
||||
|
||||
def delete_dsstore(path):
|
||||
"""
|
||||
Deletes all ".DS_store" files under a specified directory.
|
||||
|
||||
Args:
|
||||
path (str, optional): The directory path where the ".DS_store" files should be deleted.
|
||||
|
||||
Usage:
|
||||
from ultralytics.yolo.data.utils import delete_dsstore
|
||||
delete_dsstore('/Users/glennjocher/Downloads/dataset')
|
||||
|
||||
Note:
|
||||
".DS_store" files are created by the Apple operating system and contain metadata about folders and files. They
|
||||
are hidden system files and can cause issues when transferring files between different operating systems.
|
||||
"""
|
||||
# Delete Apple .DS_store files
|
||||
files = list(Path(path).rglob('.DS_store'))
|
||||
LOGGER.info(f'Deleting *.DS_store files: {files}')
|
||||
for f in files:
|
||||
f.unlink()
|
||||
|
||||
|
||||
def zip_directory(dir, use_zipfile_library=True):
|
||||
"""
|
||||
Zips a directory and saves the archive to the specified output path.
|
||||
|
||||
Args:
|
||||
dir (str): The path to the directory to be zipped.
|
||||
use_zipfile_library (bool): Whether to use zipfile library or shutil for zipping.
|
||||
|
||||
Usage:
|
||||
from ultralytics.yolo.data.utils import zip_directory
|
||||
zip_directory('/Users/glennjocher/Downloads/playground')
|
||||
|
||||
zip -r coco8-pose.zip coco8-pose
|
||||
"""
|
||||
delete_dsstore(dir)
|
||||
if use_zipfile_library:
|
||||
dir = Path(dir)
|
||||
with zipfile.ZipFile(dir.with_suffix('.zip'), 'w', zipfile.ZIP_DEFLATED) as zip_file:
|
||||
for file_path in dir.glob('**/*'):
|
||||
if file_path.is_file():
|
||||
zip_file.write(file_path, file_path.relative_to(dir))
|
||||
else:
|
||||
import shutil
|
||||
shutil.make_archive(dir, 'zip', dir)
|
||||
|
||||
|
||||
def autosplit(path=DATASETS_DIR / 'coco128/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
|
||||
"""
|
||||
Autosplit a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files.
|
||||
|
||||
Args:
|
||||
path (Path, optional): Path to images directory. Defaults to DATASETS_DIR / 'coco128/images'.
|
||||
weights (list | tuple, optional): Train, validation, and test split fractions. Defaults to (0.9, 0.1, 0.0).
|
||||
annotated_only (bool, optional): If True, only images with an associated txt file are used. Defaults to False.
|
||||
|
||||
Usage:
|
||||
from utils.dataloaders import autosplit
|
||||
autosplit()
|
||||
"""
|
||||
|
||||
path = Path(path) # images dir
|
||||
files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only
|
||||
n = len(files) # number of files
|
||||
random.seed(0) # for reproducibility
|
||||
indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
|
||||
|
||||
txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
|
||||
for x in txt:
|
||||
if (path.parent / x).exists():
|
||||
(path.parent / x).unlink() # remove existing
|
||||
|
||||
LOGGER.info(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
|
||||
for i, img in tqdm(zip(indices, files), total=n):
|
||||
if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
|
||||
with open(path.parent / txt[i], 'a') as f:
|
||||
f.write(f'./{img.relative_to(path.parent).as_posix()}' + '\n') # add image to txt file
|
@ -0,0 +1,10 @@
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
from ultralytics.utils import LOGGER
|
||||
|
||||
# Set modules in sys.modules under their old name
|
||||
sys.modules['ultralytics.yolo.engine'] = importlib.import_module('ultralytics.engine')
|
||||
|
||||
LOGGER.warning("WARNING ⚠️ 'ultralytics.yolo.engine' is deprecated since '8.0.136' and will be removed in '8.1.0'. "
|
||||
"Please use 'ultralytics.engine' instead.")
|
||||
|
@ -1,926 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Export a YOLOv8 PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit
|
||||
|
||||
Format | `format=argument` | Model
|
||||
--- | --- | ---
|
||||
PyTorch | - | yolov8n.pt
|
||||
TorchScript | `torchscript` | yolov8n.torchscript
|
||||
ONNX | `onnx` | yolov8n.onnx
|
||||
OpenVINO | `openvino` | yolov8n_openvino_model/
|
||||
TensorRT | `engine` | yolov8n.engine
|
||||
CoreML | `coreml` | yolov8n.mlmodel
|
||||
TensorFlow SavedModel | `saved_model` | yolov8n_saved_model/
|
||||
TensorFlow GraphDef | `pb` | yolov8n.pb
|
||||
TensorFlow Lite | `tflite` | yolov8n.tflite
|
||||
TensorFlow Edge TPU | `edgetpu` | yolov8n_edgetpu.tflite
|
||||
TensorFlow.js | `tfjs` | yolov8n_web_model/
|
||||
PaddlePaddle | `paddle` | yolov8n_paddle_model/
|
||||
NCNN | `ncnn` | yolov8n_ncnn_model/
|
||||
|
||||
Requirements:
|
||||
$ pip install ultralytics[export]
|
||||
|
||||
Python:
|
||||
from ultralytics import YOLO
|
||||
model = YOLO('yolov8n.pt')
|
||||
results = model.export(format='onnx')
|
||||
|
||||
CLI:
|
||||
$ yolo mode=export model=yolov8n.pt format=onnx
|
||||
|
||||
Inference:
|
||||
$ yolo predict model=yolov8n.pt # PyTorch
|
||||
yolov8n.torchscript # TorchScript
|
||||
yolov8n.onnx # ONNX Runtime or OpenCV DNN with --dnn
|
||||
yolov8n_openvino_model # OpenVINO
|
||||
yolov8n.engine # TensorRT
|
||||
yolov8n.mlmodel # CoreML (macOS-only)
|
||||
yolov8n_saved_model # TensorFlow SavedModel
|
||||
yolov8n.pb # TensorFlow GraphDef
|
||||
yolov8n.tflite # TensorFlow Lite
|
||||
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
|
||||
yolov8n_paddle_model # PaddlePaddle
|
||||
|
||||
TensorFlow.js:
|
||||
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
|
||||
$ npm install
|
||||
$ ln -s ../../yolov5/yolov8n_web_model public/yolov8n_web_model
|
||||
$ npm start
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.nn.autobackend import check_class_names
|
||||
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
|
||||
from ultralytics.nn.tasks import DetectionModel, SegmentationModel
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.utils import (ARM64, DEFAULT_CFG, LINUX, LOGGER, MACOS, ROOT, __version__, callbacks, colorstr,
|
||||
get_default_args, yaml_save)
|
||||
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version
|
||||
from ultralytics.yolo.utils.downloads import attempt_download_asset, get_github_assets
|
||||
from ultralytics.yolo.utils.files import file_size
|
||||
from ultralytics.yolo.utils.ops import Profile
|
||||
from ultralytics.yolo.utils.torch_utils import get_latest_opset, select_device, smart_inference_mode
|
||||
|
||||
|
||||
def export_formats():
|
||||
"""YOLOv8 export formats."""
|
||||
import pandas
|
||||
x = [
|
||||
['PyTorch', '-', '.pt', True, True],
|
||||
['TorchScript', 'torchscript', '.torchscript', True, True],
|
||||
['ONNX', 'onnx', '.onnx', True, True],
|
||||
['OpenVINO', 'openvino', '_openvino_model', True, False],
|
||||
['TensorRT', 'engine', '.engine', False, True],
|
||||
['CoreML', 'coreml', '.mlmodel', True, False],
|
||||
['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True],
|
||||
['TensorFlow GraphDef', 'pb', '.pb', True, True],
|
||||
['TensorFlow Lite', 'tflite', '.tflite', True, False],
|
||||
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', True, False],
|
||||
['TensorFlow.js', 'tfjs', '_web_model', True, False],
|
||||
['PaddlePaddle', 'paddle', '_paddle_model', True, True],
|
||||
['NCNN', 'ncnn', '_ncnn_model', True, True], ]
|
||||
return pandas.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
|
||||
|
||||
|
||||
def gd_outputs(gd):
|
||||
"""TensorFlow GraphDef model output node names."""
|
||||
name_list, input_list = [], []
|
||||
for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
|
||||
name_list.append(node.name)
|
||||
input_list.extend(node.input)
|
||||
return sorted(f'{x}:0' for x in list(set(name_list) - set(input_list)) if not x.startswith('NoOp'))
|
||||
|
||||
|
||||
def try_export(inner_func):
|
||||
"""YOLOv8 export decorator, i..e @try_export."""
|
||||
inner_args = get_default_args(inner_func)
|
||||
|
||||
def outer_func(*args, **kwargs):
|
||||
"""Export a model."""
|
||||
prefix = inner_args['prefix']
|
||||
try:
|
||||
with Profile() as dt:
|
||||
f, model = inner_func(*args, **kwargs)
|
||||
LOGGER.info(f'{prefix} export success ✅ {dt.t:.1f}s, saved as {f} ({file_size(f):.1f} MB)')
|
||||
return f, model
|
||||
except Exception as e:
|
||||
LOGGER.info(f'{prefix} export failure ❌ {dt.t:.1f}s: {e}')
|
||||
return None, None
|
||||
|
||||
return outer_func
|
||||
|
||||
|
||||
class Exporter:
|
||||
"""
|
||||
A class for exporting a model.
|
||||
|
||||
Attributes:
|
||||
args (SimpleNamespace): Configuration for the exporter.
|
||||
save_dir (Path): Directory to save results.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initializes the Exporter class.
|
||||
|
||||
Args:
|
||||
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
_callbacks (list, optional): List of callback functions. Defaults to None.
|
||||
"""
|
||||
self.args = get_cfg(cfg, overrides)
|
||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
@smart_inference_mode()
|
||||
def __call__(self, model=None):
|
||||
"""Returns list of exported files/dirs after running callbacks."""
|
||||
self.run_callbacks('on_export_start')
|
||||
t = time.time()
|
||||
format = self.args.format.lower() # to lowercase
|
||||
if format in ('tensorrt', 'trt'): # engine aliases
|
||||
format = 'engine'
|
||||
fmts = tuple(export_formats()['Argument'][1:]) # available export formats
|
||||
flags = [x == format for x in fmts]
|
||||
if sum(flags) != 1:
|
||||
raise ValueError(f"Invalid export format='{format}'. Valid formats are {fmts}")
|
||||
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn = flags # export booleans
|
||||
|
||||
# Load PyTorch model
|
||||
self.device = select_device('cpu' if self.args.device is None else self.args.device)
|
||||
|
||||
# Checks
|
||||
model.names = check_class_names(model.names)
|
||||
if self.args.half and onnx and self.device.type == 'cpu':
|
||||
LOGGER.warning('WARNING ⚠️ half=True only compatible with GPU export, i.e. use device=0')
|
||||
self.args.half = False
|
||||
assert not self.args.dynamic, 'half=True not compatible with dynamic=True, i.e. use only one.'
|
||||
self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
|
||||
if self.args.optimize:
|
||||
assert not ncnn, "optimize=True not compatible with format='ncnn', i.e. use optimize=False"
|
||||
assert self.device.type == 'cpu', "optimize=True not compatible with cuda devices, i.e. use device='cpu'"
|
||||
if edgetpu and not LINUX:
|
||||
raise SystemError('Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler/')
|
||||
|
||||
# Input
|
||||
im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device)
|
||||
file = Path(
|
||||
getattr(model, 'pt_path', None) or getattr(model, 'yaml_file', None) or model.yaml.get('yaml_file', ''))
|
||||
if file.suffix == '.yaml':
|
||||
file = Path(file.name)
|
||||
|
||||
# Update model
|
||||
model = deepcopy(model).to(self.device)
|
||||
for p in model.parameters():
|
||||
p.requires_grad = False
|
||||
model.eval()
|
||||
model.float()
|
||||
model = model.fuse()
|
||||
for k, m in model.named_modules():
|
||||
if isinstance(m, (Detect, RTDETRDecoder)): # Segment and Pose use Detect base class
|
||||
m.dynamic = self.args.dynamic
|
||||
m.export = True
|
||||
m.format = self.args.format
|
||||
elif isinstance(m, C2f) and not any((saved_model, pb, tflite, edgetpu, tfjs)):
|
||||
# EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
|
||||
m.forward = m.forward_split
|
||||
|
||||
y = None
|
||||
for _ in range(2):
|
||||
y = model(im) # dry runs
|
||||
if self.args.half and (engine or onnx) and self.device.type != 'cpu':
|
||||
im, model = im.half(), model.half() # to FP16
|
||||
|
||||
# Filter warnings
|
||||
warnings.filterwarnings('ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
|
||||
warnings.filterwarnings('ignore', category=UserWarning) # suppress shape prim::Constant missing ONNX warning
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning) # suppress CoreML np.bool deprecation warning
|
||||
|
||||
# Assign
|
||||
self.im = im
|
||||
self.model = model
|
||||
self.file = file
|
||||
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else \
|
||||
tuple(tuple(x.shape if isinstance(x, torch.Tensor) else []) for x in y)
|
||||
self.pretty_name = Path(self.model.yaml.get('yaml_file', self.file)).stem.replace('yolo', 'YOLO')
|
||||
trained_on = f'trained on {Path(self.args.data).name}' if self.args.data else '(untrained)'
|
||||
description = f'Ultralytics {self.pretty_name} model {trained_on}'
|
||||
self.metadata = {
|
||||
'description': description,
|
||||
'author': 'Ultralytics',
|
||||
'license': 'AGPL-3.0 https://ultralytics.com/license',
|
||||
'date': datetime.now().isoformat(),
|
||||
'version': __version__,
|
||||
'stride': int(max(model.stride)),
|
||||
'task': model.task,
|
||||
'batch': self.args.batch,
|
||||
'imgsz': self.imgsz,
|
||||
'names': model.names} # model metadata
|
||||
if model.task == 'pose':
|
||||
self.metadata['kpt_shape'] = model.model[-1].kpt_shape
|
||||
|
||||
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} BCHW and "
|
||||
f'output shape(s) {self.output_shape} ({file_size(file):.1f} MB)')
|
||||
|
||||
# Exports
|
||||
f = [''] * len(fmts) # exported filenames
|
||||
if jit or ncnn: # TorchScript
|
||||
f[0], _ = self.export_torchscript()
|
||||
if engine: # TensorRT required before ONNX
|
||||
f[1], _ = self.export_engine()
|
||||
if onnx or xml: # OpenVINO requires ONNX
|
||||
f[2], _ = self.export_onnx()
|
||||
if xml: # OpenVINO
|
||||
f[3], _ = self.export_openvino()
|
||||
if coreml: # CoreML
|
||||
f[4], _ = self.export_coreml()
|
||||
if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
|
||||
self.args.int8 |= edgetpu
|
||||
f[5], s_model = self.export_saved_model()
|
||||
if pb or tfjs: # pb prerequisite to tfjs
|
||||
f[6], _ = self.export_pb(s_model)
|
||||
if tflite:
|
||||
f[7], _ = self.export_tflite(s_model, nms=False, agnostic_nms=self.args.agnostic_nms)
|
||||
if edgetpu:
|
||||
f[8], _ = self.export_edgetpu(tflite_model=Path(f[5]) / f'{self.file.stem}_full_integer_quant.tflite')
|
||||
if tfjs:
|
||||
f[9], _ = self.export_tfjs()
|
||||
if paddle: # PaddlePaddle
|
||||
f[10], _ = self.export_paddle()
|
||||
if ncnn: # NCNN
|
||||
f[11], _ = self.export_ncnn()
|
||||
|
||||
# Finish
|
||||
f = [str(x) for x in f if x] # filter out '' and None
|
||||
if any(f):
|
||||
f = str(Path(f[-1]))
|
||||
square = self.imgsz[0] == self.imgsz[1]
|
||||
s = '' if square else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not " \
|
||||
f"work. Use export 'imgsz={max(self.imgsz)}' if val is required."
|
||||
imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(' ', '')
|
||||
data = f'data={self.args.data}' if model.task == 'segment' and format == 'pb' else ''
|
||||
LOGGER.info(
|
||||
f'\nExport complete ({time.time() - t:.1f}s)'
|
||||
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
|
||||
f'\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {data}'
|
||||
f'\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={self.args.data} {s}'
|
||||
f'\nVisualize: https://netron.app')
|
||||
|
||||
self.run_callbacks('on_export_end')
|
||||
return f # return list of exported files/dirs
|
||||
|
||||
@try_export
|
||||
def export_torchscript(self, prefix=colorstr('TorchScript:')):
|
||||
"""YOLOv8 TorchScript model export."""
|
||||
LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
|
||||
f = self.file.with_suffix('.torchscript')
|
||||
|
||||
ts = torch.jit.trace(self.model, self.im, strict=False)
|
||||
extra_files = {'config.txt': json.dumps(self.metadata)} # torch._C.ExtraFilesMap()
|
||||
if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
|
||||
LOGGER.info(f'{prefix} optimizing for mobile...')
|
||||
from torch.utils.mobile_optimizer import optimize_for_mobile
|
||||
optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
|
||||
else:
|
||||
ts.save(str(f), _extra_files=extra_files)
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
def export_onnx(self, prefix=colorstr('ONNX:')):
|
||||
"""YOLOv8 ONNX export."""
|
||||
requirements = ['onnx>=1.12.0']
|
||||
if self.args.simplify:
|
||||
requirements += ['onnxsim>=0.4.17', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime']
|
||||
check_requirements(requirements)
|
||||
import onnx # noqa
|
||||
|
||||
opset_version = self.args.opset or get_latest_opset()
|
||||
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...')
|
||||
f = str(self.file.with_suffix('.onnx'))
|
||||
|
||||
output_names = ['output0', 'output1'] if isinstance(self.model, SegmentationModel) else ['output0']
|
||||
dynamic = self.args.dynamic
|
||||
if dynamic:
|
||||
dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640)
|
||||
if isinstance(self.model, SegmentationModel):
|
||||
dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
|
||||
dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'} # shape(1,32,160,160)
|
||||
elif isinstance(self.model, DetectionModel):
|
||||
dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
|
||||
|
||||
torch.onnx.export(
|
||||
self.model.cpu() if dynamic else self.model, # --dynamic only compatible with cpu
|
||||
self.im.cpu() if dynamic else self.im,
|
||||
f,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
|
||||
input_names=['images'],
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic or None)
|
||||
|
||||
# Checks
|
||||
model_onnx = onnx.load(f) # load onnx model
|
||||
# onnx.checker.check_model(model_onnx) # check onnx model
|
||||
|
||||
# Simplify
|
||||
if self.args.simplify:
|
||||
try:
|
||||
import onnxsim
|
||||
|
||||
LOGGER.info(f'{prefix} simplifying with onnxsim {onnxsim.__version__}...')
|
||||
# subprocess.run(f'onnxsim {f} {f}', shell=True)
|
||||
model_onnx, check = onnxsim.simplify(model_onnx)
|
||||
assert check, 'Simplified ONNX model could not be validated'
|
||||
except Exception as e:
|
||||
LOGGER.info(f'{prefix} simplifier failure: {e}')
|
||||
|
||||
# Metadata
|
||||
for k, v in self.metadata.items():
|
||||
meta = model_onnx.metadata_props.add()
|
||||
meta.key, meta.value = k, str(v)
|
||||
|
||||
onnx.save(model_onnx, f)
|
||||
return f, model_onnx
|
||||
|
||||
@try_export
|
||||
def export_openvino(self, prefix=colorstr('OpenVINO:')):
|
||||
"""YOLOv8 OpenVINO export."""
|
||||
check_requirements('openvino-dev>=2023.0') # requires openvino-dev: https://pypi.org/project/openvino-dev/
|
||||
import openvino.runtime as ov # noqa
|
||||
from openvino.tools import mo # noqa
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with openvino {ov.__version__}...')
|
||||
f = str(self.file).replace(self.file.suffix, f'_openvino_model{os.sep}')
|
||||
f_onnx = self.file.with_suffix('.onnx')
|
||||
f_ov = str(Path(f) / self.file.with_suffix('.xml').name)
|
||||
|
||||
ov_model = mo.convert_model(f_onnx,
|
||||
model_name=self.pretty_name,
|
||||
framework='onnx',
|
||||
compress_to_fp16=self.args.half) # export
|
||||
|
||||
# Set RT info
|
||||
ov_model.set_rt_info('YOLOv8', ['model_info', 'model_type'])
|
||||
ov_model.set_rt_info(True, ['model_info', 'reverse_input_channels'])
|
||||
ov_model.set_rt_info(114, ['model_info', 'pad_value'])
|
||||
ov_model.set_rt_info([255.0], ['model_info', 'scale_values'])
|
||||
ov_model.set_rt_info(self.args.iou, ['model_info', 'iou_threshold'])
|
||||
ov_model.set_rt_info([v.replace(' ', '_') for k, v in sorted(self.model.names.items())],
|
||||
['model_info', 'labels'])
|
||||
if self.model.task != 'classify':
|
||||
ov_model.set_rt_info('fit_to_window_letterbox', ['model_info', 'resize_type'])
|
||||
|
||||
ov.serialize(ov_model, f_ov) # save
|
||||
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
def export_paddle(self, prefix=colorstr('PaddlePaddle:')):
|
||||
"""YOLOv8 Paddle export."""
|
||||
check_requirements(('paddlepaddle', 'x2paddle'))
|
||||
import x2paddle # noqa
|
||||
from x2paddle.convert import pytorch2paddle # noqa
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with X2Paddle {x2paddle.__version__}...')
|
||||
f = str(self.file).replace(self.file.suffix, f'_paddle_model{os.sep}')
|
||||
|
||||
pytorch2paddle(module=self.model, save_dir=f, jit_type='trace', input_examples=[self.im]) # export
|
||||
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
def export_ncnn(self, prefix=colorstr('NCNN:')):
|
||||
"""
|
||||
YOLOv8 NCNN export using PNNX https://github.com/pnnx/pnnx.
|
||||
"""
|
||||
check_requirements('git+https://github.com/Tencent/ncnn.git' if ARM64 else 'ncnn') # requires NCNN
|
||||
import ncnn # noqa
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with NCNN {ncnn.__version__}...')
|
||||
f = Path(str(self.file).replace(self.file.suffix, f'_ncnn_model{os.sep}'))
|
||||
f_ts = str(self.file.with_suffix('.torchscript'))
|
||||
|
||||
if Path('./pnnx').is_file():
|
||||
pnnx = './pnnx'
|
||||
elif (ROOT / 'pnnx').is_file():
|
||||
pnnx = ROOT / 'pnnx'
|
||||
else:
|
||||
LOGGER.warning(
|
||||
f'{prefix} WARNING ⚠️ PNNX not found. Attempting to download binary file from '
|
||||
'https://github.com/pnnx/pnnx/.\nNote PNNX Binary file must be placed in current working directory '
|
||||
f'or in {ROOT}. See PNNX repo for full installation instructions.')
|
||||
_, assets = get_github_assets(repo='pnnx/pnnx')
|
||||
asset = [x for x in assets if ('macos' if MACOS else 'ubuntu' if LINUX else 'windows') in x][0]
|
||||
attempt_download_asset(asset, repo='pnnx/pnnx', release='latest')
|
||||
unzip_dir = Path(asset).with_suffix('')
|
||||
pnnx = ROOT / 'pnnx' # new location
|
||||
(unzip_dir / 'pnnx').rename(pnnx) # move binary to ROOT
|
||||
shutil.rmtree(unzip_dir) # delete unzip dir
|
||||
Path(asset).unlink() # delete zip
|
||||
pnnx.chmod(0o777) # set read, write, and execute permissions for everyone
|
||||
|
||||
cmd = [
|
||||
str(pnnx),
|
||||
f_ts,
|
||||
f'pnnxparam={f / "model.pnnx.param"}',
|
||||
f'pnnxbin={f / "model.pnnx.bin"}',
|
||||
f'pnnxpy={f / "model_pnnx.py"}',
|
||||
f'pnnxonnx={f / "model.pnnx.onnx"}',
|
||||
f'ncnnparam={f / "model.ncnn.param"}',
|
||||
f'ncnnbin={f / "model.ncnn.bin"}',
|
||||
f'ncnnpy={f / "model_ncnn.py"}',
|
||||
f'fp16={int(self.args.half)}',
|
||||
f'device={self.device.type}',
|
||||
f'inputshape="{[self.args.batch, 3, *self.imgsz]}"', ]
|
||||
f.mkdir(exist_ok=True) # make ncnn_model directory
|
||||
LOGGER.info(f"{prefix} running '{' '.join(cmd)}'")
|
||||
subprocess.run(cmd, check=True)
|
||||
for f_debug in 'debug.bin', 'debug.param', 'debug2.bin', 'debug2.param': # remove debug files
|
||||
Path(f_debug).unlink(missing_ok=True)
|
||||
|
||||
yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
return str(f), None
|
||||
|
||||
@try_export
|
||||
def export_coreml(self, prefix=colorstr('CoreML:')):
|
||||
"""YOLOv8 CoreML export."""
|
||||
check_requirements('coremltools>=6.0')
|
||||
import coremltools as ct # noqa
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
|
||||
f = self.file.with_suffix('.mlmodel')
|
||||
|
||||
bias = [0.0, 0.0, 0.0]
|
||||
scale = 1 / 255
|
||||
classifier_config = None
|
||||
if self.model.task == 'classify':
|
||||
classifier_config = ct.ClassifierConfig(list(self.model.names.values())) if self.args.nms else None
|
||||
model = self.model
|
||||
elif self.model.task == 'detect':
|
||||
model = iOSDetectModel(self.model, self.im) if self.args.nms else self.model
|
||||
else:
|
||||
# TODO CoreML Segment and Pose model pipelining
|
||||
model = self.model
|
||||
|
||||
ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model
|
||||
ct_model = ct.convert(ts,
|
||||
inputs=[ct.ImageType('image', shape=self.im.shape, scale=scale, bias=bias)],
|
||||
classifier_config=classifier_config)
|
||||
bits, mode = (8, 'kmeans_lut') if self.args.int8 else (16, 'linear') if self.args.half else (32, None)
|
||||
if bits < 32:
|
||||
if 'kmeans' in mode:
|
||||
check_requirements('scikit-learn') # scikit-learn package required for k-means quantization
|
||||
ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
|
||||
if self.args.nms and self.model.task == 'detect':
|
||||
ct_model = self._pipeline_coreml(ct_model)
|
||||
|
||||
m = self.metadata # metadata dict
|
||||
ct_model.short_description = m.pop('description')
|
||||
ct_model.author = m.pop('author')
|
||||
ct_model.license = m.pop('license')
|
||||
ct_model.version = m.pop('version')
|
||||
ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items()})
|
||||
ct_model.save(str(f))
|
||||
return f, ct_model
|
||||
|
||||
@try_export
|
||||
def export_engine(self, prefix=colorstr('TensorRT:')):
|
||||
"""YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt."""
|
||||
assert self.im.device.type != 'cpu', "export running on CPU but must be on GPU, i.e. use 'device=0'"
|
||||
try:
|
||||
import tensorrt as trt # noqa
|
||||
except ImportError:
|
||||
if LINUX:
|
||||
check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
|
||||
import tensorrt as trt # noqa
|
||||
|
||||
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
|
||||
self.args.simplify = True
|
||||
f_onnx, _ = self.export_onnx()
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
|
||||
assert Path(f_onnx).exists(), f'failed to export ONNX file: {f_onnx}'
|
||||
f = self.file.with_suffix('.engine') # TensorRT engine file
|
||||
logger = trt.Logger(trt.Logger.INFO)
|
||||
if self.args.verbose:
|
||||
logger.min_severity = trt.Logger.Severity.VERBOSE
|
||||
|
||||
builder = trt.Builder(logger)
|
||||
config = builder.create_builder_config()
|
||||
config.max_workspace_size = self.args.workspace * 1 << 30
|
||||
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
|
||||
|
||||
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
||||
network = builder.create_network(flag)
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
if not parser.parse_from_file(f_onnx):
|
||||
raise RuntimeError(f'failed to load ONNX file: {f_onnx}')
|
||||
|
||||
inputs = [network.get_input(i) for i in range(network.num_inputs)]
|
||||
outputs = [network.get_output(i) for i in range(network.num_outputs)]
|
||||
for inp in inputs:
|
||||
LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
|
||||
for out in outputs:
|
||||
LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
|
||||
|
||||
if self.args.dynamic:
|
||||
shape = self.im.shape
|
||||
if shape[0] <= 1:
|
||||
LOGGER.warning(f'{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument')
|
||||
profile = builder.create_optimization_profile()
|
||||
for inp in inputs:
|
||||
profile.set_shape(inp.name, (1, *shape[1:]), (max(1, shape[0] // 2), *shape[1:]), shape)
|
||||
config.add_optimization_profile(profile)
|
||||
|
||||
LOGGER.info(
|
||||
f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and self.args.half else 32} engine as {f}')
|
||||
if builder.platform_has_fast_fp16 and self.args.half:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
|
||||
# Write file
|
||||
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
|
||||
# Metadata
|
||||
meta = json.dumps(self.metadata)
|
||||
t.write(len(meta).to_bytes(4, byteorder='little', signed=True))
|
||||
t.write(meta.encode())
|
||||
# Model
|
||||
t.write(engine.serialize())
|
||||
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
def export_saved_model(self, prefix=colorstr('TensorFlow SavedModel:')):
|
||||
"""YOLOv8 TensorFlow SavedModel export."""
|
||||
try:
|
||||
import tensorflow as tf # noqa
|
||||
except ImportError:
|
||||
cuda = torch.cuda.is_available()
|
||||
check_requirements(f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if cuda else '-cpu'}")
|
||||
import tensorflow as tf # noqa
|
||||
check_requirements(('onnx', 'onnx2tf>=1.7.7', 'sng4onnx>=1.0.1', 'onnxsim>=0.4.17', 'onnx_graphsurgeon>=0.3.26',
|
||||
'tflite_support', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'),
|
||||
cmds='--extra-index-url https://pypi.ngc.nvidia.com')
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
||||
f = Path(str(self.file).replace(self.file.suffix, '_saved_model'))
|
||||
if f.is_dir():
|
||||
import shutil
|
||||
shutil.rmtree(f) # delete output folder
|
||||
|
||||
# Export to ONNX
|
||||
self.args.simplify = True
|
||||
f_onnx, _ = self.export_onnx()
|
||||
|
||||
# Export to TF
|
||||
int8 = '-oiqt -qt per-tensor' if self.args.int8 else ''
|
||||
cmd = f'onnx2tf -i {f_onnx} -o {f} -nuo --non_verbose {int8}'
|
||||
LOGGER.info(f"\n{prefix} running '{cmd.strip()}'")
|
||||
subprocess.run(cmd, shell=True)
|
||||
yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
|
||||
# Remove/rename TFLite models
|
||||
if self.args.int8:
|
||||
for file in f.rglob('*_dynamic_range_quant.tflite'):
|
||||
file.rename(file.with_name(file.stem.replace('_dynamic_range_quant', '_int8') + file.suffix))
|
||||
for file in f.rglob('*_integer_quant_with_int16_act.tflite'):
|
||||
file.unlink() # delete extra fp16 activation TFLite files
|
||||
|
||||
# Add TFLite metadata
|
||||
for file in f.rglob('*.tflite'):
|
||||
f.unlink() if 'quant_with_int16_act.tflite' in str(f) else self._add_tflite_metadata(file)
|
||||
|
||||
# Load saved_model
|
||||
keras_model = tf.saved_model.load(f, tags=None, options=None)
|
||||
|
||||
return str(f), keras_model
|
||||
|
||||
@try_export
|
||||
def export_pb(self, keras_model, prefix=colorstr('TensorFlow GraphDef:')):
|
||||
"""YOLOv8 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow."""
|
||||
import tensorflow as tf # noqa
|
||||
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
||||
f = self.file.with_suffix('.pb')
|
||||
|
||||
m = tf.function(lambda x: keras_model(x)) # full model
|
||||
m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
|
||||
frozen_func = convert_variables_to_constants_v2(m)
|
||||
frozen_func.graph.as_graph_def()
|
||||
tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
def export_tflite(self, keras_model, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
|
||||
"""YOLOv8 TensorFlow Lite export."""
|
||||
import tensorflow as tf # noqa
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
||||
saved_model = Path(str(self.file).replace(self.file.suffix, '_saved_model'))
|
||||
if self.args.int8:
|
||||
f = saved_model / f'{self.file.stem}_int8.tflite' # fp32 in/out
|
||||
elif self.args.half:
|
||||
f = saved_model / f'{self.file.stem}_float16.tflite' # fp32 in/out
|
||||
else:
|
||||
f = saved_model / f'{self.file.stem}_float32.tflite'
|
||||
return str(f), None
|
||||
|
||||
@try_export
|
||||
def export_edgetpu(self, tflite_model='', prefix=colorstr('Edge TPU:')):
|
||||
"""YOLOv8 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/."""
|
||||
LOGGER.warning(f'{prefix} WARNING ⚠️ Edge TPU known bug https://github.com/ultralytics/ultralytics/issues/1185')
|
||||
|
||||
cmd = 'edgetpu_compiler --version'
|
||||
help_url = 'https://coral.ai/docs/edgetpu/compiler/'
|
||||
assert LINUX, f'export only supported on Linux. See {help_url}'
|
||||
if subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True).returncode != 0:
|
||||
LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
|
||||
sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
|
||||
for c in (
|
||||
'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',
|
||||
'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
|
||||
'sudo apt-get update', 'sudo apt-get install edgetpu-compiler'):
|
||||
subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)
|
||||
ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
|
||||
f = str(tflite_model).replace('.tflite', '_edgetpu.tflite') # Edge TPU model
|
||||
|
||||
cmd = f'edgetpu_compiler -s -d -k 10 --out_dir {Path(f).parent} {tflite_model}'
|
||||
LOGGER.info(f"{prefix} running '{cmd}'")
|
||||
subprocess.run(cmd.split(), check=True)
|
||||
self._add_tflite_metadata(f)
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
def export_tfjs(self, prefix=colorstr('TensorFlow.js:')):
|
||||
"""YOLOv8 TensorFlow.js export."""
|
||||
check_requirements('tensorflowjs')
|
||||
import tensorflow as tf
|
||||
import tensorflowjs as tfjs # noqa
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
|
||||
f = str(self.file).replace(self.file.suffix, '_web_model') # js dir
|
||||
f_pb = self.file.with_suffix('.pb') # *.pb path
|
||||
|
||||
gd = tf.Graph().as_graph_def() # TF GraphDef
|
||||
with open(f_pb, 'rb') as file:
|
||||
gd.ParseFromString(file.read())
|
||||
outputs = ','.join(gd_outputs(gd))
|
||||
LOGGER.info(f'\n{prefix} output node names: {outputs}')
|
||||
|
||||
cmd = f'tensorflowjs_converter --input_format=tf_frozen_model --output_node_names={outputs} {f_pb} {f}'
|
||||
subprocess.run(cmd.split(), check=True)
|
||||
|
||||
# f_json = Path(f) / 'model.json' # *.json path
|
||||
# with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
|
||||
# subst = re.sub(
|
||||
# r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
|
||||
# r'"Identity.?.?": {"name": "Identity.?.?"}, '
|
||||
# r'"Identity.?.?": {"name": "Identity.?.?"}, '
|
||||
# r'"Identity.?.?": {"name": "Identity.?.?"}}}',
|
||||
# r'{"outputs": {"Identity": {"name": "Identity"}, '
|
||||
# r'"Identity_1": {"name": "Identity_1"}, '
|
||||
# r'"Identity_2": {"name": "Identity_2"}, '
|
||||
# r'"Identity_3": {"name": "Identity_3"}}}',
|
||||
# f_json.read_text(),
|
||||
# )
|
||||
# j.write(subst)
|
||||
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
return f, None
|
||||
|
||||
def _add_tflite_metadata(self, file):
|
||||
"""Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata."""
|
||||
from tflite_support import flatbuffers # noqa
|
||||
from tflite_support import metadata as _metadata # noqa
|
||||
from tflite_support import metadata_schema_py_generated as _metadata_fb # noqa
|
||||
|
||||
# Create model info
|
||||
model_meta = _metadata_fb.ModelMetadataT()
|
||||
model_meta.name = self.metadata['description']
|
||||
model_meta.version = self.metadata['version']
|
||||
model_meta.author = self.metadata['author']
|
||||
model_meta.license = self.metadata['license']
|
||||
|
||||
# Label file
|
||||
tmp_file = Path(file).parent / 'temp_meta.txt'
|
||||
with open(tmp_file, 'w') as f:
|
||||
f.write(str(self.metadata))
|
||||
|
||||
label_file = _metadata_fb.AssociatedFileT()
|
||||
label_file.name = tmp_file.name
|
||||
label_file.type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS
|
||||
|
||||
# Create input info
|
||||
input_meta = _metadata_fb.TensorMetadataT()
|
||||
input_meta.name = 'image'
|
||||
input_meta.description = 'Input image to be detected.'
|
||||
input_meta.content = _metadata_fb.ContentT()
|
||||
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
|
||||
input_meta.content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.RGB
|
||||
input_meta.content.contentPropertiesType = _metadata_fb.ContentProperties.ImageProperties
|
||||
|
||||
# Create output info
|
||||
output1 = _metadata_fb.TensorMetadataT()
|
||||
output1.name = 'output'
|
||||
output1.description = 'Coordinates of detected objects, class labels, and confidence score'
|
||||
output1.associatedFiles = [label_file]
|
||||
if self.model.task == 'segment':
|
||||
output2 = _metadata_fb.TensorMetadataT()
|
||||
output2.name = 'output'
|
||||
output2.description = 'Mask protos'
|
||||
output2.associatedFiles = [label_file]
|
||||
|
||||
# Create subgraph info
|
||||
subgraph = _metadata_fb.SubGraphMetadataT()
|
||||
subgraph.inputTensorMetadata = [input_meta]
|
||||
subgraph.outputTensorMetadata = [output1, output2] if self.model.task == 'segment' else [output1]
|
||||
model_meta.subgraphMetadata = [subgraph]
|
||||
|
||||
b = flatbuffers.Builder(0)
|
||||
b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
||||
metadata_buf = b.Output()
|
||||
|
||||
populator = _metadata.MetadataPopulator.with_model_file(str(file))
|
||||
populator.load_metadata_buffer(metadata_buf)
|
||||
populator.load_associated_files([str(tmp_file)])
|
||||
populator.populate()
|
||||
tmp_file.unlink()
|
||||
|
||||
def _pipeline_coreml(self, model, prefix=colorstr('CoreML Pipeline:')):
|
||||
"""YOLOv8 CoreML pipeline."""
|
||||
import coremltools as ct # noqa
|
||||
|
||||
LOGGER.info(f'{prefix} starting pipeline with coremltools {ct.__version__}...')
|
||||
batch_size, ch, h, w = list(self.im.shape) # BCHW
|
||||
|
||||
# Output shapes
|
||||
spec = model.get_spec()
|
||||
out0, out1 = iter(spec.description.output)
|
||||
if MACOS:
|
||||
from PIL import Image
|
||||
img = Image.new('RGB', (w, h)) # img(192 width, 320 height)
|
||||
# img = torch.zeros((*opt.img_size, 3)).numpy() # img size(320,192,3) iDetection
|
||||
out = model.predict({'image': img})
|
||||
out0_shape = out[out0.name].shape
|
||||
out1_shape = out[out1.name].shape
|
||||
else: # linux and windows can not run model.predict(), get sizes from pytorch output y
|
||||
out0_shape = self.output_shape[2], self.output_shape[1] - 4 # (3780, 80)
|
||||
out1_shape = self.output_shape[2], 4 # (3780, 4)
|
||||
|
||||
# Checks
|
||||
names = self.metadata['names']
|
||||
nx, ny = spec.description.input[0].type.imageType.width, spec.description.input[0].type.imageType.height
|
||||
na, nc = out0_shape
|
||||
# na, nc = out0.type.multiArrayType.shape # number anchors, classes
|
||||
assert len(names) == nc, f'{len(names)} names found for nc={nc}' # check
|
||||
|
||||
# Define output shapes (missing)
|
||||
out0.type.multiArrayType.shape[:] = out0_shape # (3780, 80)
|
||||
out1.type.multiArrayType.shape[:] = out1_shape # (3780, 4)
|
||||
# spec.neuralNetwork.preprocessing[0].featureName = '0'
|
||||
|
||||
# Flexible input shapes
|
||||
# from coremltools.models.neural_network import flexible_shape_utils
|
||||
# s = [] # shapes
|
||||
# s.append(flexible_shape_utils.NeuralNetworkImageSize(320, 192))
|
||||
# s.append(flexible_shape_utils.NeuralNetworkImageSize(640, 384)) # (height, width)
|
||||
# flexible_shape_utils.add_enumerated_image_sizes(spec, feature_name='image', sizes=s)
|
||||
# r = flexible_shape_utils.NeuralNetworkImageSizeRange() # shape ranges
|
||||
# r.add_height_range((192, 640))
|
||||
# r.add_width_range((192, 640))
|
||||
# flexible_shape_utils.update_image_size_range(spec, feature_name='image', size_range=r)
|
||||
|
||||
# Print
|
||||
# print(spec.description)
|
||||
|
||||
# Model from spec
|
||||
model = ct.models.MLModel(spec)
|
||||
|
||||
# 3. Create NMS protobuf
|
||||
nms_spec = ct.proto.Model_pb2.Model()
|
||||
nms_spec.specificationVersion = 5
|
||||
for i in range(2):
|
||||
decoder_output = model._spec.description.output[i].SerializeToString()
|
||||
nms_spec.description.input.add()
|
||||
nms_spec.description.input[i].ParseFromString(decoder_output)
|
||||
nms_spec.description.output.add()
|
||||
nms_spec.description.output[i].ParseFromString(decoder_output)
|
||||
|
||||
nms_spec.description.output[0].name = 'confidence'
|
||||
nms_spec.description.output[1].name = 'coordinates'
|
||||
|
||||
output_sizes = [nc, 4]
|
||||
for i in range(2):
|
||||
ma_type = nms_spec.description.output[i].type.multiArrayType
|
||||
ma_type.shapeRange.sizeRanges.add()
|
||||
ma_type.shapeRange.sizeRanges[0].lowerBound = 0
|
||||
ma_type.shapeRange.sizeRanges[0].upperBound = -1
|
||||
ma_type.shapeRange.sizeRanges.add()
|
||||
ma_type.shapeRange.sizeRanges[1].lowerBound = output_sizes[i]
|
||||
ma_type.shapeRange.sizeRanges[1].upperBound = output_sizes[i]
|
||||
del ma_type.shape[:]
|
||||
|
||||
nms = nms_spec.nonMaximumSuppression
|
||||
nms.confidenceInputFeatureName = out0.name # 1x507x80
|
||||
nms.coordinatesInputFeatureName = out1.name # 1x507x4
|
||||
nms.confidenceOutputFeatureName = 'confidence'
|
||||
nms.coordinatesOutputFeatureName = 'coordinates'
|
||||
nms.iouThresholdInputFeatureName = 'iouThreshold'
|
||||
nms.confidenceThresholdInputFeatureName = 'confidenceThreshold'
|
||||
nms.iouThreshold = 0.45
|
||||
nms.confidenceThreshold = 0.25
|
||||
nms.pickTop.perClass = True
|
||||
nms.stringClassLabels.vector.extend(names.values())
|
||||
nms_model = ct.models.MLModel(nms_spec)
|
||||
|
||||
# 4. Pipeline models together
|
||||
pipeline = ct.models.pipeline.Pipeline(input_features=[('image', ct.models.datatypes.Array(3, ny, nx)),
|
||||
('iouThreshold', ct.models.datatypes.Double()),
|
||||
('confidenceThreshold', ct.models.datatypes.Double())],
|
||||
output_features=['confidence', 'coordinates'])
|
||||
pipeline.add_model(model)
|
||||
pipeline.add_model(nms_model)
|
||||
|
||||
# Correct datatypes
|
||||
pipeline.spec.description.input[0].ParseFromString(model._spec.description.input[0].SerializeToString())
|
||||
pipeline.spec.description.output[0].ParseFromString(nms_model._spec.description.output[0].SerializeToString())
|
||||
pipeline.spec.description.output[1].ParseFromString(nms_model._spec.description.output[1].SerializeToString())
|
||||
|
||||
# Update metadata
|
||||
pipeline.spec.specificationVersion = 5
|
||||
pipeline.spec.description.metadata.userDefined.update({
|
||||
'IoU threshold': str(nms.iouThreshold),
|
||||
'Confidence threshold': str(nms.confidenceThreshold)})
|
||||
|
||||
# Save the model
|
||||
model = ct.models.MLModel(pipeline.spec)
|
||||
model.input_description['image'] = 'Input image'
|
||||
model.input_description['iouThreshold'] = f'(optional) IOU threshold override (default: {nms.iouThreshold})'
|
||||
model.input_description['confidenceThreshold'] = \
|
||||
f'(optional) Confidence threshold override (default: {nms.confidenceThreshold})'
|
||||
model.output_description['confidence'] = 'Boxes × Class confidence (see user-defined metadata "classes")'
|
||||
model.output_description['coordinates'] = 'Boxes × [x, y, width, height] (relative to image size)'
|
||||
LOGGER.info(f'{prefix} pipeline success')
|
||||
return model
|
||||
|
||||
def add_callback(self, event: str, callback):
|
||||
"""
|
||||
Appends the given callback.
|
||||
"""
|
||||
self.callbacks[event].append(callback)
|
||||
|
||||
def run_callbacks(self, event: str):
|
||||
"""Execute all callbacks for a given event."""
|
||||
for callback in self.callbacks.get(event, []):
|
||||
callback(self)
|
||||
|
||||
|
||||
class iOSDetectModel(torch.nn.Module):
|
||||
"""Wrap an Ultralytics YOLO model for iOS export."""
|
||||
|
||||
def __init__(self, model, im):
|
||||
"""Initialize the iOSDetectModel class with a YOLO model and example image."""
|
||||
super().__init__()
|
||||
b, c, h, w = im.shape # batch, channel, height, width
|
||||
self.model = model
|
||||
self.nc = len(model.names) # number of classes
|
||||
if w == h:
|
||||
self.normalize = 1.0 / w # scalar
|
||||
else:
|
||||
self.normalize = torch.tensor([1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h]) # broadcast (slower, smaller)
|
||||
|
||||
def forward(self, x):
|
||||
"""Normalize predictions of object detection model with input size-dependent factors."""
|
||||
xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
|
||||
return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)
|
||||
|
||||
|
||||
def export(cfg=DEFAULT_CFG):
|
||||
"""Export a YOLOv model to a specific format."""
|
||||
cfg.model = cfg.model or 'yolov8n.yaml'
|
||||
cfg.format = cfg.format or 'torchscript'
|
||||
|
||||
from ultralytics import YOLO
|
||||
model = YOLO(cfg.model)
|
||||
model.export(**vars(cfg))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""
|
||||
CLI:
|
||||
yolo mode=export model=yolov8n.yaml format=onnx
|
||||
"""
|
||||
export()
|
@ -1,436 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from ultralytics import yolo # noqa
|
||||
from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, PoseModel, SegmentationModel,
|
||||
attempt_load_one_weight, guess_model_task, nn, yaml_model_load)
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
|
||||
is_git_dir, yaml_load)
|
||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
|
||||
from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS
|
||||
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
|
||||
|
||||
# Map head to model, trainer, validator, and predictor classes
|
||||
TASK_MAP = {
|
||||
'classify': [
|
||||
ClassificationModel, yolo.v8.classify.ClassificationTrainer, yolo.v8.classify.ClassificationValidator,
|
||||
yolo.v8.classify.ClassificationPredictor],
|
||||
'detect': [
|
||||
DetectionModel, yolo.v8.detect.DetectionTrainer, yolo.v8.detect.DetectionValidator,
|
||||
yolo.v8.detect.DetectionPredictor],
|
||||
'segment': [
|
||||
SegmentationModel, yolo.v8.segment.SegmentationTrainer, yolo.v8.segment.SegmentationValidator,
|
||||
yolo.v8.segment.SegmentationPredictor],
|
||||
'pose': [PoseModel, yolo.v8.pose.PoseTrainer, yolo.v8.pose.PoseValidator, yolo.v8.pose.PosePredictor]}
|
||||
|
||||
|
||||
class YOLO:
|
||||
"""
|
||||
YOLO (You Only Look Once) object detection model.
|
||||
|
||||
Args:
|
||||
model (str, Path): Path to the model file to load or create.
|
||||
task (Any, optional): Task type for the YOLO model. Defaults to None.
|
||||
|
||||
Attributes:
|
||||
predictor (Any): The predictor object.
|
||||
model (Any): The model object.
|
||||
trainer (Any): The trainer object.
|
||||
task (str): The type of model task.
|
||||
ckpt (Any): The checkpoint object if the model loaded from *.pt file.
|
||||
cfg (str): The model configuration if loaded from *.yaml file.
|
||||
ckpt_path (str): The checkpoint file path.
|
||||
overrides (dict): Overrides for the trainer object.
|
||||
metrics (Any): The data for metrics.
|
||||
|
||||
Methods:
|
||||
__call__(source=None, stream=False, **kwargs):
|
||||
Alias for the predict method.
|
||||
_new(cfg:str, verbose:bool=True) -> None:
|
||||
Initializes a new model and infers the task type from the model definitions.
|
||||
_load(weights:str, task:str='') -> None:
|
||||
Initializes a new model and infers the task type from the model head.
|
||||
_check_is_pytorch_model() -> None:
|
||||
Raises TypeError if the model is not a PyTorch model.
|
||||
reset() -> None:
|
||||
Resets the model modules.
|
||||
info(verbose:bool=False) -> None:
|
||||
Logs the model info.
|
||||
fuse() -> None:
|
||||
Fuses the model for faster inference.
|
||||
predict(source=None, stream=False, **kwargs) -> List[ultralytics.yolo.engine.results.Results]:
|
||||
Performs prediction using the YOLO model.
|
||||
|
||||
Returns:
|
||||
list(ultralytics.yolo.engine.results.Results): The prediction results.
|
||||
"""
|
||||
|
||||
def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None) -> None:
|
||||
"""
|
||||
Initializes the YOLO model.
|
||||
|
||||
Args:
|
||||
model (Union[str, Path], optional): Path or name of the model to load or create. Defaults to 'yolov8n.pt'.
|
||||
task (Any, optional): Task type for the YOLO model. Defaults to None.
|
||||
"""
|
||||
self.callbacks = callbacks.get_default_callbacks()
|
||||
self.predictor = None # reuse predictor
|
||||
self.model = None # model object
|
||||
self.trainer = None # trainer object
|
||||
self.task = None # task type
|
||||
self.ckpt = None # if loaded from *.pt
|
||||
self.cfg = None # if loaded from *.yaml
|
||||
self.ckpt_path = None
|
||||
self.overrides = {} # overrides for trainer object
|
||||
self.metrics = None # validation/training metrics
|
||||
self.session = None # HUB session
|
||||
model = str(model).strip() # strip spaces
|
||||
|
||||
# Check if Ultralytics HUB model from https://hub.ultralytics.com
|
||||
if self.is_hub_model(model):
|
||||
from ultralytics.hub.session import HUBTrainingSession
|
||||
self.session = HUBTrainingSession(model)
|
||||
model = self.session.model_file
|
||||
|
||||
# Load or create new YOLO model
|
||||
suffix = Path(model).suffix
|
||||
if not suffix and Path(model).stem in GITHUB_ASSET_STEMS:
|
||||
model, suffix = Path(model).with_suffix('.pt'), '.pt' # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||
if suffix == '.yaml':
|
||||
self._new(model, task)
|
||||
else:
|
||||
self._load(model, task)
|
||||
|
||||
def __call__(self, source=None, stream=False, **kwargs):
|
||||
"""Calls the 'predict' function with given arguments to perform object detection."""
|
||||
return self.predict(source, stream, **kwargs)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
"""Raises error if object has no requested attribute."""
|
||||
name = self.__class__.__name__
|
||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||
|
||||
@staticmethod
|
||||
def is_hub_model(model):
|
||||
"""Check if the provided model is a HUB model."""
|
||||
return any((
|
||||
model.startswith('https://hub.ultralytics.com/models/'), # i.e. https://hub.ultralytics.com/models/MODEL_ID
|
||||
[len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID
|
||||
len(model) == 20 and not Path(model).exists() and all(x not in model for x in './\\'))) # MODELID
|
||||
|
||||
def _new(self, cfg: str, task=None, verbose=True):
|
||||
"""
|
||||
Initializes a new model and infers the task type from the model definitions.
|
||||
|
||||
Args:
|
||||
cfg (str): model configuration file
|
||||
task (str | None): model task
|
||||
verbose (bool): display model info on load
|
||||
"""
|
||||
cfg_dict = yaml_model_load(cfg)
|
||||
self.cfg = cfg
|
||||
self.task = task or guess_model_task(cfg_dict)
|
||||
self.model = TASK_MAP[self.task][0](cfg_dict, verbose=verbose and RANK == -1) # build model
|
||||
self.overrides['model'] = self.cfg
|
||||
|
||||
# Below added to allow export from yamls
|
||||
args = {**DEFAULT_CFG_DICT, **self.overrides} # combine model and default args, preferring model args
|
||||
self.model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
|
||||
self.model.task = self.task
|
||||
|
||||
def _load(self, weights: str, task=None):
|
||||
"""
|
||||
Initializes a new model and infers the task type from the model head.
|
||||
|
||||
Args:
|
||||
weights (str): model checkpoint to be loaded
|
||||
task (str | None): model task
|
||||
"""
|
||||
suffix = Path(weights).suffix
|
||||
if suffix == '.pt':
|
||||
self.model, self.ckpt = attempt_load_one_weight(weights)
|
||||
self.task = self.model.args['task']
|
||||
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
|
||||
self.ckpt_path = self.model.pt_path
|
||||
else:
|
||||
weights = check_file(weights)
|
||||
self.model, self.ckpt = weights, None
|
||||
self.task = task or guess_model_task(weights)
|
||||
self.ckpt_path = weights
|
||||
self.overrides['model'] = weights
|
||||
self.overrides['task'] = self.task
|
||||
|
||||
def _check_is_pytorch_model(self):
|
||||
"""
|
||||
Raises TypeError is model is not a PyTorch model
|
||||
"""
|
||||
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == '.pt'
|
||||
pt_module = isinstance(self.model, nn.Module)
|
||||
if not (pt_module or pt_str):
|
||||
raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. "
|
||||
f'PyTorch models can be used to train, val, predict and export, i.e. '
|
||||
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
|
||||
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
|
||||
|
||||
@smart_inference_mode()
|
||||
def reset_weights(self):
|
||||
"""
|
||||
Resets the model modules parameters to randomly initialized values, losing all training information.
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
for m in self.model.modules():
|
||||
if hasattr(m, 'reset_parameters'):
|
||||
m.reset_parameters()
|
||||
for p in self.model.parameters():
|
||||
p.requires_grad = True
|
||||
return self
|
||||
|
||||
@smart_inference_mode()
|
||||
def load(self, weights='yolov8n.pt'):
|
||||
"""
|
||||
Transfers parameters with matching names and shapes from 'weights' to model.
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
if isinstance(weights, (str, Path)):
|
||||
weights, self.ckpt = attempt_load_one_weight(weights)
|
||||
self.model.load(weights)
|
||||
return self
|
||||
|
||||
def info(self, detailed=False, verbose=True):
|
||||
"""
|
||||
Logs model info.
|
||||
|
||||
Args:
|
||||
detailed (bool): Show detailed information about model.
|
||||
verbose (bool): Controls verbosity.
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
return self.model.info(detailed=detailed, verbose=verbose)
|
||||
|
||||
def fuse(self):
|
||||
"""Fuse PyTorch Conv2d and BatchNorm2d layers."""
|
||||
self._check_is_pytorch_model()
|
||||
self.model.fuse()
|
||||
|
||||
@smart_inference_mode()
|
||||
def predict(self, source=None, stream=False, **kwargs):
|
||||
"""
|
||||
Perform prediction using the YOLO model.
|
||||
|
||||
Args:
|
||||
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
|
||||
Accepts all source types accepted by the YOLO model.
|
||||
stream (bool): Whether to stream the predictions or not. Defaults to False.
|
||||
**kwargs : Additional keyword arguments passed to the predictor.
|
||||
Check the 'configuration' section in the documentation for all available options.
|
||||
|
||||
Returns:
|
||||
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
|
||||
"""
|
||||
if source is None:
|
||||
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
||||
is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any(
|
||||
x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track'))
|
||||
overrides = self.overrides.copy()
|
||||
overrides['conf'] = 0.25
|
||||
overrides.update(kwargs) # prefer kwargs
|
||||
overrides['mode'] = kwargs.get('mode', 'predict')
|
||||
assert overrides['mode'] in ['track', 'predict']
|
||||
if not is_cli:
|
||||
overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
|
||||
if not self.predictor:
|
||||
self.task = overrides.get('task') or self.task
|
||||
self.predictor = TASK_MAP[self.task][3](overrides=overrides, _callbacks=self.callbacks)
|
||||
self.predictor.setup_model(model=self.model, verbose=is_cli)
|
||||
else: # only update args if predictor is already setup
|
||||
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
||||
if 'project' in overrides or 'name' in overrides:
|
||||
self.predictor.save_dir = self.predictor.get_save_dir()
|
||||
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
|
||||
|
||||
def track(self, source=None, stream=False, persist=False, **kwargs):
|
||||
"""
|
||||
Perform object tracking on the input source using the registered trackers.
|
||||
|
||||
Args:
|
||||
source (str, optional): The input source for object tracking. Can be a file path or a video stream.
|
||||
stream (bool, optional): Whether the input source is a video stream. Defaults to False.
|
||||
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
|
||||
**kwargs (optional): Additional keyword arguments for the tracking process.
|
||||
|
||||
Returns:
|
||||
(List[ultralytics.yolo.engine.results.Results]): The tracking results.
|
||||
|
||||
"""
|
||||
if not hasattr(self.predictor, 'trackers'):
|
||||
from ultralytics.tracker import register_tracker
|
||||
register_tracker(self, persist)
|
||||
# ByteTrack-based method needs low confidence predictions as input
|
||||
conf = kwargs.get('conf') or 0.1
|
||||
kwargs['conf'] = conf
|
||||
kwargs['mode'] = 'track'
|
||||
return self.predict(source=source, stream=stream, **kwargs)
|
||||
|
||||
@smart_inference_mode()
|
||||
def val(self, data=None, **kwargs):
|
||||
"""
|
||||
Validate a model on a given dataset.
|
||||
|
||||
Args:
|
||||
data (str): The dataset to validate on. Accepts all formats accepted by yolo
|
||||
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
|
||||
"""
|
||||
overrides = self.overrides.copy()
|
||||
overrides['rect'] = True # rect batches as default
|
||||
overrides.update(kwargs)
|
||||
overrides['mode'] = 'val'
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
args.data = data or args.data
|
||||
if 'task' in overrides:
|
||||
self.task = args.task
|
||||
else:
|
||||
args.task = self.task
|
||||
if args.imgsz == DEFAULT_CFG.imgsz and not isinstance(self.model, (str, Path)):
|
||||
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
||||
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
|
||||
|
||||
validator = TASK_MAP[self.task][2](args=args, _callbacks=self.callbacks)
|
||||
validator(model=self.model)
|
||||
self.metrics = validator.metrics
|
||||
|
||||
return validator.metrics
|
||||
|
||||
@smart_inference_mode()
|
||||
def benchmark(self, **kwargs):
|
||||
"""
|
||||
Benchmark a model on all export formats.
|
||||
|
||||
Args:
|
||||
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
from ultralytics.yolo.utils.benchmarks import benchmark
|
||||
overrides = self.model.args.copy()
|
||||
overrides.update(kwargs)
|
||||
overrides['mode'] = 'benchmark'
|
||||
overrides = {**DEFAULT_CFG_DICT, **overrides} # fill in missing overrides keys with defaults
|
||||
return benchmark(model=self, imgsz=overrides['imgsz'], half=overrides['half'], device=overrides['device'])
|
||||
|
||||
def export(self, **kwargs):
|
||||
"""
|
||||
Export model.
|
||||
|
||||
Args:
|
||||
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
overrides = self.overrides.copy()
|
||||
overrides.update(kwargs)
|
||||
overrides['mode'] = 'export'
|
||||
if overrides.get('imgsz') is None:
|
||||
overrides['imgsz'] = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
||||
if 'batch' not in kwargs:
|
||||
overrides['batch'] = 1 # default to 1 if not modified
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
args.task = self.task
|
||||
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""
|
||||
Trains the model on a given dataset.
|
||||
|
||||
Args:
|
||||
**kwargs (Any): Any number of arguments representing the training configuration.
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
if self.session: # Ultralytics HUB session
|
||||
if any(kwargs):
|
||||
LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
|
||||
kwargs = self.session.train_args
|
||||
check_pip_update_available()
|
||||
overrides = self.overrides.copy()
|
||||
if kwargs.get('cfg'):
|
||||
LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
|
||||
overrides = yaml_load(check_yaml(kwargs['cfg']))
|
||||
overrides.update(kwargs)
|
||||
overrides['mode'] = 'train'
|
||||
if not overrides.get('data'):
|
||||
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
|
||||
if overrides.get('resume'):
|
||||
overrides['resume'] = self.ckpt_path
|
||||
self.task = overrides.get('task') or self.task
|
||||
self.trainer = TASK_MAP[self.task][1](overrides=overrides, _callbacks=self.callbacks)
|
||||
if not overrides.get('resume'): # manually set model only if not resuming
|
||||
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
||||
self.model = self.trainer.model
|
||||
self.trainer.hub_session = self.session # attach optional HUB session
|
||||
self.trainer.train()
|
||||
# Update model and cfg after training
|
||||
if RANK in (-1, 0):
|
||||
self.model, _ = attempt_load_one_weight(str(self.trainer.best))
|
||||
self.overrides = self.model.args
|
||||
self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
|
||||
|
||||
def to(self, device):
|
||||
"""
|
||||
Sends the model to the given device.
|
||||
|
||||
Args:
|
||||
device (str): device
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
self.model.to(device)
|
||||
|
||||
def tune(self, *args, **kwargs):
|
||||
"""
|
||||
Runs hyperparameter tuning using Ray Tune. See ultralytics.yolo.utils.tuner.run_ray_tune for Args.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary containing the results of the hyperparameter search.
|
||||
|
||||
Raises:
|
||||
ModuleNotFoundError: If Ray Tune is not installed.
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
from ultralytics.yolo.utils.tuner import run_ray_tune
|
||||
return run_ray_tune(self, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
"""Returns class names of the loaded model."""
|
||||
return self.model.names if hasattr(self.model, 'names') else None
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
"""Returns device if PyTorch model."""
|
||||
return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None
|
||||
|
||||
@property
|
||||
def transforms(self):
|
||||
"""Returns transform of the loaded model."""
|
||||
return self.model.transforms if hasattr(self.model, 'transforms') else None
|
||||
|
||||
def add_callback(self, event: str, func):
|
||||
"""Add a callback."""
|
||||
self.callbacks[event].append(func)
|
||||
|
||||
def clear_callback(self, event: str):
|
||||
"""Clear all event callbacks."""
|
||||
self.callbacks[event] = []
|
||||
|
||||
@staticmethod
|
||||
def _reset_ckpt_args(args):
|
||||
"""Reset arguments when loading a PyTorch model."""
|
||||
include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
|
||||
return {k: v for k, v in args.items() if k in include}
|
||||
|
||||
def _reset_callbacks(self):
|
||||
"""Reset all registered callbacks."""
|
||||
for event in callbacks.default_callbacks.keys():
|
||||
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
|
@ -1,357 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Run prediction on images, videos, directories, globs, YouTube, webcam, streams, etc.
|
||||
|
||||
Usage - sources:
|
||||
$ yolo mode=predict model=yolov8n.pt source=0 # webcam
|
||||
img.jpg # image
|
||||
vid.mp4 # video
|
||||
screen # screenshot
|
||||
path/ # directory
|
||||
list.txt # list of images
|
||||
list.streams # list of streams
|
||||
'path/*.jpg' # glob
|
||||
'https://youtu.be/Zgi9g1ksQHc' # YouTube
|
||||
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
|
||||
|
||||
Usage - formats:
|
||||
$ yolo mode=predict model=yolov8n.pt # PyTorch
|
||||
yolov8n.torchscript # TorchScript
|
||||
yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
|
||||
yolov8n_openvino_model # OpenVINO
|
||||
yolov8n.engine # TensorRT
|
||||
yolov8n.mlmodel # CoreML (macOS-only)
|
||||
yolov8n_saved_model # TensorFlow SavedModel
|
||||
yolov8n.pb # TensorFlow GraphDef
|
||||
yolov8n.tflite # TensorFlow Lite
|
||||
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
|
||||
yolov8n_paddle_model # PaddlePaddle
|
||||
"""
|
||||
import platform
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.nn.autobackend import AutoBackend
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.data import load_inference_source
|
||||
from ultralytics.yolo.data.augment import LetterBox, classify_transforms
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, MACOS, SETTINGS, WINDOWS, callbacks, colorstr, ops
|
||||
from ultralytics.yolo.utils.checks import check_imgsz, check_imshow
|
||||
from ultralytics.yolo.utils.files import increment_path
|
||||
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
|
||||
|
||||
STREAM_WARNING = """
|
||||
WARNING ⚠️ stream/video/webcam/dir predict source will accumulate results in RAM unless `stream=True` is passed,
|
||||
causing potential out-of-memory errors for large sources or long-running streams/videos.
|
||||
|
||||
Usage:
|
||||
results = model(source=..., stream=True) # generator of Results objects
|
||||
for r in results:
|
||||
boxes = r.boxes # Boxes object for bbox outputs
|
||||
masks = r.masks # Masks object for segment masks outputs
|
||||
probs = r.probs # Class probabilities for classification outputs
|
||||
"""
|
||||
|
||||
|
||||
class BasePredictor:
|
||||
"""
|
||||
BasePredictor
|
||||
|
||||
A base class for creating predictors.
|
||||
|
||||
Attributes:
|
||||
args (SimpleNamespace): Configuration for the predictor.
|
||||
save_dir (Path): Directory to save results.
|
||||
done_warmup (bool): Whether the predictor has finished setup.
|
||||
model (nn.Module): Model used for prediction.
|
||||
data (dict): Data configuration.
|
||||
device (torch.device): Device used for prediction.
|
||||
dataset (Dataset): Dataset used for prediction.
|
||||
vid_path (str): Path to video file.
|
||||
vid_writer (cv2.VideoWriter): Video writer for saving video output.
|
||||
data_path (str): Path to data.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initializes the BasePredictor class.
|
||||
|
||||
Args:
|
||||
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
"""
|
||||
self.args = get_cfg(cfg, overrides)
|
||||
self.save_dir = self.get_save_dir()
|
||||
if self.args.conf is None:
|
||||
self.args.conf = 0.25 # default conf=0.25
|
||||
self.done_warmup = False
|
||||
if self.args.show:
|
||||
self.args.show = check_imshow(warn=True)
|
||||
|
||||
# Usable if setup is done
|
||||
self.model = None
|
||||
self.data = self.args.data # data_dict
|
||||
self.imgsz = None
|
||||
self.device = None
|
||||
self.dataset = None
|
||||
self.vid_path, self.vid_writer = None, None
|
||||
self.plotted_img = None
|
||||
self.data_path = None
|
||||
self.source_type = None
|
||||
self.batch = None
|
||||
self.results = None
|
||||
self.transforms = None
|
||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
def get_save_dir(self):
|
||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||
name = self.args.name or f'{self.args.mode}'
|
||||
return increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
|
||||
|
||||
def preprocess(self, im):
|
||||
"""Prepares input image before inference.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
|
||||
"""
|
||||
not_tensor = not isinstance(im, torch.Tensor)
|
||||
if not_tensor:
|
||||
im = np.stack(self.pre_transform(im))
|
||||
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
|
||||
im = np.ascontiguousarray(im) # contiguous
|
||||
im = torch.from_numpy(im)
|
||||
|
||||
img = im.to(self.device)
|
||||
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
|
||||
if not_tensor:
|
||||
img /= 255 # 0 - 255 to 0.0 - 1.0
|
||||
return img
|
||||
|
||||
def inference(self, im, *args, **kwargs):
|
||||
visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem,
|
||||
mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
|
||||
return self.model(im, augment=self.args.augment, visualize=visualize)
|
||||
|
||||
def pre_transform(self, im):
|
||||
"""Pre-transform input image before inference.
|
||||
|
||||
Args:
|
||||
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
|
||||
|
||||
Return: A list of transformed imgs.
|
||||
"""
|
||||
same_shapes = all(x.shape == im[0].shape for x in im)
|
||||
auto = same_shapes and self.model.pt
|
||||
return [LetterBox(self.imgsz, auto=auto, stride=self.model.stride)(image=x) for x in im]
|
||||
|
||||
def write_results(self, idx, results, batch):
|
||||
"""Write inference results to a file or directory."""
|
||||
p, im, _ = batch
|
||||
log_string = ''
|
||||
if len(im.shape) == 3:
|
||||
im = im[None] # expand for batch dim
|
||||
if self.source_type.webcam or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1
|
||||
log_string += f'{idx}: '
|
||||
frame = self.dataset.count
|
||||
else:
|
||||
frame = getattr(self.dataset, 'frame', 0)
|
||||
self.data_path = p
|
||||
self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
|
||||
log_string += '%gx%g ' % im.shape[2:] # print string
|
||||
result = results[idx]
|
||||
log_string += result.verbose()
|
||||
|
||||
if self.args.save or self.args.show: # Add bbox to image
|
||||
plot_args = {
|
||||
'line_width': self.args.line_width,
|
||||
'boxes': self.args.boxes,
|
||||
'conf': self.args.show_conf,
|
||||
'labels': self.args.show_labels}
|
||||
if not self.args.retina_masks:
|
||||
plot_args['im_gpu'] = im[idx]
|
||||
self.plotted_img = result.plot(**plot_args)
|
||||
# Write
|
||||
if self.args.save_txt:
|
||||
result.save_txt(f'{self.txt_path}.txt', save_conf=self.args.save_conf)
|
||||
if self.args.save_crop:
|
||||
result.save_crop(save_dir=self.save_dir / 'crops', file_name=self.data_path.stem)
|
||||
|
||||
return log_string
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Post-processes predictions for an image and returns them."""
|
||||
return preds
|
||||
|
||||
def __call__(self, source=None, model=None, stream=False, *args, **kwargs):
|
||||
"""Performs inference on an image or stream."""
|
||||
self.stream = stream
|
||||
if stream:
|
||||
return self.stream_inference(source, model, *args, **kwargs)
|
||||
else:
|
||||
return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one
|
||||
|
||||
def predict_cli(self, source=None, model=None):
|
||||
"""Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode."""
|
||||
gen = self.stream_inference(source, model)
|
||||
for _ in gen: # running CLI inference without accumulating any outputs (do not modify)
|
||||
pass
|
||||
|
||||
def setup_source(self, source):
|
||||
"""Sets up source and inference mode."""
|
||||
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
|
||||
self.transforms = getattr(self.model.model, 'transforms', classify_transforms(
|
||||
self.imgsz[0])) if self.args.task == 'classify' else None
|
||||
self.dataset = load_inference_source(source=source, imgsz=self.imgsz, vid_stride=self.args.vid_stride)
|
||||
self.source_type = self.dataset.source_type
|
||||
if not getattr(self, 'stream', True) and (self.dataset.mode == 'stream' or # streams
|
||||
len(self.dataset) > 1000 or # images
|
||||
any(getattr(self.dataset, 'video_flag', [False]))): # videos
|
||||
LOGGER.warning(STREAM_WARNING)
|
||||
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
|
||||
|
||||
@smart_inference_mode()
|
||||
def stream_inference(self, source=None, model=None, *args, **kwargs):
|
||||
"""Streams real-time inference on camera feed and saves results to file."""
|
||||
if self.args.verbose:
|
||||
LOGGER.info('')
|
||||
|
||||
# Setup model
|
||||
if not self.model:
|
||||
self.setup_model(model)
|
||||
|
||||
# Setup source every time predict is called
|
||||
self.setup_source(source if source is not None else self.args.source)
|
||||
|
||||
# Check if save_dir/ label file exists
|
||||
if self.args.save or self.args.save_txt:
|
||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Warmup model
|
||||
if not self.done_warmup:
|
||||
self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
|
||||
self.done_warmup = True
|
||||
|
||||
self.seen, self.windows, self.batch, profilers = 0, [], None, (ops.Profile(), ops.Profile(), ops.Profile())
|
||||
self.run_callbacks('on_predict_start')
|
||||
for batch in self.dataset:
|
||||
self.run_callbacks('on_predict_batch_start')
|
||||
self.batch = batch
|
||||
path, im0s, vid_cap, s = batch
|
||||
|
||||
# Preprocess
|
||||
with profilers[0]:
|
||||
im = self.preprocess(im0s)
|
||||
|
||||
# Inference
|
||||
with profilers[1]:
|
||||
preds = self.inference(im, *args, **kwargs)
|
||||
|
||||
# Postprocess
|
||||
with profilers[2]:
|
||||
self.results = self.postprocess(preds, im, im0s)
|
||||
self.run_callbacks('on_predict_postprocess_end')
|
||||
|
||||
# Visualize, save, write results
|
||||
n = len(im0s)
|
||||
for i in range(n):
|
||||
self.seen += 1
|
||||
self.results[i].speed = {
|
||||
'preprocess': profilers[0].dt * 1E3 / n,
|
||||
'inference': profilers[1].dt * 1E3 / n,
|
||||
'postprocess': profilers[2].dt * 1E3 / n}
|
||||
p, im0 = path[i], None if self.source_type.tensor else im0s[i].copy()
|
||||
p = Path(p)
|
||||
|
||||
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
|
||||
s += self.write_results(i, self.results, (p, im, im0))
|
||||
if self.args.save or self.args.save_txt:
|
||||
self.results[i].save_dir = self.save_dir.__str__()
|
||||
if self.args.show and self.plotted_img is not None:
|
||||
self.show(p)
|
||||
if self.args.save and self.plotted_img is not None:
|
||||
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
|
||||
|
||||
self.run_callbacks('on_predict_batch_end')
|
||||
yield from self.results
|
||||
|
||||
# Print time (inference-only)
|
||||
if self.args.verbose:
|
||||
LOGGER.info(f'{s}{profilers[1].dt * 1E3:.1f}ms')
|
||||
|
||||
# Release assets
|
||||
if isinstance(self.vid_writer[-1], cv2.VideoWriter):
|
||||
self.vid_writer[-1].release() # release final video writer
|
||||
|
||||
# Print results
|
||||
if self.args.verbose and self.seen:
|
||||
t = tuple(x.t / self.seen * 1E3 for x in profilers) # speeds per image
|
||||
LOGGER.info(f'Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape '
|
||||
f'{(1, 3, *im.shape[2:])}' % t)
|
||||
if self.args.save or self.args.save_txt or self.args.save_crop:
|
||||
nl = len(list(self.save_dir.glob('labels/*.txt'))) # number of labels
|
||||
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
|
||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
|
||||
|
||||
self.run_callbacks('on_predict_end')
|
||||
|
||||
def setup_model(self, model, verbose=True):
|
||||
"""Initialize YOLO model with given parameters and set it to evaluation mode."""
|
||||
self.model = AutoBackend(model or self.args.model,
|
||||
device=select_device(self.args.device, verbose=verbose),
|
||||
dnn=self.args.dnn,
|
||||
data=self.args.data,
|
||||
fp16=self.args.half,
|
||||
fuse=True,
|
||||
verbose=verbose)
|
||||
|
||||
self.device = self.model.device # update device
|
||||
self.args.half = self.model.fp16 # update half
|
||||
self.model.eval()
|
||||
|
||||
def show(self, p):
|
||||
"""Display an image in a window using OpenCV imshow()."""
|
||||
im0 = self.plotted_img
|
||||
if platform.system() == 'Linux' and p not in self.windows:
|
||||
self.windows.append(p)
|
||||
cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
|
||||
cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
|
||||
cv2.imshow(str(p), im0)
|
||||
cv2.waitKey(500 if self.batch[3].startswith('image') else 1) # 1 millisecond
|
||||
|
||||
def save_preds(self, vid_cap, idx, save_path):
|
||||
"""Save video predictions as mp4 at specified path."""
|
||||
im0 = self.plotted_img
|
||||
# Save imgs
|
||||
if self.dataset.mode == 'image':
|
||||
cv2.imwrite(save_path, im0)
|
||||
else: # 'video' or 'stream'
|
||||
if self.vid_path[idx] != save_path: # new video
|
||||
self.vid_path[idx] = save_path
|
||||
if isinstance(self.vid_writer[idx], cv2.VideoWriter):
|
||||
self.vid_writer[idx].release() # release previous video writer
|
||||
if vid_cap: # video
|
||||
fps = int(vid_cap.get(cv2.CAP_PROP_FPS)) # integer required, floats produce error in MP4 codec
|
||||
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
else: # stream
|
||||
fps, w, h = 30, im0.shape[1], im0.shape[0]
|
||||
suffix = '.mp4' if MACOS else '.avi' if WINDOWS else '.avi'
|
||||
fourcc = 'avc1' if MACOS else 'WMV2' if WINDOWS else 'MJPG'
|
||||
save_path = str(Path(save_path).with_suffix(suffix))
|
||||
self.vid_writer[idx] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*fourcc), fps, (w, h))
|
||||
self.vid_writer[idx].write(im0)
|
||||
|
||||
def run_callbacks(self, event: str):
|
||||
"""Runs all registered callbacks for a specific event."""
|
||||
for callback in self.callbacks.get(event, []):
|
||||
callback(self)
|
||||
|
||||
def add_callback(self, event: str, func):
|
||||
"""
|
||||
Add callback
|
||||
"""
|
||||
self.callbacks[event].append(func)
|
@ -1,614 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Ultralytics Results, Boxes and Masks classes for handling inference results
|
||||
|
||||
Usage: See https://docs.ultralytics.com/modes/predict/
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.data.augment import LetterBox
|
||||
from ultralytics.yolo.utils import LOGGER, SimpleClass, deprecation_warn, ops
|
||||
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
|
||||
|
||||
|
||||
class BaseTensor(SimpleClass):
|
||||
"""
|
||||
Base tensor class with additional methods for easy manipulation and device handling.
|
||||
"""
|
||||
|
||||
def __init__(self, data, orig_shape) -> None:
|
||||
"""Initialize BaseTensor with data and original shape.
|
||||
|
||||
Args:
|
||||
data (torch.Tensor | np.ndarray): Predictions, such as bboxes, masks and keypoints.
|
||||
orig_shape (tuple): Original shape of image.
|
||||
"""
|
||||
assert isinstance(data, (torch.Tensor, np.ndarray))
|
||||
self.data = data
|
||||
self.orig_shape = orig_shape
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
"""Return the shape of the data tensor."""
|
||||
return self.data.shape
|
||||
|
||||
def cpu(self):
|
||||
"""Return a copy of the tensor on CPU memory."""
|
||||
return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.cpu(), self.orig_shape)
|
||||
|
||||
def numpy(self):
|
||||
"""Return a copy of the tensor as a numpy array."""
|
||||
return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.numpy(), self.orig_shape)
|
||||
|
||||
def cuda(self):
|
||||
"""Return a copy of the tensor on GPU memory."""
|
||||
return self.__class__(torch.as_tensor(self.data).cuda(), self.orig_shape)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
"""Return a copy of the tensor with the specified device and dtype."""
|
||||
return self.__class__(torch.as_tensor(self.data).to(*args, **kwargs), self.orig_shape)
|
||||
|
||||
def __len__(self): # override len(results)
|
||||
"""Return the length of the data tensor."""
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Return a BaseTensor with the specified index of the data tensor."""
|
||||
return self.__class__(self.data[idx], self.orig_shape)
|
||||
|
||||
|
||||
class Results(SimpleClass):
|
||||
"""
|
||||
A class for storing and manipulating inference results.
|
||||
|
||||
Args:
|
||||
orig_img (numpy.ndarray): The original image as a numpy array.
|
||||
path (str): The path to the image file.
|
||||
names (dict): A dictionary of class names.
|
||||
boxes (torch.tensor, optional): A 2D tensor of bounding box coordinates for each detection.
|
||||
masks (torch.tensor, optional): A 3D tensor of detection masks, where each mask is a binary image.
|
||||
probs (torch.tensor, optional): A 1D tensor of probabilities of each class for classification task.
|
||||
keypoints (List[List[float]], optional): A list of detected keypoints for each object.
|
||||
|
||||
|
||||
Attributes:
|
||||
orig_img (numpy.ndarray): The original image as a numpy array.
|
||||
orig_shape (tuple): The original image shape in (height, width) format.
|
||||
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
|
||||
masks (Masks, optional): A Masks object containing the detection masks.
|
||||
probs (Probs, optional): A Probs object containing probabilities of each class for classification task.
|
||||
names (dict): A dictionary of class names.
|
||||
path (str): The path to the image file.
|
||||
keypoints (Keypoints, optional): A Keypoints object containing detected keypoints for each object.
|
||||
speed (dict): A dictionary of preprocess, inference and postprocess speeds in milliseconds per image.
|
||||
_keys (tuple): A tuple of attribute names for non-empty attributes.
|
||||
"""
|
||||
|
||||
def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None, keypoints=None) -> None:
|
||||
"""Initialize the Results class."""
|
||||
self.orig_img = orig_img
|
||||
self.orig_shape = orig_img.shape[:2]
|
||||
self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes
|
||||
self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks
|
||||
self.probs = Probs(probs) if probs is not None else None
|
||||
self.keypoints = Keypoints(keypoints, self.orig_shape) if keypoints is not None else None
|
||||
self.speed = {'preprocess': None, 'inference': None, 'postprocess': None} # milliseconds per image
|
||||
self.names = names
|
||||
self.path = path
|
||||
self.save_dir = None
|
||||
self._keys = ('boxes', 'masks', 'probs', 'keypoints')
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Return a Results object for the specified index."""
|
||||
r = self.new()
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k)[idx])
|
||||
return r
|
||||
|
||||
def update(self, boxes=None, masks=None, probs=None):
|
||||
"""Update the boxes, masks, and probs attributes of the Results object."""
|
||||
if boxes is not None:
|
||||
self.boxes = Boxes(boxes, self.orig_shape)
|
||||
if masks is not None:
|
||||
self.masks = Masks(masks, self.orig_shape)
|
||||
if probs is not None:
|
||||
self.probs = probs
|
||||
|
||||
def cpu(self):
|
||||
"""Return a copy of the Results object with all tensors on CPU memory."""
|
||||
r = self.new()
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k).cpu())
|
||||
return r
|
||||
|
||||
def numpy(self):
|
||||
"""Return a copy of the Results object with all tensors as numpy arrays."""
|
||||
r = self.new()
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k).numpy())
|
||||
return r
|
||||
|
||||
def cuda(self):
|
||||
"""Return a copy of the Results object with all tensors on GPU memory."""
|
||||
r = self.new()
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k).cuda())
|
||||
return r
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
"""Return a copy of the Results object with tensors on the specified device and dtype."""
|
||||
r = self.new()
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k).to(*args, **kwargs))
|
||||
return r
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of detections in the Results object."""
|
||||
for k in self.keys:
|
||||
return len(getattr(self, k))
|
||||
|
||||
def new(self):
|
||||
"""Return a new Results object with the same image, path, and names."""
|
||||
return Results(orig_img=self.orig_img, path=self.path, names=self.names)
|
||||
|
||||
@property
|
||||
def keys(self):
|
||||
"""Return a list of non-empty attribute names."""
|
||||
return [k for k in self._keys if getattr(self, k) is not None]
|
||||
|
||||
def plot(
|
||||
self,
|
||||
conf=True,
|
||||
line_width=None,
|
||||
font_size=None,
|
||||
font='Arial.ttf',
|
||||
pil=False,
|
||||
img=None,
|
||||
im_gpu=None,
|
||||
kpt_line=True,
|
||||
labels=True,
|
||||
boxes=True,
|
||||
masks=True,
|
||||
probs=True,
|
||||
**kwargs # deprecated args TODO: remove support in 8.2
|
||||
):
|
||||
"""
|
||||
Plots the detection results on an input RGB image. Accepts a numpy array (cv2) or a PIL Image.
|
||||
|
||||
Args:
|
||||
conf (bool): Whether to plot the detection confidence score.
|
||||
line_width (float, optional): The line width of the bounding boxes. If None, it is scaled to the image size.
|
||||
font_size (float, optional): The font size of the text. If None, it is scaled to the image size.
|
||||
font (str): The font to use for the text.
|
||||
pil (bool): Whether to return the image as a PIL Image.
|
||||
img (numpy.ndarray): Plot to another image. if not, plot to original image.
|
||||
im_gpu (torch.Tensor): Normalized image in gpu with shape (1, 3, 640, 640), for faster mask plotting.
|
||||
kpt_line (bool): Whether to draw lines connecting keypoints.
|
||||
labels (bool): Whether to plot the label of bounding boxes.
|
||||
boxes (bool): Whether to plot the bounding boxes.
|
||||
masks (bool): Whether to plot the masks.
|
||||
probs (bool): Whether to plot classification probability
|
||||
|
||||
Returns:
|
||||
(numpy.ndarray): A numpy array of the annotated image.
|
||||
"""
|
||||
if img is None and isinstance(self.orig_img, torch.Tensor):
|
||||
img = np.ascontiguousarray(self.orig_img[0].permute(1, 2, 0).cpu().detach().numpy()) * 255
|
||||
|
||||
# Deprecation warn TODO: remove in 8.2
|
||||
if 'show_conf' in kwargs:
|
||||
deprecation_warn('show_conf', 'conf')
|
||||
conf = kwargs['show_conf']
|
||||
assert type(conf) == bool, '`show_conf` should be of boolean type, i.e, show_conf=True/False'
|
||||
|
||||
if 'line_thickness' in kwargs:
|
||||
deprecation_warn('line_thickness', 'line_width')
|
||||
line_width = kwargs['line_thickness']
|
||||
assert type(line_width) == int, '`line_width` should be of int type, i.e, line_width=3'
|
||||
|
||||
names = self.names
|
||||
pred_boxes, show_boxes = self.boxes, boxes
|
||||
pred_masks, show_masks = self.masks, masks
|
||||
pred_probs, show_probs = self.probs, probs
|
||||
annotator = Annotator(
|
||||
deepcopy(self.orig_img if img is None else img),
|
||||
line_width,
|
||||
font_size,
|
||||
font,
|
||||
pil or (pred_probs is not None and show_probs), # Classify tasks default to pil=True
|
||||
example=names)
|
||||
|
||||
# Plot Segment results
|
||||
if pred_masks and show_masks:
|
||||
if im_gpu is None:
|
||||
img = LetterBox(pred_masks.shape[1:])(image=annotator.result())
|
||||
im_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device).permute(
|
||||
2, 0, 1).flip(0).contiguous() / 255
|
||||
idx = pred_boxes.cls if pred_boxes else range(len(pred_masks))
|
||||
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=im_gpu)
|
||||
|
||||
# Plot Detect results
|
||||
if pred_boxes and show_boxes:
|
||||
for d in reversed(pred_boxes):
|
||||
c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
|
||||
name = ('' if id is None else f'id:{id} ') + names[c]
|
||||
label = (f'{name} {conf:.2f}' if conf else name) if labels else None
|
||||
annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
||||
|
||||
# Plot Classify results
|
||||
if pred_probs is not None and show_probs:
|
||||
text = ',\n'.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5)
|
||||
x = round(self.orig_shape[0] * 0.03)
|
||||
annotator.text([x, x], text, txt_color=(255, 255, 255)) # TODO: allow setting colors
|
||||
|
||||
# Plot Pose results
|
||||
if self.keypoints is not None:
|
||||
for k in reversed(self.keypoints.data):
|
||||
annotator.kpts(k, self.orig_shape, kpt_line=kpt_line)
|
||||
|
||||
return annotator.result()
|
||||
|
||||
def verbose(self):
|
||||
"""
|
||||
Return log string for each task.
|
||||
"""
|
||||
log_string = ''
|
||||
probs = self.probs
|
||||
boxes = self.boxes
|
||||
if len(self) == 0:
|
||||
return log_string if probs is not None else f'{log_string}(no detections), '
|
||||
if probs is not None:
|
||||
log_string += f"{', '.join(f'{self.names[j]} {probs.data[j]:.2f}' for j in probs.top5)}, "
|
||||
if boxes:
|
||||
for c in boxes.cls.unique():
|
||||
n = (boxes.cls == c).sum() # detections per class
|
||||
log_string += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, "
|
||||
return log_string
|
||||
|
||||
def save_txt(self, txt_file, save_conf=False):
|
||||
"""
|
||||
Save predictions into txt file.
|
||||
|
||||
Args:
|
||||
txt_file (str): txt file path.
|
||||
save_conf (bool): save confidence score or not.
|
||||
"""
|
||||
boxes = self.boxes
|
||||
masks = self.masks
|
||||
probs = self.probs
|
||||
kpts = self.keypoints
|
||||
texts = []
|
||||
if probs is not None:
|
||||
# Classify
|
||||
[texts.append(f'{probs.data[j]:.2f} {self.names[j]}') for j in probs.top5]
|
||||
elif boxes:
|
||||
# Detect/segment/pose
|
||||
for j, d in enumerate(boxes):
|
||||
c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item())
|
||||
line = (c, *d.xywhn.view(-1))
|
||||
if masks:
|
||||
seg = masks[j].xyn[0].copy().reshape(-1) # reversed mask.xyn, (n,2) to (n*2)
|
||||
line = (c, *seg)
|
||||
if kpts is not None:
|
||||
kpt = torch.cat((kpts[j].xyn, kpts[j].conf[..., None]), 2) if kpts[j].has_visible else kpts[j].xyn
|
||||
line += (*kpt.reshape(-1).tolist(), )
|
||||
line += (conf, ) * save_conf + (() if id is None else (id, ))
|
||||
texts.append(('%g ' * len(line)).rstrip() % line)
|
||||
|
||||
if texts:
|
||||
with open(txt_file, 'a') as f:
|
||||
f.writelines(text + '\n' for text in texts)
|
||||
|
||||
def save_crop(self, save_dir, file_name=Path('im.jpg')):
|
||||
"""
|
||||
Save cropped predictions to `save_dir/cls/file_name.jpg`.
|
||||
|
||||
Args:
|
||||
save_dir (str | pathlib.Path): Save path.
|
||||
file_name (str | pathlib.Path): File name.
|
||||
"""
|
||||
if self.probs is not None:
|
||||
LOGGER.warning('WARNING ⚠️ Classify task do not support `save_crop`.')
|
||||
return
|
||||
if isinstance(save_dir, str):
|
||||
save_dir = Path(save_dir)
|
||||
if isinstance(file_name, str):
|
||||
file_name = Path(file_name)
|
||||
for d in self.boxes:
|
||||
save_one_box(d.xyxy,
|
||||
self.orig_img.copy(),
|
||||
file=save_dir / self.names[int(d.cls)] / f'{file_name.stem}.jpg',
|
||||
BGR=True)
|
||||
|
||||
def pandas(self):
|
||||
"""Convert the object to a pandas DataFrame (not yet implemented)."""
|
||||
LOGGER.warning("WARNING ⚠️ 'Results.pandas' method is not yet implemented.")
|
||||
|
||||
def tojson(self, normalize=False):
|
||||
"""Convert the object to JSON format."""
|
||||
if self.probs is not None:
|
||||
LOGGER.warning('Warning: Classify task do not support `tojson` yet.')
|
||||
return
|
||||
|
||||
import json
|
||||
|
||||
# Create list of detection dictionaries
|
||||
results = []
|
||||
data = self.boxes.data.cpu().tolist()
|
||||
h, w = self.orig_shape if normalize else (1, 1)
|
||||
for i, row in enumerate(data):
|
||||
box = {'x1': row[0] / w, 'y1': row[1] / h, 'x2': row[2] / w, 'y2': row[3] / h}
|
||||
conf = row[4]
|
||||
id = int(row[5])
|
||||
name = self.names[id]
|
||||
result = {'name': name, 'class': id, 'confidence': conf, 'box': box}
|
||||
if self.masks:
|
||||
x, y = self.masks.xy[i][:, 0], self.masks.xy[i][:, 1] # numpy array
|
||||
result['segments'] = {'x': (x / w).tolist(), 'y': (y / h).tolist()}
|
||||
if self.keypoints is not None:
|
||||
x, y, visible = self.keypoints[i].data[0].cpu().unbind(dim=1) # torch Tensor
|
||||
result['keypoints'] = {'x': (x / w).tolist(), 'y': (y / h).tolist(), 'visible': visible.tolist()}
|
||||
results.append(result)
|
||||
|
||||
# Convert detections to JSON
|
||||
return json.dumps(results, indent=2)
|
||||
|
||||
|
||||
class Boxes(BaseTensor):
|
||||
"""
|
||||
A class for storing and manipulating detection boxes.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor | numpy.ndarray): A tensor or numpy array containing the detection boxes,
|
||||
with shape (num_boxes, 6). The last two columns should contain confidence and class values.
|
||||
orig_shape (tuple): Original image size, in the format (height, width).
|
||||
|
||||
Attributes:
|
||||
boxes (torch.Tensor | numpy.ndarray): The detection boxes with shape (num_boxes, 6).
|
||||
orig_shape (torch.Tensor | numpy.ndarray): Original image size, in the format (height, width).
|
||||
is_track (bool): True if the boxes also include track IDs, False otherwise.
|
||||
|
||||
Properties:
|
||||
xyxy (torch.Tensor | numpy.ndarray): The boxes in xyxy format.
|
||||
conf (torch.Tensor | numpy.ndarray): The confidence values of the boxes.
|
||||
cls (torch.Tensor | numpy.ndarray): The class values of the boxes.
|
||||
id (torch.Tensor | numpy.ndarray): The track IDs of the boxes (if available).
|
||||
xywh (torch.Tensor | numpy.ndarray): The boxes in xywh format.
|
||||
xyxyn (torch.Tensor | numpy.ndarray): The boxes in xyxy format normalized by original image size.
|
||||
xywhn (torch.Tensor | numpy.ndarray): The boxes in xywh format normalized by original image size.
|
||||
data (torch.Tensor): The raw bboxes tensor
|
||||
|
||||
Methods:
|
||||
cpu(): Move the object to CPU memory.
|
||||
numpy(): Convert the object to a numpy array.
|
||||
cuda(): Move the object to CUDA memory.
|
||||
to(*args, **kwargs): Move the object to the specified device.
|
||||
pandas(): Convert the object to a pandas DataFrame (not yet implemented).
|
||||
"""
|
||||
|
||||
def __init__(self, boxes, orig_shape) -> None:
|
||||
"""Initialize the Boxes class."""
|
||||
if boxes.ndim == 1:
|
||||
boxes = boxes[None, :]
|
||||
n = boxes.shape[-1]
|
||||
assert n in (6, 7), f'expected `n` in [6, 7], but got {n}' # xyxy, (track_id), conf, cls
|
||||
super().__init__(boxes, orig_shape)
|
||||
self.is_track = n == 7
|
||||
self.orig_shape = orig_shape
|
||||
|
||||
@property
|
||||
def xyxy(self):
|
||||
"""Return the boxes in xyxy format."""
|
||||
return self.data[:, :4]
|
||||
|
||||
@property
|
||||
def conf(self):
|
||||
"""Return the confidence values of the boxes."""
|
||||
return self.data[:, -2]
|
||||
|
||||
@property
|
||||
def cls(self):
|
||||
"""Return the class values of the boxes."""
|
||||
return self.data[:, -1]
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
"""Return the track IDs of the boxes (if available)."""
|
||||
return self.data[:, -3] if self.is_track else None
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=2) # maxsize 1 should suffice
|
||||
def xywh(self):
|
||||
"""Return the boxes in xywh format."""
|
||||
return ops.xyxy2xywh(self.xyxy)
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=2)
|
||||
def xyxyn(self):
|
||||
"""Return the boxes in xyxy format normalized by original image size."""
|
||||
xyxy = self.xyxy.clone() if isinstance(self.xyxy, torch.Tensor) else np.copy(self.xyxy)
|
||||
xyxy[..., [0, 2]] /= self.orig_shape[1]
|
||||
xyxy[..., [1, 3]] /= self.orig_shape[0]
|
||||
return xyxy
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=2)
|
||||
def xywhn(self):
|
||||
"""Return the boxes in xywh format normalized by original image size."""
|
||||
xywh = ops.xyxy2xywh(self.xyxy)
|
||||
xywh[..., [0, 2]] /= self.orig_shape[1]
|
||||
xywh[..., [1, 3]] /= self.orig_shape[0]
|
||||
return xywh
|
||||
|
||||
@property
|
||||
def boxes(self):
|
||||
"""Return the raw bboxes tensor (deprecated)."""
|
||||
LOGGER.warning("WARNING ⚠️ 'Boxes.boxes' is deprecated. Use 'Boxes.data' instead.")
|
||||
return self.data
|
||||
|
||||
|
||||
class Masks(BaseTensor):
|
||||
"""
|
||||
A class for storing and manipulating detection masks.
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor | np.ndarray): A tensor containing the detection masks, with shape (num_masks, height, width).
|
||||
orig_shape (tuple): Original image size, in the format (height, width).
|
||||
|
||||
Attributes:
|
||||
masks (torch.Tensor | np.ndarray): A tensor containing the detection masks, with shape (num_masks, height, width).
|
||||
orig_shape (tuple): Original image size, in the format (height, width).
|
||||
|
||||
Properties:
|
||||
xy (list): A list of segments (pixels) which includes x, y segments of each detection.
|
||||
xyn (list): A list of segments (normalized) which includes x, y segments of each detection.
|
||||
|
||||
Methods:
|
||||
cpu(): Returns a copy of the masks tensor on CPU memory.
|
||||
numpy(): Returns a copy of the masks tensor as a numpy array.
|
||||
cuda(): Returns a copy of the masks tensor on GPU memory.
|
||||
to(): Returns a copy of the masks tensor with the specified device and dtype.
|
||||
"""
|
||||
|
||||
def __init__(self, masks, orig_shape) -> None:
|
||||
"""Initialize the Masks class."""
|
||||
if masks.ndim == 2:
|
||||
masks = masks[None, :]
|
||||
super().__init__(masks, orig_shape)
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def segments(self):
|
||||
"""Return segments (deprecated; normalized)."""
|
||||
LOGGER.warning("WARNING ⚠️ 'Masks.segments' is deprecated. Use 'Masks.xyn' for segments (normalized) and "
|
||||
"'Masks.xy' for segments (pixels) instead.")
|
||||
return self.xyn
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def xyn(self):
|
||||
"""Return segments (normalized)."""
|
||||
return [
|
||||
ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=True)
|
||||
for x in ops.masks2segments(self.data)]
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def xy(self):
|
||||
"""Return segments (pixels)."""
|
||||
return [
|
||||
ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=False)
|
||||
for x in ops.masks2segments(self.data)]
|
||||
|
||||
@property
|
||||
def masks(self):
|
||||
"""Return the raw masks tensor (deprecated)."""
|
||||
LOGGER.warning("WARNING ⚠️ 'Masks.masks' is deprecated. Use 'Masks.data' instead.")
|
||||
return self.data
|
||||
|
||||
def pandas(self):
|
||||
"""Convert the object to a pandas DataFrame (not yet implemented)."""
|
||||
LOGGER.warning("WARNING ⚠️ 'Masks.pandas' method is not yet implemented.")
|
||||
|
||||
|
||||
class Keypoints(BaseTensor):
|
||||
"""
|
||||
A class for storing and manipulating detection keypoints.
|
||||
|
||||
Args:
|
||||
keypoints (torch.Tensor | np.ndarray): A tensor containing the detection keypoints, with shape (num_dets, num_kpts, 2/3).
|
||||
orig_shape (tuple): Original image size, in the format (height, width).
|
||||
|
||||
Attributes:
|
||||
keypoints (torch.Tensor | np.ndarray): A tensor containing the detection keypoints, with shape (num_dets, num_kpts, 2/3).
|
||||
orig_shape (tuple): Original image size, in the format (height, width).
|
||||
|
||||
Properties:
|
||||
xy (list): A list of keypoints (pixels) which includes x, y keypoints of each detection.
|
||||
xyn (list): A list of keypoints (normalized) which includes x, y keypoints of each detection.
|
||||
|
||||
Methods:
|
||||
cpu(): Returns a copy of the keypoints tensor on CPU memory.
|
||||
numpy(): Returns a copy of the keypoints tensor as a numpy array.
|
||||
cuda(): Returns a copy of the keypoints tensor on GPU memory.
|
||||
to(): Returns a copy of the keypoints tensor with the specified device and dtype.
|
||||
"""
|
||||
|
||||
def __init__(self, keypoints, orig_shape) -> None:
|
||||
if keypoints.ndim == 2:
|
||||
keypoints = keypoints[None, :]
|
||||
super().__init__(keypoints, orig_shape)
|
||||
self.has_visible = self.data.shape[-1] == 3
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def xy(self):
|
||||
return self.data[..., :2]
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def xyn(self):
|
||||
xy = self.xy.clone() if isinstance(self.xy, torch.Tensor) else np.copy(self.xy)
|
||||
xy[..., 0] /= self.orig_shape[1]
|
||||
xy[..., 1] /= self.orig_shape[0]
|
||||
return xy
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def conf(self):
|
||||
return self.data[..., 2] if self.has_visible else None
|
||||
|
||||
|
||||
class Probs(BaseTensor):
|
||||
"""
|
||||
A class for storing and manipulating classify predictions.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor | np.ndarray): A tensor containing the detection keypoints, with shape (num_class, ).
|
||||
|
||||
Attributes:
|
||||
probs (torch.Tensor | np.ndarray): A tensor containing the detection keypoints, with shape (num_class).
|
||||
|
||||
Properties:
|
||||
top5 (list[int]): Top 1 indice.
|
||||
top1 (int): Top 5 indices.
|
||||
|
||||
Methods:
|
||||
cpu(): Returns a copy of the probs tensor on CPU memory.
|
||||
numpy(): Returns a copy of the probs tensor as a numpy array.
|
||||
cuda(): Returns a copy of the probs tensor on GPU memory.
|
||||
to(): Returns a copy of the probs tensor with the specified device and dtype.
|
||||
"""
|
||||
|
||||
def __init__(self, probs, orig_shape=None) -> None:
|
||||
super().__init__(probs, orig_shape)
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def top5(self):
|
||||
"""Return the indices of top 5."""
|
||||
return (-self.data).argsort(0)[:5].tolist() # this way works with both torch and numpy.
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def top1(self):
|
||||
"""Return the indices of top 1."""
|
||||
return int(self.data.argmax())
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def top5conf(self):
|
||||
"""Return the confidences of top 5."""
|
||||
return self.data[self.top5]
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def top1conf(self):
|
||||
"""Return the confidences of top 1."""
|
||||
return self.data[self.top1]
|
@ -1,664 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Train a model on a dataset
|
||||
|
||||
Usage:
|
||||
$ yolo mode=train model=yolov8n.pt data=coco128.yaml imgsz=640 epochs=100 batch=16
|
||||
"""
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
from torch import nn, optim
|
||||
from torch.cuda import amp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, __version__, callbacks,
|
||||
clean_url, colorstr, emojis, yaml_save)
|
||||
from ultralytics.yolo.utils.autobatch import check_train_batch_size
|
||||
from ultralytics.yolo.utils.checks import check_amp, check_file, check_imgsz, print_args
|
||||
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
|
||||
from ultralytics.yolo.utils.files import get_latest_run, increment_path
|
||||
from ultralytics.yolo.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle,
|
||||
select_device, strip_optimizer)
|
||||
|
||||
|
||||
class BaseTrainer:
|
||||
"""
|
||||
BaseTrainer
|
||||
|
||||
A base class for creating trainers.
|
||||
|
||||
Attributes:
|
||||
args (SimpleNamespace): Configuration for the trainer.
|
||||
check_resume (method): Method to check if training should be resumed from a saved checkpoint.
|
||||
validator (BaseValidator): Validator instance.
|
||||
model (nn.Module): Model instance.
|
||||
callbacks (defaultdict): Dictionary of callbacks.
|
||||
save_dir (Path): Directory to save results.
|
||||
wdir (Path): Directory to save weights.
|
||||
last (Path): Path to last checkpoint.
|
||||
best (Path): Path to best checkpoint.
|
||||
save_period (int): Save checkpoint every x epochs (disabled if < 1).
|
||||
batch_size (int): Batch size for training.
|
||||
epochs (int): Number of epochs to train for.
|
||||
start_epoch (int): Starting epoch for training.
|
||||
device (torch.device): Device to use for training.
|
||||
amp (bool): Flag to enable AMP (Automatic Mixed Precision).
|
||||
scaler (amp.GradScaler): Gradient scaler for AMP.
|
||||
data (str): Path to data.
|
||||
trainset (torch.utils.data.Dataset): Training dataset.
|
||||
testset (torch.utils.data.Dataset): Testing dataset.
|
||||
ema (nn.Module): EMA (Exponential Moving Average) of the model.
|
||||
lf (nn.Module): Loss function.
|
||||
scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
|
||||
best_fitness (float): The best fitness value achieved.
|
||||
fitness (float): Current fitness value.
|
||||
loss (float): Current loss value.
|
||||
tloss (float): Total loss value.
|
||||
loss_names (list): List of loss names.
|
||||
csv (Path): Path to results CSV file.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initializes the BaseTrainer class.
|
||||
|
||||
Args:
|
||||
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
"""
|
||||
self.args = get_cfg(cfg, overrides)
|
||||
self.device = select_device(self.args.device, self.args.batch)
|
||||
self.check_resume()
|
||||
self.validator = None
|
||||
self.model = None
|
||||
self.metrics = None
|
||||
self.plots = {}
|
||||
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
||||
|
||||
# Dirs
|
||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||
name = self.args.name or f'{self.args.mode}'
|
||||
if hasattr(self.args, 'save_dir'):
|
||||
self.save_dir = Path(self.args.save_dir)
|
||||
else:
|
||||
self.save_dir = Path(
|
||||
increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in (-1, 0) else True))
|
||||
self.wdir = self.save_dir / 'weights' # weights dir
|
||||
if RANK in (-1, 0):
|
||||
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
||||
self.args.save_dir = str(self.save_dir)
|
||||
yaml_save(self.save_dir / 'args.yaml', vars(self.args)) # save run args
|
||||
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
|
||||
self.save_period = self.args.save_period
|
||||
|
||||
self.batch_size = self.args.batch
|
||||
self.epochs = self.args.epochs
|
||||
self.start_epoch = 0
|
||||
if RANK == -1:
|
||||
print_args(vars(self.args))
|
||||
|
||||
# Device
|
||||
if self.device.type == 'cpu':
|
||||
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
|
||||
|
||||
# Model and Dataset
|
||||
self.model = self.args.model
|
||||
try:
|
||||
if self.args.task == 'classify':
|
||||
self.data = check_cls_dataset(self.args.data)
|
||||
elif self.args.data.endswith('.yaml') or self.args.task in ('detect', 'segment'):
|
||||
self.data = check_det_dataset(self.args.data)
|
||||
if 'yaml_file' in self.data:
|
||||
self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage
|
||||
except Exception as e:
|
||||
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
|
||||
|
||||
self.trainset, self.testset = self.get_dataset(self.data)
|
||||
self.ema = None
|
||||
|
||||
# Optimization utils init
|
||||
self.lf = None
|
||||
self.scheduler = None
|
||||
|
||||
# Epoch level metrics
|
||||
self.best_fitness = None
|
||||
self.fitness = None
|
||||
self.loss = None
|
||||
self.tloss = None
|
||||
self.loss_names = ['Loss']
|
||||
self.csv = self.save_dir / 'results.csv'
|
||||
self.plot_idx = [0, 1, 2]
|
||||
|
||||
# Callbacks
|
||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||
if RANK in (-1, 0):
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
def add_callback(self, event: str, callback):
|
||||
"""
|
||||
Appends the given callback.
|
||||
"""
|
||||
self.callbacks[event].append(callback)
|
||||
|
||||
def set_callback(self, event: str, callback):
|
||||
"""
|
||||
Overrides the existing callbacks with the given callback.
|
||||
"""
|
||||
self.callbacks[event] = [callback]
|
||||
|
||||
def run_callbacks(self, event: str):
|
||||
"""Run all existing callbacks associated with a particular event."""
|
||||
for callback in self.callbacks.get(event, []):
|
||||
callback(self)
|
||||
|
||||
def train(self):
|
||||
"""Allow device='', device=None on Multi-GPU systems to default to device=0."""
|
||||
if isinstance(self.args.device, int) or self.args.device: # i.e. device=0 or device=[0,1,2,3]
|
||||
world_size = torch.cuda.device_count()
|
||||
elif torch.cuda.is_available(): # i.e. device=None or device=''
|
||||
world_size = 1 # default to device 0
|
||||
else: # i.e. device='cpu' or 'mps'
|
||||
world_size = 0
|
||||
|
||||
# Run subprocess if DDP training, else train normally
|
||||
if world_size > 1 and 'LOCAL_RANK' not in os.environ:
|
||||
# Argument checks
|
||||
if self.args.rect:
|
||||
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting rect=False")
|
||||
self.args.rect = False
|
||||
# Command
|
||||
cmd, file = generate_ddp_command(world_size, self)
|
||||
try:
|
||||
LOGGER.info(f'DDP command: {cmd}')
|
||||
subprocess.run(cmd, check=True)
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
ddp_cleanup(self, str(file))
|
||||
else:
|
||||
self._do_train(world_size)
|
||||
|
||||
def _setup_ddp(self, world_size):
|
||||
"""Initializes and sets the DistributedDataParallel parameters for training."""
|
||||
torch.cuda.set_device(RANK)
|
||||
self.device = torch.device('cuda', RANK)
|
||||
LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
||||
os.environ['NCCL_BLOCKING_WAIT'] = '1' # set to enforce timeout
|
||||
dist.init_process_group(
|
||||
'nccl' if dist.is_nccl_available() else 'gloo',
|
||||
timeout=timedelta(seconds=10800), # 3 hours
|
||||
rank=RANK,
|
||||
world_size=world_size)
|
||||
|
||||
def _setup_train(self, world_size):
|
||||
"""
|
||||
Builds dataloaders and optimizer on correct rank process.
|
||||
"""
|
||||
# Model
|
||||
self.run_callbacks('on_pretrain_routine_start')
|
||||
ckpt = self.setup_model()
|
||||
self.model = self.model.to(self.device)
|
||||
self.set_model_attributes()
|
||||
# Check AMP
|
||||
self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
|
||||
if self.amp and RANK in (-1, 0): # Single-GPU and DDP
|
||||
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
|
||||
self.amp = torch.tensor(check_amp(self.model), device=self.device)
|
||||
callbacks.default_callbacks = callbacks_backup # restore callbacks
|
||||
if RANK > -1 and world_size > 1: # DDP
|
||||
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
|
||||
self.amp = bool(self.amp) # as boolean
|
||||
self.scaler = amp.GradScaler(enabled=self.amp)
|
||||
if world_size > 1:
|
||||
self.model = DDP(self.model, device_ids=[RANK])
|
||||
# Check imgsz
|
||||
gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
|
||||
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
|
||||
# Batch size
|
||||
if self.batch_size == -1:
|
||||
if RANK == -1: # single-GPU only, estimate best batch size
|
||||
self.args.batch = self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp)
|
||||
else:
|
||||
SyntaxError('batch=-1 to use AutoBatch is only available in Single-GPU training. '
|
||||
'Please pass a valid batch size value for Multi-GPU DDP training, i.e. batch=16')
|
||||
|
||||
# Dataloaders
|
||||
batch_size = self.batch_size // max(world_size, 1)
|
||||
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train')
|
||||
if RANK in (-1, 0):
|
||||
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val')
|
||||
self.validator = self.get_validator()
|
||||
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')
|
||||
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
|
||||
self.ema = ModelEMA(self.model)
|
||||
if self.args.plots:
|
||||
self.plot_training_labels()
|
||||
|
||||
# Optimizer
|
||||
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
|
||||
weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
|
||||
iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
|
||||
self.optimizer = self.build_optimizer(model=self.model,
|
||||
name=self.args.optimizer,
|
||||
lr=self.args.lr0,
|
||||
momentum=self.args.momentum,
|
||||
decay=weight_decay,
|
||||
iterations=iterations)
|
||||
# Scheduler
|
||||
if self.args.cos_lr:
|
||||
self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
|
||||
else:
|
||||
self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
|
||||
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
||||
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
|
||||
self.resume_training(ckpt)
|
||||
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
||||
self.run_callbacks('on_pretrain_routine_end')
|
||||
|
||||
def _do_train(self, world_size=1):
|
||||
"""Train completed, evaluate and plot if specified by arguments."""
|
||||
if world_size > 1:
|
||||
self._setup_ddp(world_size)
|
||||
|
||||
self._setup_train(world_size)
|
||||
|
||||
self.epoch_time = None
|
||||
self.epoch_time_start = time.time()
|
||||
self.train_time_start = time.time()
|
||||
nb = len(self.train_loader) # number of batches
|
||||
nw = max(round(self.args.warmup_epochs *
|
||||
nb), 100) if self.args.warmup_epochs > 0 else -1 # number of warmup iterations
|
||||
last_opt_step = -1
|
||||
self.run_callbacks('on_train_start')
|
||||
LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
|
||||
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
|
||||
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
||||
f'Starting training for {self.epochs} epochs...')
|
||||
if self.args.close_mosaic:
|
||||
base_idx = (self.epochs - self.args.close_mosaic) * nb
|
||||
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
|
||||
epoch = self.epochs # predefine for resume fully trained model edge cases
|
||||
for epoch in range(self.start_epoch, self.epochs):
|
||||
self.epoch = epoch
|
||||
self.run_callbacks('on_train_epoch_start')
|
||||
self.model.train()
|
||||
if RANK != -1:
|
||||
self.train_loader.sampler.set_epoch(epoch)
|
||||
pbar = enumerate(self.train_loader)
|
||||
# Update dataloader attributes (optional)
|
||||
if epoch == (self.epochs - self.args.close_mosaic):
|
||||
LOGGER.info('Closing dataloader mosaic')
|
||||
if hasattr(self.train_loader.dataset, 'mosaic'):
|
||||
self.train_loader.dataset.mosaic = False
|
||||
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
||||
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
||||
self.train_loader.reset()
|
||||
|
||||
if RANK in (-1, 0):
|
||||
LOGGER.info(self.progress_string())
|
||||
pbar = tqdm(enumerate(self.train_loader), total=nb, bar_format=TQDM_BAR_FORMAT)
|
||||
self.tloss = None
|
||||
self.optimizer.zero_grad()
|
||||
for i, batch in pbar:
|
||||
self.run_callbacks('on_train_batch_start')
|
||||
# Warmup
|
||||
ni = i + nb * epoch
|
||||
if ni <= nw:
|
||||
xi = [0, nw] # x interp
|
||||
self.accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round())
|
||||
for j, x in enumerate(self.optimizer.param_groups):
|
||||
# Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
|
||||
x['lr'] = np.interp(
|
||||
ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x['initial_lr'] * self.lf(epoch)])
|
||||
if 'momentum' in x:
|
||||
x['momentum'] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
|
||||
|
||||
# Forward
|
||||
with torch.cuda.amp.autocast(self.amp):
|
||||
batch = self.preprocess_batch(batch)
|
||||
self.loss, self.loss_items = self.model(batch)
|
||||
if RANK != -1:
|
||||
self.loss *= world_size
|
||||
self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
|
||||
else self.loss_items
|
||||
|
||||
# Backward
|
||||
self.scaler.scale(self.loss).backward()
|
||||
|
||||
# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
|
||||
if ni - last_opt_step >= self.accumulate:
|
||||
self.optimizer_step()
|
||||
last_opt_step = ni
|
||||
|
||||
# Log
|
||||
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
|
||||
loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1
|
||||
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
|
||||
if RANK in (-1, 0):
|
||||
pbar.set_description(
|
||||
('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
|
||||
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch['cls'].shape[0], batch['img'].shape[-1]))
|
||||
self.run_callbacks('on_batch_end')
|
||||
if self.args.plots and ni in self.plot_idx:
|
||||
self.plot_training_samples(batch, ni)
|
||||
|
||||
self.run_callbacks('on_train_batch_end')
|
||||
|
||||
self.lr = {f'lr/pg{ir}': x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
||||
|
||||
self.scheduler.step()
|
||||
self.run_callbacks('on_train_epoch_end')
|
||||
|
||||
if RANK in (-1, 0):
|
||||
|
||||
# Validation
|
||||
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
|
||||
final_epoch = (epoch + 1 == self.epochs) or self.stopper.possible_stop
|
||||
|
||||
if self.args.val or final_epoch:
|
||||
self.metrics, self.fitness = self.validate()
|
||||
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
|
||||
self.stop = self.stopper(epoch + 1, self.fitness)
|
||||
|
||||
# Save model
|
||||
if self.args.save or (epoch + 1 == self.epochs):
|
||||
self.save_model()
|
||||
self.run_callbacks('on_model_save')
|
||||
|
||||
tnow = time.time()
|
||||
self.epoch_time = tnow - self.epoch_time_start
|
||||
self.epoch_time_start = tnow
|
||||
self.run_callbacks('on_fit_epoch_end')
|
||||
torch.cuda.empty_cache() # clears GPU vRAM at end of epoch, can help with out of memory errors
|
||||
|
||||
# Early Stopping
|
||||
if RANK != -1: # if DDP training
|
||||
broadcast_list = [self.stop if RANK == 0 else None]
|
||||
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
|
||||
if RANK != 0:
|
||||
self.stop = broadcast_list[0]
|
||||
if self.stop:
|
||||
break # must break all DDP ranks
|
||||
|
||||
if RANK in (-1, 0):
|
||||
# Do final val with best.pt
|
||||
LOGGER.info(f'\n{epoch - self.start_epoch + 1} epochs completed in '
|
||||
f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
|
||||
self.final_eval()
|
||||
if self.args.plots:
|
||||
self.plot_metrics()
|
||||
self.run_callbacks('on_train_end')
|
||||
torch.cuda.empty_cache()
|
||||
self.run_callbacks('teardown')
|
||||
|
||||
def save_model(self):
|
||||
"""Save model checkpoints based on various conditions."""
|
||||
ckpt = {
|
||||
'epoch': self.epoch,
|
||||
'best_fitness': self.best_fitness,
|
||||
'model': deepcopy(de_parallel(self.model)).half(),
|
||||
'ema': deepcopy(self.ema.ema).half(),
|
||||
'updates': self.ema.updates,
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'train_args': vars(self.args), # save as dict
|
||||
'date': datetime.now().isoformat(),
|
||||
'version': __version__}
|
||||
|
||||
# Use dill (if exists) to serialize the lambda functions where pickle does not do this
|
||||
try:
|
||||
import dill as pickle
|
||||
except ImportError:
|
||||
import pickle
|
||||
|
||||
# Save last, best and delete
|
||||
torch.save(ckpt, self.last, pickle_module=pickle)
|
||||
if self.best_fitness == self.fitness:
|
||||
torch.save(ckpt, self.best, pickle_module=pickle)
|
||||
if (self.epoch > 0) and (self.save_period > 0) and (self.epoch % self.save_period == 0):
|
||||
torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt', pickle_module=pickle)
|
||||
del ckpt
|
||||
|
||||
@staticmethod
|
||||
def get_dataset(data):
|
||||
"""
|
||||
Get train, val path from data dict if it exists. Returns None if data format is not recognized.
|
||||
"""
|
||||
return data['train'], data.get('val') or data.get('test')
|
||||
|
||||
def setup_model(self):
|
||||
"""
|
||||
load/create/download model for any task.
|
||||
"""
|
||||
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
||||
return
|
||||
|
||||
model, weights = self.model, None
|
||||
ckpt = None
|
||||
if str(model).endswith('.pt'):
|
||||
weights, ckpt = attempt_load_one_weight(model)
|
||||
cfg = ckpt['model'].yaml
|
||||
else:
|
||||
cfg = model
|
||||
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
|
||||
return ckpt
|
||||
|
||||
def optimizer_step(self):
|
||||
"""Perform a single step of the training optimizer with gradient clipping and EMA update."""
|
||||
self.scaler.unscale_(self.optimizer) # unscale gradients
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
self.optimizer.zero_grad()
|
||||
if self.ema:
|
||||
self.ema.update(self.model)
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
"""
|
||||
Allows custom preprocessing model inputs and ground truths depending on task type.
|
||||
"""
|
||||
return batch
|
||||
|
||||
def validate(self):
|
||||
"""
|
||||
Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
|
||||
"""
|
||||
metrics = self.validator(self)
|
||||
fitness = metrics.pop('fitness', -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
||||
if not self.best_fitness or self.best_fitness < fitness:
|
||||
self.best_fitness = fitness
|
||||
return metrics, fitness
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""Get model and raise NotImplementedError for loading cfg files."""
|
||||
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
||||
|
||||
def get_validator(self):
|
||||
"""Returns a NotImplementedError when the get_validator function is called."""
|
||||
raise NotImplementedError('get_validator function not implemented in trainer')
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
|
||||
"""
|
||||
Returns dataloader derived from torch.data.Dataloader.
|
||||
"""
|
||||
raise NotImplementedError('get_dataloader function not implemented in trainer')
|
||||
|
||||
def build_dataset(self, img_path, mode='train', batch=None):
|
||||
"""Build dataset"""
|
||||
raise NotImplementedError('build_dataset function not implemented in trainer')
|
||||
|
||||
def label_loss_items(self, loss_items=None, prefix='train'):
|
||||
"""
|
||||
Returns a loss dict with labelled training loss items tensor
|
||||
"""
|
||||
# Not needed for classification but necessary for segmentation & detection
|
||||
return {'loss': loss_items} if loss_items is not None else ['loss']
|
||||
|
||||
def set_model_attributes(self):
|
||||
"""
|
||||
To set or update model parameters before training.
|
||||
"""
|
||||
self.model.names = self.data['names']
|
||||
|
||||
def build_targets(self, preds, targets):
|
||||
"""Builds target tensors for training YOLO model."""
|
||||
pass
|
||||
|
||||
def progress_string(self):
|
||||
"""Returns a string describing training progress."""
|
||||
return ''
|
||||
|
||||
# TODO: may need to put these following functions into callback
|
||||
def plot_training_samples(self, batch, ni):
|
||||
"""Plots training samples during YOLOv5 training."""
|
||||
pass
|
||||
|
||||
def plot_training_labels(self):
|
||||
"""Plots training labels for YOLO model."""
|
||||
pass
|
||||
|
||||
def save_metrics(self, metrics):
|
||||
"""Saves training metrics to a CSV file."""
|
||||
keys, vals = list(metrics.keys()), list(metrics.values())
|
||||
n = len(metrics) + 1 # number of cols
|
||||
s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header
|
||||
with open(self.csv, 'a') as f:
|
||||
f.write(s + ('%23.5g,' * n % tuple([self.epoch] + vals)).rstrip(',') + '\n')
|
||||
|
||||
def plot_metrics(self):
|
||||
"""Plot and display metrics visually."""
|
||||
pass
|
||||
|
||||
def on_plot(self, name, data=None):
|
||||
"""Registers plots (e.g. to be consumed in callbacks)"""
|
||||
self.plots[name] = {'data': data, 'timestamp': time.time()}
|
||||
|
||||
def final_eval(self):
|
||||
"""Performs final evaluation and validation for object detection YOLO model."""
|
||||
for f in self.last, self.best:
|
||||
if f.exists():
|
||||
strip_optimizer(f) # strip optimizers
|
||||
if f is self.best:
|
||||
LOGGER.info(f'\nValidating {f}...')
|
||||
self.metrics = self.validator(model=f)
|
||||
self.metrics.pop('fitness', None)
|
||||
self.run_callbacks('on_fit_epoch_end')
|
||||
|
||||
def check_resume(self):
|
||||
"""Check if resume checkpoint exists and update arguments accordingly."""
|
||||
resume = self.args.resume
|
||||
if resume:
|
||||
try:
|
||||
exists = isinstance(resume, (str, Path)) and Path(resume).exists()
|
||||
last = Path(check_file(resume) if exists else get_latest_run())
|
||||
|
||||
# Check that resume data YAML exists, otherwise strip to force re-download of dataset
|
||||
ckpt_args = attempt_load_weights(last).args
|
||||
if not Path(ckpt_args['data']).exists():
|
||||
ckpt_args['data'] = self.args.data
|
||||
|
||||
self.args = get_cfg(ckpt_args)
|
||||
self.args.model, resume = str(last), True # reinstate
|
||||
except Exception as e:
|
||||
raise FileNotFoundError('Resume checkpoint not found. Please pass a valid checkpoint to resume from, '
|
||||
"i.e. 'yolo train resume model=path/to/last.pt'") from e
|
||||
self.resume = resume
|
||||
|
||||
def resume_training(self, ckpt):
|
||||
"""Resume YOLO training from given epoch and best fitness."""
|
||||
if ckpt is None:
|
||||
return
|
||||
best_fitness = 0.0
|
||||
start_epoch = ckpt['epoch'] + 1
|
||||
if ckpt['optimizer'] is not None:
|
||||
self.optimizer.load_state_dict(ckpt['optimizer']) # optimizer
|
||||
best_fitness = ckpt['best_fitness']
|
||||
if self.ema and ckpt.get('ema'):
|
||||
self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
|
||||
self.ema.updates = ckpt['updates']
|
||||
if self.resume:
|
||||
assert start_epoch > 0, \
|
||||
f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \
|
||||
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
|
||||
LOGGER.info(
|
||||
f'Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs')
|
||||
if self.epochs < start_epoch:
|
||||
LOGGER.info(
|
||||
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs.")
|
||||
self.epochs += ckpt['epoch'] # finetune additional epochs
|
||||
self.best_fitness = best_fitness
|
||||
self.start_epoch = start_epoch
|
||||
if start_epoch > (self.epochs - self.args.close_mosaic):
|
||||
LOGGER.info('Closing dataloader mosaic')
|
||||
if hasattr(self.train_loader.dataset, 'mosaic'):
|
||||
self.train_loader.dataset.mosaic = False
|
||||
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
||||
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
||||
|
||||
def build_optimizer(self, model, name='auto', lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
|
||||
"""
|
||||
Constructs an optimizer for the given model, based on the specified optimizer name, learning rate,
|
||||
momentum, weight decay, and number of iterations.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model for which to build an optimizer.
|
||||
name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
|
||||
based on the number of iterations. Default: 'auto'.
|
||||
lr (float, optional): The learning rate for the optimizer. Default: 0.001.
|
||||
momentum (float, optional): The momentum factor for the optimizer. Default: 0.9.
|
||||
decay (float, optional): The weight decay for the optimizer. Default: 1e-5.
|
||||
iterations (float, optional): The number of iterations, which determines the optimizer if
|
||||
name is 'auto'. Default: 1e5.
|
||||
|
||||
Returns:
|
||||
(torch.optim.Optimizer): The constructed optimizer.
|
||||
"""
|
||||
|
||||
g = [], [], [] # optimizer parameter groups
|
||||
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
||||
if name == 'auto':
|
||||
nc = getattr(model, 'nc', 10) # number of classes
|
||||
lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
|
||||
name, lr, momentum = ('SGD', 0.01, 0.9) if iterations > 10000 else ('AdamW', lr_fit, 0.9)
|
||||
self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
|
||||
|
||||
for module_name, module in model.named_modules():
|
||||
for param_name, param in module.named_parameters(recurse=False):
|
||||
fullname = f'{module_name}.{param_name}' if module_name else param_name
|
||||
if 'bias' in fullname: # bias (no decay)
|
||||
g[2].append(param)
|
||||
elif isinstance(module, bn): # weight (no decay)
|
||||
g[1].append(param)
|
||||
else: # weight (with decay)
|
||||
g[0].append(param)
|
||||
|
||||
if name in ('Adam', 'Adamax', 'AdamW', 'NAdam', 'RAdam'):
|
||||
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
|
||||
elif name == 'RMSProp':
|
||||
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
|
||||
elif name == 'SGD':
|
||||
optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Optimizer '{name}' not found in list of available optimizers "
|
||||
f'[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto].'
|
||||
'To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics.')
|
||||
|
||||
optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
|
||||
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
|
||||
LOGGER.info(
|
||||
f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
|
||||
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)')
|
||||
return optimizer
|
@ -1,278 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Check a model's accuracy on a test or val split of a dataset
|
||||
|
||||
Usage:
|
||||
$ yolo mode=val model=yolov8n.pt data=coco128.yaml imgsz=640
|
||||
|
||||
Usage - formats:
|
||||
$ yolo mode=val model=yolov8n.pt # PyTorch
|
||||
yolov8n.torchscript # TorchScript
|
||||
yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
|
||||
yolov8n_openvino_model # OpenVINO
|
||||
yolov8n.engine # TensorRT
|
||||
yolov8n.mlmodel # CoreML (macOS-only)
|
||||
yolov8n_saved_model # TensorFlow SavedModel
|
||||
yolov8n.pb # TensorFlow GraphDef
|
||||
yolov8n.tflite # TensorFlow Lite
|
||||
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
|
||||
yolov8n_paddle_model # PaddlePaddle
|
||||
"""
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.nn.autobackend import AutoBackend
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis
|
||||
from ultralytics.yolo.utils.checks import check_imgsz
|
||||
from ultralytics.yolo.utils.files import increment_path
|
||||
from ultralytics.yolo.utils.ops import Profile
|
||||
from ultralytics.yolo.utils.torch_utils import de_parallel, select_device, smart_inference_mode
|
||||
|
||||
|
||||
class BaseValidator:
|
||||
"""
|
||||
BaseValidator
|
||||
|
||||
A base class for creating validators.
|
||||
|
||||
Attributes:
|
||||
dataloader (DataLoader): Dataloader to use for validation.
|
||||
pbar (tqdm): Progress bar to update during validation.
|
||||
args (SimpleNamespace): Configuration for the validator.
|
||||
model (nn.Module): Model to validate.
|
||||
data (dict): Data dictionary.
|
||||
device (torch.device): Device to use for validation.
|
||||
batch_i (int): Current batch index.
|
||||
training (bool): Whether the model is in training mode.
|
||||
speed (float): Batch processing speed in seconds.
|
||||
jdict (dict): Dictionary to store validation results.
|
||||
save_dir (Path): Directory to save results.
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||
"""
|
||||
Initializes a BaseValidator instance.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
|
||||
save_dir (Path): Directory to save results.
|
||||
pbar (tqdm.tqdm): Progress bar for displaying progress.
|
||||
args (SimpleNamespace): Configuration for the validator.
|
||||
"""
|
||||
self.dataloader = dataloader
|
||||
self.pbar = pbar
|
||||
self.args = args or get_cfg(DEFAULT_CFG)
|
||||
self.model = None
|
||||
self.data = None
|
||||
self.device = None
|
||||
self.batch_i = None
|
||||
self.training = True
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
self.jdict = None
|
||||
|
||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||
name = self.args.name or f'{self.args.mode}'
|
||||
self.save_dir = save_dir or increment_path(Path(project) / name,
|
||||
exist_ok=self.args.exist_ok if RANK in (-1, 0) else True)
|
||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if self.args.conf is None:
|
||||
self.args.conf = 0.001 # default conf=0.001
|
||||
|
||||
self.plots = {}
|
||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||
|
||||
@smart_inference_mode()
|
||||
def __call__(self, trainer=None, model=None):
|
||||
"""
|
||||
Supports validation of a pre-trained model if passed or a model being trained
|
||||
if trainer is passed (trainer gets priority).
|
||||
"""
|
||||
self.training = trainer is not None
|
||||
if self.training:
|
||||
self.device = trainer.device
|
||||
self.data = trainer.data
|
||||
model = trainer.ema.ema or trainer.model
|
||||
self.args.half = self.device.type != 'cpu' # force FP16 val during training
|
||||
model = model.half() if self.args.half else model.float()
|
||||
self.model = model
|
||||
self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
|
||||
self.args.plots = trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
|
||||
model.eval()
|
||||
else:
|
||||
callbacks.add_integration_callbacks(self)
|
||||
self.run_callbacks('on_val_start')
|
||||
assert model is not None, 'Either trainer or model is needed for validation'
|
||||
model = AutoBackend(model,
|
||||
device=select_device(self.args.device, self.args.batch),
|
||||
dnn=self.args.dnn,
|
||||
data=self.args.data,
|
||||
fp16=self.args.half)
|
||||
self.model = model
|
||||
self.device = model.device # update device
|
||||
self.args.half = model.fp16 # update half
|
||||
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
|
||||
imgsz = check_imgsz(self.args.imgsz, stride=stride)
|
||||
if engine:
|
||||
self.args.batch = model.batch_size
|
||||
elif not pt and not jit:
|
||||
self.args.batch = 1 # export.py models default to batch-size 1
|
||||
LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
|
||||
|
||||
if isinstance(self.args.data, str) and self.args.data.endswith('.yaml'):
|
||||
self.data = check_det_dataset(self.args.data)
|
||||
elif self.args.task == 'classify':
|
||||
self.data = check_cls_dataset(self.args.data, split=self.args.split)
|
||||
else:
|
||||
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
|
||||
|
||||
if self.device.type == 'cpu':
|
||||
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
||||
if not pt:
|
||||
self.args.rect = False
|
||||
self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)
|
||||
|
||||
model.eval()
|
||||
model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup
|
||||
|
||||
dt = Profile(), Profile(), Profile(), Profile()
|
||||
n_batches = len(self.dataloader)
|
||||
desc = self.get_desc()
|
||||
# NOTE: keeping `not self.training` in tqdm will eliminate pbar after segmentation evaluation during training,
|
||||
# which may affect classification task since this arg is in yolov5/classify/val.py.
|
||||
# bar = tqdm(self.dataloader, desc, n_batches, not self.training, bar_format=TQDM_BAR_FORMAT)
|
||||
bar = tqdm(self.dataloader, desc, n_batches, bar_format=TQDM_BAR_FORMAT)
|
||||
self.init_metrics(de_parallel(model))
|
||||
self.jdict = [] # empty before each val
|
||||
for batch_i, batch in enumerate(bar):
|
||||
self.run_callbacks('on_val_batch_start')
|
||||
self.batch_i = batch_i
|
||||
# Preprocess
|
||||
with dt[0]:
|
||||
batch = self.preprocess(batch)
|
||||
|
||||
# Inference
|
||||
with dt[1]:
|
||||
preds = model(batch['img'], augment=self.args.augment)
|
||||
|
||||
# Loss
|
||||
with dt[2]:
|
||||
if self.training:
|
||||
self.loss += model.loss(batch, preds)[1]
|
||||
|
||||
# Postprocess
|
||||
with dt[3]:
|
||||
preds = self.postprocess(preds)
|
||||
|
||||
self.update_metrics(preds, batch)
|
||||
if self.args.plots and batch_i < 3:
|
||||
self.plot_val_samples(batch, batch_i)
|
||||
self.plot_predictions(batch, preds, batch_i)
|
||||
|
||||
self.run_callbacks('on_val_batch_end')
|
||||
stats = self.get_stats()
|
||||
self.check_stats(stats)
|
||||
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1E3 for x in dt)))
|
||||
self.finalize_metrics()
|
||||
self.print_results()
|
||||
self.run_callbacks('on_val_end')
|
||||
if self.training:
|
||||
model.float()
|
||||
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')}
|
||||
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
|
||||
else:
|
||||
LOGGER.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' %
|
||||
tuple(self.speed.values()))
|
||||
if self.args.save_json and self.jdict:
|
||||
with open(str(self.save_dir / 'predictions.json'), 'w') as f:
|
||||
LOGGER.info(f'Saving {f.name}...')
|
||||
json.dump(self.jdict, f) # flatten and save
|
||||
stats = self.eval_json(stats) # update stats
|
||||
if self.args.plots or self.args.save_json:
|
||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
|
||||
return stats
|
||||
|
||||
def add_callback(self, event: str, callback):
|
||||
"""Appends the given callback."""
|
||||
self.callbacks[event].append(callback)
|
||||
|
||||
def run_callbacks(self, event: str):
|
||||
"""Runs all callbacks associated with a specified event."""
|
||||
for callback in self.callbacks.get(event, []):
|
||||
callback(self)
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size):
|
||||
"""Get data loader from dataset path and batch size."""
|
||||
raise NotImplementedError('get_dataloader function not implemented for this validator')
|
||||
|
||||
def build_dataset(self, img_path):
|
||||
"""Build dataset"""
|
||||
raise NotImplementedError('build_dataset function not implemented in validator')
|
||||
|
||||
def preprocess(self, batch):
|
||||
"""Preprocesses an input batch."""
|
||||
return batch
|
||||
|
||||
def postprocess(self, preds):
|
||||
"""Describes and summarizes the purpose of 'postprocess()' but no details mentioned."""
|
||||
return preds
|
||||
|
||||
def init_metrics(self, model):
|
||||
"""Initialize performance metrics for the YOLO model."""
|
||||
pass
|
||||
|
||||
def update_metrics(self, preds, batch):
|
||||
"""Updates metrics based on predictions and batch."""
|
||||
pass
|
||||
|
||||
def finalize_metrics(self, *args, **kwargs):
|
||||
"""Finalizes and returns all metrics."""
|
||||
pass
|
||||
|
||||
def get_stats(self):
|
||||
"""Returns statistics about the model's performance."""
|
||||
return {}
|
||||
|
||||
def check_stats(self, stats):
|
||||
"""Checks statistics."""
|
||||
pass
|
||||
|
||||
def print_results(self):
|
||||
"""Prints the results of the model's predictions."""
|
||||
pass
|
||||
|
||||
def get_desc(self):
|
||||
"""Get description of the YOLO model."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def metric_keys(self):
|
||||
"""Returns the metric keys used in YOLO training/validation."""
|
||||
return []
|
||||
|
||||
def on_plot(self, name, data=None):
|
||||
"""Registers plots (e.g. to be consumed in callbacks)"""
|
||||
self.plots[name] = {'data': data, 'timestamp': time.time()}
|
||||
|
||||
# TODO: may need to put these following functions into callback
|
||||
def plot_val_samples(self, batch, ni):
|
||||
"""Plots validation samples during training."""
|
||||
pass
|
||||
|
||||
def plot_predictions(self, batch, preds, ni):
|
||||
"""Plots YOLO model predictions on batch images."""
|
||||
pass
|
||||
|
||||
def pred_to_json(self, preds, batch):
|
||||
"""Convert predictions to JSON format."""
|
||||
pass
|
||||
|
||||
def eval_json(self, stats):
|
||||
"""Evaluate and return JSON format of prediction statistics."""
|
||||
pass
|
@ -1,8 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from .model import FastSAM
|
||||
from .predict import FastSAMPredictor
|
||||
from .prompt import FastSAMPrompt
|
||||
from .val import FastSAMValidator
|
||||
|
||||
__all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMValidator'
|
@ -1,111 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
FastSAM model interface.
|
||||
|
||||
Usage - Predict:
|
||||
from ultralytics import FastSAM
|
||||
|
||||
model = FastSAM('last.pt')
|
||||
results = model.predict('ultralytics/assets/bus.jpg')
|
||||
"""
|
||||
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.engine.model import YOLO
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir
|
||||
from ultralytics.yolo.utils.checks import check_imgsz
|
||||
|
||||
from ...yolo.utils.torch_utils import model_info, smart_inference_mode
|
||||
from .predict import FastSAMPredictor
|
||||
|
||||
|
||||
class FastSAM(YOLO):
|
||||
|
||||
def __init__(self, model='FastSAM-x.pt'):
|
||||
"""Call the __init__ method of the parent class (YOLO) with the updated default model"""
|
||||
if model == 'FastSAM.pt':
|
||||
model = 'FastSAM-x.pt'
|
||||
super().__init__(model=model)
|
||||
# any additional initialization code for FastSAM
|
||||
|
||||
@smart_inference_mode()
|
||||
def predict(self, source=None, stream=False, **kwargs):
|
||||
"""
|
||||
Perform prediction using the YOLO model.
|
||||
|
||||
Args:
|
||||
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
|
||||
Accepts all source types accepted by the YOLO model.
|
||||
stream (bool): Whether to stream the predictions or not. Defaults to False.
|
||||
**kwargs : Additional keyword arguments passed to the predictor.
|
||||
Check the 'configuration' section in the documentation for all available options.
|
||||
|
||||
Returns:
|
||||
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
|
||||
"""
|
||||
if source is None:
|
||||
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
||||
overrides = self.overrides.copy()
|
||||
overrides['conf'] = 0.25
|
||||
overrides.update(kwargs) # prefer kwargs
|
||||
overrides['mode'] = kwargs.get('mode', 'predict')
|
||||
assert overrides['mode'] in ['track', 'predict']
|
||||
overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
|
||||
self.predictor = FastSAMPredictor(overrides=overrides)
|
||||
self.predictor.setup_model(model=self.model, verbose=False)
|
||||
|
||||
return self.predictor(source, stream=stream)
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""Function trains models but raises an error as FastSAM models do not support training."""
|
||||
raise NotImplementedError("FastSAM models don't support training")
|
||||
|
||||
def val(self, **kwargs):
|
||||
"""Run validation given dataset."""
|
||||
overrides = dict(task='segment', mode='val')
|
||||
overrides.update(kwargs) # prefer kwargs
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
|
||||
validator = FastSAM(args=args)
|
||||
validator(model=self.model)
|
||||
self.metrics = validator.metrics
|
||||
return validator.metrics
|
||||
|
||||
@smart_inference_mode()
|
||||
def export(self, **kwargs):
|
||||
"""
|
||||
Export model.
|
||||
|
||||
Args:
|
||||
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
|
||||
"""
|
||||
overrides = dict(task='detect')
|
||||
overrides.update(kwargs)
|
||||
overrides['mode'] = 'export'
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
args.task = self.task
|
||||
if args.imgsz == DEFAULT_CFG.imgsz:
|
||||
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
||||
if args.batch == DEFAULT_CFG.batch:
|
||||
args.batch = 1 # default to 1 if not modified
|
||||
return Exporter(overrides=args)(model=self.model)
|
||||
|
||||
def info(self, detailed=False, verbose=True):
|
||||
"""
|
||||
Logs model info.
|
||||
|
||||
Args:
|
||||
detailed (bool): Show detailed information about model.
|
||||
verbose (bool): Controls verbosity.
|
||||
"""
|
||||
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
|
||||
|
||||
def __call__(self, source=None, stream=False, **kwargs):
|
||||
"""Calls the 'predict' function with given arguments to perform object detection."""
|
||||
return self.predict(source, stream, **kwargs)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
"""Raises error if object has no requested attribute."""
|
||||
name = self.__class__.__name__
|
||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
@ -1,53 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.engine.results import Results
|
||||
from ultralytics.yolo.fastsam.utils import bbox_iou
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, ops
|
||||
from ultralytics.yolo.v8.detect.predict import DetectionPredictor
|
||||
|
||||
|
||||
class FastSAMPredictor(DetectionPredictor):
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.task = 'segment'
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""TODO: filter by classes."""
|
||||
p = ops.non_max_suppression(preds[0],
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
agnostic=self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
nc=len(self.model.names),
|
||||
classes=self.args.classes)
|
||||
full_box = torch.zeros_like(p[0][0])
|
||||
full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
|
||||
full_box = full_box.view(1, -1)
|
||||
critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])
|
||||
if critical_iou_index.numel() != 0:
|
||||
full_box[0][4] = p[0][critical_iou_index][:, 4]
|
||||
full_box[0][6:] = p[0][critical_iou_index][:, 6:]
|
||||
p[0][critical_iou_index] = full_box
|
||||
results = []
|
||||
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
|
||||
for i, pred in enumerate(p):
|
||||
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
|
||||
path = self.batch[0]
|
||||
img_path = path[i] if isinstance(path, list) else path
|
||||
if not len(pred): # save empty boxes
|
||||
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6]))
|
||||
continue
|
||||
if self.args.retina_masks:
|
||||
if not isinstance(orig_imgs, torch.Tensor):
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||
masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
|
||||
else:
|
||||
masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
|
||||
if not isinstance(orig_imgs, torch.Tensor):
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||
results.append(
|
||||
Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
|
||||
return results
|
@ -1,406 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class FastSAMPrompt:
|
||||
|
||||
def __init__(self, img_path, results, device='cuda') -> None:
|
||||
# self.img_path = img_path
|
||||
self.device = device
|
||||
self.results = results
|
||||
self.img_path = img_path
|
||||
self.ori_img = cv2.imread(img_path)
|
||||
|
||||
# Import and assign clip
|
||||
try:
|
||||
import clip # for linear_assignment
|
||||
except ImportError:
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source
|
||||
import clip
|
||||
self.clip = clip
|
||||
|
||||
@staticmethod
|
||||
def _segment_image(image, bbox):
|
||||
image_array = np.array(image)
|
||||
segmented_image_array = np.zeros_like(image_array)
|
||||
x1, y1, x2, y2 = bbox
|
||||
segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
|
||||
segmented_image = Image.fromarray(segmented_image_array)
|
||||
black_image = Image.new('RGB', image.size, (255, 255, 255))
|
||||
# transparency_mask = np.zeros_like((), dtype=np.uint8)
|
||||
transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)
|
||||
transparency_mask[y1:y2, x1:x2] = 255
|
||||
transparency_mask_image = Image.fromarray(transparency_mask, mode='L')
|
||||
black_image.paste(segmented_image, mask=transparency_mask_image)
|
||||
return black_image
|
||||
|
||||
@staticmethod
|
||||
def _format_results(result, filter=0):
|
||||
annotations = []
|
||||
n = len(result.masks.data)
|
||||
for i in range(n):
|
||||
mask = result.masks.data[i] == 1.0
|
||||
|
||||
if torch.sum(mask) < filter:
|
||||
continue
|
||||
annotation = {
|
||||
'id': i,
|
||||
'segmentation': mask.cpu().numpy(),
|
||||
'bbox': result.boxes.data[i],
|
||||
'score': result.boxes.conf[i]}
|
||||
annotation['area'] = annotation['segmentation'].sum()
|
||||
annotations.append(annotation)
|
||||
return annotations
|
||||
|
||||
@staticmethod
|
||||
def filter_masks(annotations): # filter the overlap mask
|
||||
annotations.sort(key=lambda x: x['area'], reverse=True)
|
||||
to_remove = set()
|
||||
for i in range(len(annotations)):
|
||||
a = annotations[i]
|
||||
for j in range(i + 1, len(annotations)):
|
||||
b = annotations[j]
|
||||
if i != j and j not in to_remove and b['area'] < a['area'] and \
|
||||
(a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8:
|
||||
to_remove.add(j)
|
||||
|
||||
return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
|
||||
|
||||
@staticmethod
|
||||
def _get_bbox_from_mask(mask):
|
||||
mask = mask.astype(np.uint8)
|
||||
contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
x1, y1, w, h = cv2.boundingRect(contours[0])
|
||||
x2, y2 = x1 + w, y1 + h
|
||||
if len(contours) > 1:
|
||||
for b in contours:
|
||||
x_t, y_t, w_t, h_t = cv2.boundingRect(b)
|
||||
# 将多个bbox合并成一个
|
||||
x1 = min(x1, x_t)
|
||||
y1 = min(y1, y_t)
|
||||
x2 = max(x2, x_t + w_t)
|
||||
y2 = max(y2, y_t + h_t)
|
||||
h = y2 - y1
|
||||
w = x2 - x1
|
||||
return [x1, y1, x2, y2]
|
||||
|
||||
def plot(self,
|
||||
annotations,
|
||||
output,
|
||||
bbox=None,
|
||||
points=None,
|
||||
point_label=None,
|
||||
mask_random_color=True,
|
||||
better_quality=True,
|
||||
retina=False,
|
||||
withContours=True):
|
||||
if isinstance(annotations[0], dict):
|
||||
annotations = [annotation['segmentation'] for annotation in annotations]
|
||||
result_name = os.path.basename(self.img_path)
|
||||
image = self.ori_img
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
original_h = image.shape[0]
|
||||
original_w = image.shape[1]
|
||||
# for macOS only
|
||||
# plt.switch_backend('TkAgg')
|
||||
plt.figure(figsize=(original_w / 100, original_h / 100))
|
||||
# Add subplot with no margin.
|
||||
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
||||
plt.margins(0, 0)
|
||||
plt.gca().xaxis.set_major_locator(plt.NullLocator())
|
||||
plt.gca().yaxis.set_major_locator(plt.NullLocator())
|
||||
|
||||
plt.imshow(image)
|
||||
if better_quality:
|
||||
if isinstance(annotations[0], torch.Tensor):
|
||||
annotations = np.array(annotations.cpu())
|
||||
for i, mask in enumerate(annotations):
|
||||
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
|
||||
annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
|
||||
if self.device == 'cpu':
|
||||
annotations = np.array(annotations)
|
||||
self.fast_show_mask(
|
||||
annotations,
|
||||
plt.gca(),
|
||||
random_color=mask_random_color,
|
||||
bbox=bbox,
|
||||
points=points,
|
||||
pointlabel=point_label,
|
||||
retinamask=retina,
|
||||
target_height=original_h,
|
||||
target_width=original_w,
|
||||
)
|
||||
else:
|
||||
if isinstance(annotations[0], np.ndarray):
|
||||
annotations = torch.from_numpy(annotations)
|
||||
self.fast_show_mask_gpu(
|
||||
annotations,
|
||||
plt.gca(),
|
||||
random_color=mask_random_color,
|
||||
bbox=bbox,
|
||||
points=points,
|
||||
pointlabel=point_label,
|
||||
retinamask=retina,
|
||||
target_height=original_h,
|
||||
target_width=original_w,
|
||||
)
|
||||
if isinstance(annotations, torch.Tensor):
|
||||
annotations = annotations.cpu().numpy()
|
||||
if withContours:
|
||||
contour_all = []
|
||||
temp = np.zeros((original_h, original_w, 1))
|
||||
for i, mask in enumerate(annotations):
|
||||
if type(mask) == dict:
|
||||
mask = mask['segmentation']
|
||||
annotation = mask.astype(np.uint8)
|
||||
if not retina:
|
||||
annotation = cv2.resize(
|
||||
annotation,
|
||||
(original_w, original_h),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
contours, hierarchy = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
||||
contour_all.extend(iter(contours))
|
||||
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
|
||||
color = np.array([0 / 255, 0 / 255, 1.0, 0.8])
|
||||
contour_mask = temp / 255 * color.reshape(1, 1, -1)
|
||||
plt.imshow(contour_mask)
|
||||
|
||||
save_path = output
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
plt.axis('off')
|
||||
fig = plt.gcf()
|
||||
plt.draw()
|
||||
|
||||
try:
|
||||
buf = fig.canvas.tostring_rgb()
|
||||
except AttributeError:
|
||||
fig.canvas.draw()
|
||||
buf = fig.canvas.tostring_rgb()
|
||||
cols, rows = fig.canvas.get_width_height()
|
||||
img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3)
|
||||
cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
|
||||
|
||||
# CPU post process
|
||||
def fast_show_mask(
|
||||
self,
|
||||
annotation,
|
||||
ax,
|
||||
random_color=False,
|
||||
bbox=None,
|
||||
points=None,
|
||||
pointlabel=None,
|
||||
retinamask=True,
|
||||
target_height=960,
|
||||
target_width=960,
|
||||
):
|
||||
msak_sum = annotation.shape[0]
|
||||
height = annotation.shape[1]
|
||||
weight = annotation.shape[2]
|
||||
# 将annotation 按照面积 排序
|
||||
areas = np.sum(annotation, axis=(1, 2))
|
||||
sorted_indices = np.argsort(areas)
|
||||
annotation = annotation[sorted_indices]
|
||||
|
||||
index = (annotation != 0).argmax(axis=0)
|
||||
if random_color:
|
||||
color = np.random.random((msak_sum, 1, 1, 3))
|
||||
else:
|
||||
color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 1.0])
|
||||
transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
|
||||
visual = np.concatenate([color, transparency], axis=-1)
|
||||
mask_image = np.expand_dims(annotation, -1) * visual
|
||||
|
||||
show = np.zeros((height, weight, 4))
|
||||
h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
|
||||
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
|
||||
# 使用向量化索引更新show的值
|
||||
show[h_indices, w_indices, :] = mask_image[indices]
|
||||
if bbox is not None:
|
||||
x1, y1, x2, y2 = bbox
|
||||
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
|
||||
# draw point
|
||||
if points is not None:
|
||||
plt.scatter(
|
||||
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
|
||||
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
|
||||
s=20,
|
||||
c='y',
|
||||
)
|
||||
plt.scatter(
|
||||
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
|
||||
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
|
||||
s=20,
|
||||
c='m',
|
||||
)
|
||||
|
||||
if not retinamask:
|
||||
show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
|
||||
ax.imshow(show)
|
||||
|
||||
def fast_show_mask_gpu(
|
||||
self,
|
||||
annotation,
|
||||
ax,
|
||||
random_color=False,
|
||||
bbox=None,
|
||||
points=None,
|
||||
pointlabel=None,
|
||||
retinamask=True,
|
||||
target_height=960,
|
||||
target_width=960,
|
||||
):
|
||||
msak_sum = annotation.shape[0]
|
||||
height = annotation.shape[1]
|
||||
weight = annotation.shape[2]
|
||||
areas = torch.sum(annotation, dim=(1, 2))
|
||||
sorted_indices = torch.argsort(areas, descending=False)
|
||||
annotation = annotation[sorted_indices]
|
||||
# 找每个位置第一个非零值下标
|
||||
index = (annotation != 0).to(torch.long).argmax(dim=0)
|
||||
if random_color:
|
||||
color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
|
||||
else:
|
||||
color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([30 / 255, 144 / 255, 1.0]).to(
|
||||
annotation.device)
|
||||
transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
|
||||
visual = torch.cat([color, transparency], dim=-1)
|
||||
mask_image = torch.unsqueeze(annotation, -1) * visual
|
||||
# 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
|
||||
show = torch.zeros((height, weight, 4)).to(annotation.device)
|
||||
h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight), indexing='ij')
|
||||
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
|
||||
# 使用向量化索引更新show的值
|
||||
show[h_indices, w_indices, :] = mask_image[indices]
|
||||
show_cpu = show.cpu().numpy()
|
||||
if bbox is not None:
|
||||
x1, y1, x2, y2 = bbox
|
||||
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
|
||||
# draw point
|
||||
if points is not None:
|
||||
plt.scatter(
|
||||
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
|
||||
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
|
||||
s=20,
|
||||
c='y',
|
||||
)
|
||||
plt.scatter(
|
||||
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
|
||||
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
|
||||
s=20,
|
||||
c='m',
|
||||
)
|
||||
if not retinamask:
|
||||
show_cpu = cv2.resize(show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
|
||||
ax.imshow(show_cpu)
|
||||
|
||||
# clip
|
||||
@torch.no_grad()
|
||||
def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
|
||||
preprocessed_images = [preprocess(image).to(device) for image in elements]
|
||||
tokenized_text = self.clip.tokenize([search_text]).to(device)
|
||||
stacked_images = torch.stack(preprocessed_images)
|
||||
image_features = model.encode_image(stacked_images)
|
||||
text_features = model.encode_text(tokenized_text)
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
probs = 100.0 * image_features @ text_features.T
|
||||
return probs[:, 0].softmax(dim=0)
|
||||
|
||||
def _crop_image(self, format_results):
|
||||
|
||||
image = Image.fromarray(cv2.cvtColor(self.ori_img, cv2.COLOR_BGR2RGB))
|
||||
ori_w, ori_h = image.size
|
||||
annotations = format_results
|
||||
mask_h, mask_w = annotations[0]['segmentation'].shape
|
||||
if ori_w != mask_w or ori_h != mask_h:
|
||||
image = image.resize((mask_w, mask_h))
|
||||
cropped_boxes = []
|
||||
cropped_images = []
|
||||
not_crop = []
|
||||
filter_id = []
|
||||
# annotations, _ = filter_masks(annotations)
|
||||
# filter_id = list(_)
|
||||
for _, mask in enumerate(annotations):
|
||||
if np.sum(mask['segmentation']) <= 100:
|
||||
filter_id.append(_)
|
||||
continue
|
||||
bbox = self._get_bbox_from_mask(mask['segmentation']) # mask 的 bbox
|
||||
cropped_boxes.append(self._segment_image(image, bbox)) # 保存裁剪的图片
|
||||
# cropped_boxes.append(segment_image(image,mask["segmentation"]))
|
||||
cropped_images.append(bbox) # 保存裁剪的图片的bbox
|
||||
|
||||
return cropped_boxes, cropped_images, not_crop, filter_id, annotations
|
||||
|
||||
def box_prompt(self, bbox):
|
||||
|
||||
assert (bbox[2] != 0 and bbox[3] != 0)
|
||||
masks = self.results[0].masks.data
|
||||
target_height = self.ori_img.shape[0]
|
||||
target_width = self.ori_img.shape[1]
|
||||
h = masks.shape[1]
|
||||
w = masks.shape[2]
|
||||
if h != target_height or w != target_width:
|
||||
bbox = [
|
||||
int(bbox[0] * w / target_width),
|
||||
int(bbox[1] * h / target_height),
|
||||
int(bbox[2] * w / target_width),
|
||||
int(bbox[3] * h / target_height), ]
|
||||
bbox[0] = max(round(bbox[0]), 0)
|
||||
bbox[1] = max(round(bbox[1]), 0)
|
||||
bbox[2] = min(round(bbox[2]), w)
|
||||
bbox[3] = min(round(bbox[3]), h)
|
||||
|
||||
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
|
||||
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
|
||||
|
||||
masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2))
|
||||
orig_masks_area = torch.sum(masks, dim=(1, 2))
|
||||
|
||||
union = bbox_area + orig_masks_area - masks_area
|
||||
IoUs = masks_area / union
|
||||
max_iou_index = torch.argmax(IoUs)
|
||||
|
||||
return np.array([masks[max_iou_index].cpu().numpy()])
|
||||
|
||||
def point_prompt(self, points, pointlabel): # numpy 处理
|
||||
|
||||
masks = self._format_results(self.results[0], 0)
|
||||
target_height = self.ori_img.shape[0]
|
||||
target_width = self.ori_img.shape[1]
|
||||
h = masks[0]['segmentation'].shape[0]
|
||||
w = masks[0]['segmentation'].shape[1]
|
||||
if h != target_height or w != target_width:
|
||||
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
|
||||
onemask = np.zeros((h, w))
|
||||
for i, annotation in enumerate(masks):
|
||||
mask = annotation['segmentation'] if type(annotation) == dict else annotation
|
||||
for i, point in enumerate(points):
|
||||
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
|
||||
onemask += mask
|
||||
if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
|
||||
onemask -= mask
|
||||
onemask = onemask >= 1
|
||||
return np.array([onemask])
|
||||
|
||||
def text_prompt(self, text):
|
||||
format_results = self._format_results(self.results[0], 0)
|
||||
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
|
||||
clip_model, preprocess = self.clip.load('ViT-B/32', device=self.device)
|
||||
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
|
||||
max_idx = scores.argsort()
|
||||
max_idx = max_idx[-1]
|
||||
max_idx += sum(np.array(filter_id) <= int(max_idx))
|
||||
return np.array([annotations[max_idx]['segmentation']])
|
||||
|
||||
def everything_prompt(self):
|
||||
return self.results[0].masks.data
|
@ -1,64 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
|
||||
"""
|
||||
Adjust bounding boxes to stick to image border if they are within a certain threshold.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor): (n, 4)
|
||||
image_shape (tuple): (height, width)
|
||||
threshold (int): pixel threshold
|
||||
|
||||
Returns:
|
||||
adjusted_boxes (torch.Tensor): adjusted bounding boxes
|
||||
"""
|
||||
|
||||
# Image dimensions
|
||||
h, w = image_shape
|
||||
|
||||
# Adjust boxes
|
||||
boxes[boxes[:, 0] < threshold, 0] = 0 # x1
|
||||
boxes[boxes[:, 1] < threshold, 1] = 0 # y1
|
||||
boxes[boxes[:, 2] > w - threshold, 2] = w # x2
|
||||
boxes[boxes[:, 3] > h - threshold, 3] = h # y2
|
||||
return boxes
|
||||
|
||||
|
||||
def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False):
|
||||
"""
|
||||
Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
|
||||
|
||||
Args:
|
||||
box1 (torch.Tensor): (4, )
|
||||
boxes (torch.Tensor): (n, 4)
|
||||
|
||||
Returns:
|
||||
high_iou_indices (torch.Tensor): Indices of boxes with IoU > thres
|
||||
"""
|
||||
boxes = adjust_bboxes_to_image_border(boxes, image_shape)
|
||||
# obtain coordinates for intersections
|
||||
x1 = torch.max(box1[0], boxes[:, 0])
|
||||
y1 = torch.max(box1[1], boxes[:, 1])
|
||||
x2 = torch.min(box1[2], boxes[:, 2])
|
||||
y2 = torch.min(box1[3], boxes[:, 3])
|
||||
|
||||
# compute the area of intersection
|
||||
intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
|
||||
|
||||
# compute the area of both individual boxes
|
||||
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
||||
box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
||||
|
||||
# compute the area of union
|
||||
union = box1_area + box2_area - intersection
|
||||
|
||||
# compute the IoU
|
||||
iou = intersection / union # Should be shape (n, )
|
||||
if raw_output:
|
||||
return 0 if iou.numel() == 0 else iou
|
||||
|
||||
# return indices of boxes with IoU > thres
|
||||
return torch.nonzero(iou > iou_thres).flatten()
|
@ -1,244 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER, NUM_THREADS, ops
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
from ultralytics.yolo.utils.metrics import SegmentMetrics, box_iou, mask_iou
|
||||
from ultralytics.yolo.utils.plotting import output_to_target, plot_images
|
||||
from ultralytics.yolo.v8.detect import DetectionValidator
|
||||
|
||||
|
||||
class FastSAMValidator(DetectionValidator):
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics."""
|
||||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||
self.args.task = 'segment'
|
||||
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
||||
|
||||
def preprocess(self, batch):
|
||||
"""Preprocesses batch by converting masks to float and sending to device."""
|
||||
batch = super().preprocess(batch)
|
||||
batch['masks'] = batch['masks'].to(self.device).float()
|
||||
return batch
|
||||
|
||||
def init_metrics(self, model):
|
||||
"""Initialize metrics and select mask processing function based on save_json flag."""
|
||||
super().init_metrics(model)
|
||||
self.plot_masks = []
|
||||
if self.args.save_json:
|
||||
check_requirements('pycocotools>=2.0.6')
|
||||
self.process = ops.process_mask_upsample # more accurate
|
||||
else:
|
||||
self.process = ops.process_mask # faster
|
||||
|
||||
def get_desc(self):
|
||||
"""Return a formatted description of evaluation metrics."""
|
||||
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Mask(P',
|
||||
'R', 'mAP50', 'mAP50-95)')
|
||||
|
||||
def postprocess(self, preds):
|
||||
"""Postprocesses YOLO predictions and returns output detections with proto."""
|
||||
p = ops.non_max_suppression(preds[0],
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
labels=self.lb,
|
||||
multi_label=True,
|
||||
agnostic=self.args.single_cls,
|
||||
max_det=self.args.max_det,
|
||||
nc=self.nc)
|
||||
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
|
||||
return p, proto
|
||||
|
||||
def update_metrics(self, preds, batch):
|
||||
"""Metrics."""
|
||||
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
|
||||
idx = batch['batch_idx'] == si
|
||||
cls = batch['cls'][idx]
|
||||
bbox = batch['bboxes'][idx]
|
||||
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
|
||||
shape = batch['ori_shape'][si]
|
||||
correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
||||
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
||||
self.seen += 1
|
||||
|
||||
if npr == 0:
|
||||
if nl:
|
||||
self.stats.append((correct_bboxes, correct_masks, *torch.zeros(
|
||||
(2, 0), device=self.device), cls.squeeze(-1)))
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
|
||||
continue
|
||||
|
||||
# Masks
|
||||
midx = [si] if self.args.overlap_mask else idx
|
||||
gt_masks = batch['masks'][midx]
|
||||
pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=batch['img'][si].shape[1:])
|
||||
|
||||
# Predictions
|
||||
if self.args.single_cls:
|
||||
pred[:, 5] = 0
|
||||
predn = pred.clone()
|
||||
ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
|
||||
ratio_pad=batch['ratio_pad'][si]) # native-space pred
|
||||
|
||||
# Evaluate
|
||||
if nl:
|
||||
height, width = batch['img'].shape[2:]
|
||||
tbox = ops.xywh2xyxy(bbox) * torch.tensor(
|
||||
(width, height, width, height), device=self.device) # target boxes
|
||||
ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
|
||||
ratio_pad=batch['ratio_pad'][si]) # native-space labels
|
||||
labelsn = torch.cat((cls, tbox), 1) # native-space labels
|
||||
correct_bboxes = self._process_batch(predn, labelsn)
|
||||
# TODO: maybe remove these `self.` arguments as they already are member variable
|
||||
correct_masks = self._process_batch(predn,
|
||||
labelsn,
|
||||
pred_masks,
|
||||
gt_masks,
|
||||
overlap=self.args.overlap_mask,
|
||||
masks=True)
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.process_batch(predn, labelsn)
|
||||
|
||||
# Append correct_masks, correct_boxes, pconf, pcls, tcls
|
||||
self.stats.append((correct_bboxes, correct_masks, pred[:, 4], pred[:, 5], cls.squeeze(-1)))
|
||||
|
||||
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
|
||||
if self.args.plots and self.batch_i < 3:
|
||||
self.plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
|
||||
|
||||
# Save
|
||||
if self.args.save_json:
|
||||
pred_masks = ops.scale_image(pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
|
||||
shape,
|
||||
ratio_pad=batch['ratio_pad'][si])
|
||||
self.pred_to_json(predn, batch['im_file'][si], pred_masks)
|
||||
# if self.args.save_txt:
|
||||
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
|
||||
|
||||
def finalize_metrics(self, *args, **kwargs):
|
||||
"""Sets speed and confusion matrix for evaluation metrics."""
|
||||
self.metrics.speed = self.speed
|
||||
self.metrics.confusion_matrix = self.confusion_matrix
|
||||
|
||||
def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False):
|
||||
"""
|
||||
Return correct prediction matrix
|
||||
Arguments:
|
||||
detections (array[N, 6]), x1, y1, x2, y2, conf, class
|
||||
labels (array[M, 5]), class, x1, y1, x2, y2
|
||||
Returns:
|
||||
correct (array[N, 10]), for 10 IoU levels
|
||||
"""
|
||||
if masks:
|
||||
if overlap:
|
||||
nl = len(labels)
|
||||
index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
|
||||
gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
|
||||
gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
|
||||
if gt_masks.shape[1:] != pred_masks.shape[1:]:
|
||||
gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode='bilinear', align_corners=False)[0]
|
||||
gt_masks = gt_masks.gt_(0.5)
|
||||
iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
|
||||
else: # boxes
|
||||
iou = box_iou(labels[:, 1:], detections[:, :4])
|
||||
|
||||
correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool)
|
||||
correct_class = labels[:, 0:1] == detections[:, 5]
|
||||
for i in range(len(self.iouv)):
|
||||
x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match
|
||||
if x[0].shape[0]:
|
||||
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]),
|
||||
1).cpu().numpy() # [label, detect, iou]
|
||||
if x[0].shape[0] > 1:
|
||||
matches = matches[matches[:, 2].argsort()[::-1]]
|
||||
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
||||
# matches = matches[matches[:, 2].argsort()[::-1]]
|
||||
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
||||
correct[matches[:, 1].astype(int), i] = True
|
||||
return torch.tensor(correct, dtype=torch.bool, device=detections.device)
|
||||
|
||||
def plot_val_samples(self, batch, ni):
|
||||
"""Plots validation samples with bounding box labels."""
|
||||
plot_images(batch['img'],
|
||||
batch['batch_idx'],
|
||||
batch['cls'].squeeze(-1),
|
||||
batch['bboxes'],
|
||||
batch['masks'],
|
||||
paths=batch['im_file'],
|
||||
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
|
||||
names=self.names,
|
||||
on_plot=self.on_plot)
|
||||
|
||||
def plot_predictions(self, batch, preds, ni):
|
||||
"""Plots batch predictions with masks and bounding boxes."""
|
||||
plot_images(
|
||||
batch['img'],
|
||||
*output_to_target(preds[0], max_det=15), # not set to self.args.max_det due to slow plotting speed
|
||||
torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
|
||||
paths=batch['im_file'],
|
||||
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
|
||||
names=self.names,
|
||||
on_plot=self.on_plot) # pred
|
||||
self.plot_masks.clear()
|
||||
|
||||
def pred_to_json(self, predn, filename, pred_masks):
|
||||
"""Save one JSON result."""
|
||||
# Example result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
|
||||
from pycocotools.mask import encode # noqa
|
||||
|
||||
def single_encode(x):
|
||||
"""Encode predicted masks as RLE and append results to jdict."""
|
||||
rle = encode(np.asarray(x[:, :, None], order='F', dtype='uint8'))[0]
|
||||
rle['counts'] = rle['counts'].decode('utf-8')
|
||||
return rle
|
||||
|
||||
stem = Path(filename).stem
|
||||
image_id = int(stem) if stem.isnumeric() else stem
|
||||
box = ops.xyxy2xywh(predn[:, :4]) # xywh
|
||||
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
|
||||
pred_masks = np.transpose(pred_masks, (2, 0, 1))
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
rles = pool.map(single_encode, pred_masks)
|
||||
for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())):
|
||||
self.jdict.append({
|
||||
'image_id': image_id,
|
||||
'category_id': self.class_map[int(p[5])],
|
||||
'bbox': [round(x, 3) for x in b],
|
||||
'score': round(p[4], 5),
|
||||
'segmentation': rles[i]})
|
||||
|
||||
def eval_json(self, stats):
|
||||
"""Return COCO-style object detection evaluation metrics."""
|
||||
if self.args.save_json and self.is_coco and len(self.jdict):
|
||||
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
|
||||
pred_json = self.save_dir / 'predictions.json' # predictions
|
||||
LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
|
||||
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
||||
check_requirements('pycocotools>=2.0.6')
|
||||
from pycocotools.coco import COCO # noqa
|
||||
from pycocotools.cocoeval import COCOeval # noqa
|
||||
|
||||
for x in anno_json, pred_json:
|
||||
assert x.is_file(), f'{x} file not found'
|
||||
anno = COCO(str(anno_json)) # init annotations api
|
||||
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
||||
for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm')]):
|
||||
if self.is_coco:
|
||||
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
|
||||
eval.evaluate()
|
||||
eval.accumulate()
|
||||
eval.summarize()
|
||||
idx = i * 4 + 2
|
||||
stats[self.metrics.keys[idx + 1]], stats[
|
||||
self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'pycocotools unable to run: {e}')
|
||||
return stats
|
@ -1,7 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from .model import NAS
|
||||
from .predict import NASPredictor
|
||||
from .val import NASValidator
|
||||
|
||||
__all__ = 'NASPredictor', 'NASValidator', 'NAS'
|
@ -1,133 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
YOLO-NAS model interface.
|
||||
|
||||
Usage - Predict:
|
||||
from ultralytics import NAS
|
||||
|
||||
model = NAS('yolo_nas_s')
|
||||
results = model.predict('ultralytics/assets/bus.jpg')
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, ROOT, is_git_dir
|
||||
from ultralytics.yolo.utils.checks import check_imgsz
|
||||
|
||||
from ...yolo.utils.torch_utils import model_info, smart_inference_mode
|
||||
from .predict import NASPredictor
|
||||
from .val import NASValidator
|
||||
|
||||
|
||||
class NAS:
|
||||
|
||||
def __init__(self, model='yolo_nas_s.pt') -> None:
|
||||
# Load or create new NAS model
|
||||
import super_gradients
|
||||
|
||||
self.predictor = None
|
||||
suffix = Path(model).suffix
|
||||
if suffix == '.pt':
|
||||
self._load(model)
|
||||
elif suffix == '':
|
||||
self.model = super_gradients.training.models.get(model, pretrained_weights='coco')
|
||||
self.task = 'detect'
|
||||
self.model.args = DEFAULT_CFG_DICT # attach args to model
|
||||
|
||||
# Standardize model
|
||||
self.model.fuse = lambda verbose=True: self.model
|
||||
self.model.stride = torch.tensor([32])
|
||||
self.model.names = dict(enumerate(self.model._class_names))
|
||||
self.model.is_fused = lambda: False # for info()
|
||||
self.model.yaml = {} # for info()
|
||||
self.model.pt_path = model # for export()
|
||||
self.model.task = 'detect' # for export()
|
||||
self.info()
|
||||
|
||||
@smart_inference_mode()
|
||||
def _load(self, weights: str):
|
||||
self.model = torch.load(weights)
|
||||
|
||||
@smart_inference_mode()
|
||||
def predict(self, source=None, stream=False, **kwargs):
|
||||
"""
|
||||
Perform prediction using the YOLO model.
|
||||
|
||||
Args:
|
||||
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
|
||||
Accepts all source types accepted by the YOLO model.
|
||||
stream (bool): Whether to stream the predictions or not. Defaults to False.
|
||||
**kwargs : Additional keyword arguments passed to the predictor.
|
||||
Check the 'configuration' section in the documentation for all available options.
|
||||
|
||||
Returns:
|
||||
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
|
||||
"""
|
||||
if source is None:
|
||||
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
||||
overrides = dict(conf=0.25, task='detect', mode='predict')
|
||||
overrides.update(kwargs) # prefer kwargs
|
||||
if not self.predictor:
|
||||
self.predictor = NASPredictor(overrides=overrides)
|
||||
self.predictor.setup_model(model=self.model)
|
||||
else: # only update args if predictor is already setup
|
||||
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
||||
return self.predictor(source, stream=stream)
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""Function trains models but raises an error as NAS models do not support training."""
|
||||
raise NotImplementedError("NAS models don't support training")
|
||||
|
||||
def val(self, **kwargs):
|
||||
"""Run validation given dataset."""
|
||||
overrides = dict(task='detect', mode='val')
|
||||
overrides.update(kwargs) # prefer kwargs
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
|
||||
validator = NASValidator(args=args)
|
||||
validator(model=self.model)
|
||||
self.metrics = validator.metrics
|
||||
return validator.metrics
|
||||
|
||||
@smart_inference_mode()
|
||||
def export(self, **kwargs):
|
||||
"""
|
||||
Export model.
|
||||
|
||||
Args:
|
||||
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
|
||||
"""
|
||||
overrides = dict(task='detect')
|
||||
overrides.update(kwargs)
|
||||
overrides['mode'] = 'export'
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
args.task = self.task
|
||||
if args.imgsz == DEFAULT_CFG.imgsz:
|
||||
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
||||
if args.batch == DEFAULT_CFG.batch:
|
||||
args.batch = 1 # default to 1 if not modified
|
||||
return Exporter(overrides=args)(model=self.model)
|
||||
|
||||
def info(self, detailed=False, verbose=True):
|
||||
"""
|
||||
Logs model info.
|
||||
|
||||
Args:
|
||||
detailed (bool): Show detailed information about model.
|
||||
verbose (bool): Controls verbosity.
|
||||
"""
|
||||
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
|
||||
|
||||
def __call__(self, source=None, stream=False, **kwargs):
|
||||
"""Calls the 'predict' function with given arguments to perform object detection."""
|
||||
return self.predict(source, stream, **kwargs)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
"""Raises error if object has no requested attribute."""
|
||||
name = self.__class__.__name__
|
||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
@ -1,35 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.engine.predictor import BasePredictor
|
||||
from ultralytics.yolo.engine.results import Results
|
||||
from ultralytics.yolo.utils import ops
|
||||
from ultralytics.yolo.utils.ops import xyxy2xywh
|
||||
|
||||
|
||||
class NASPredictor(BasePredictor):
|
||||
|
||||
def postprocess(self, preds_in, img, orig_imgs):
|
||||
"""Postprocesses predictions and returns a list of Results objects."""
|
||||
|
||||
# Cat boxes and class scores
|
||||
boxes = xyxy2xywh(preds_in[0][0])
|
||||
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
|
||||
|
||||
preds = ops.non_max_suppression(preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
agnostic=self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
classes=self.args.classes)
|
||||
|
||||
results = []
|
||||
for i, pred in enumerate(preds):
|
||||
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
|
||||
if not isinstance(orig_imgs, torch.Tensor):
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||
path = self.batch[0]
|
||||
img_path = path[i] if isinstance(path, list) else path
|
||||
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred))
|
||||
return results
|
@ -1,25 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.utils import ops
|
||||
from ultralytics.yolo.utils.ops import xyxy2xywh
|
||||
from ultralytics.yolo.v8.detect import DetectionValidator
|
||||
|
||||
__all__ = ['NASValidator']
|
||||
|
||||
|
||||
class NASValidator(DetectionValidator):
|
||||
|
||||
def postprocess(self, preds_in):
|
||||
"""Apply Non-maximum suppression to prediction outputs."""
|
||||
boxes = xyxy2xywh(preds_in[0][0])
|
||||
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
|
||||
return ops.non_max_suppression(preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
labels=self.lb,
|
||||
multi_label=False,
|
||||
agnostic=self.args.single_cls,
|
||||
max_det=self.args.max_det,
|
||||
max_time_img=0.5)
|
@ -1,809 +1,15 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
import logging.config
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import subprocess
|
||||
import importlib
|
||||
import sys
|
||||
import threading
|
||||
import urllib
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Union
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
from ultralytics.utils import LOGGER
|
||||
|
||||
from ultralytics import __version__
|
||||
# Set modules in sys.modules under their old name
|
||||
sys.modules['ultralytics.yolo.utils'] = importlib.import_module('ultralytics.utils')
|
||||
|
||||
# PyTorch Multi-GPU DDP Constants
|
||||
RANK = int(os.getenv('RANK', -1))
|
||||
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
|
||||
|
||||
# Other Constants
|
||||
FILE = Path(__file__).resolve()
|
||||
ROOT = FILE.parents[2] # YOLO
|
||||
DEFAULT_CFG_PATH = ROOT / 'yolo/cfg/default.yaml'
|
||||
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
|
||||
AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
|
||||
VERBOSE = str(os.getenv('YOLO_VERBOSE', True)).lower() == 'true' # global verbose mode
|
||||
TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' # tqdm bar format
|
||||
LOGGING_NAME = 'ultralytics'
|
||||
MACOS, LINUX, WINDOWS = (platform.system() == x for x in ['Darwin', 'Linux', 'Windows']) # environment booleans
|
||||
ARM64 = platform.machine() in ('arm64', 'aarch64') # ARM64 booleans
|
||||
HELP_MSG = \
|
||||
"""
|
||||
Usage examples for running YOLOv8:
|
||||
|
||||
1. Install the ultralytics package:
|
||||
|
||||
pip install ultralytics
|
||||
|
||||
2. Use the Python SDK:
|
||||
|
||||
from ultralytics import YOLO
|
||||
|
||||
# Load a model
|
||||
model = YOLO('yolov8n.yaml') # build a new model from scratch
|
||||
model = YOLO("yolov8n.pt") # load a pretrained model (recommended for training)
|
||||
|
||||
# Use the model
|
||||
results = model.train(data="coco128.yaml", epochs=3) # train the model
|
||||
results = model.val() # evaluate model performance on the validation set
|
||||
results = model('https://ultralytics.com/images/bus.jpg') # predict on an image
|
||||
success = model.export(format='onnx') # export the model to ONNX format
|
||||
|
||||
3. Use the command line interface (CLI):
|
||||
|
||||
YOLOv8 'yolo' CLI commands use the following syntax:
|
||||
|
||||
yolo TASK MODE ARGS
|
||||
|
||||
Where TASK (optional) is one of [detect, segment, classify]
|
||||
MODE (required) is one of [train, val, predict, export]
|
||||
ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
|
||||
See all ARGS at https://docs.ultralytics.com/usage/cfg or with 'yolo cfg'
|
||||
|
||||
- Train a detection model for 10 epochs with an initial learning_rate of 0.01
|
||||
yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
|
||||
|
||||
- Predict a YouTube video using a pretrained segmentation model at image size 320:
|
||||
yolo segment predict model=yolov8n-seg.pt source='https://youtu.be/Zgi9g1ksQHc' imgsz=320
|
||||
|
||||
- Val a pretrained detection model at batch-size 1 and image size 640:
|
||||
yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
|
||||
|
||||
- Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
|
||||
yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
|
||||
|
||||
- Run special commands:
|
||||
yolo help
|
||||
yolo checks
|
||||
yolo version
|
||||
yolo settings
|
||||
yolo copy-cfg
|
||||
yolo cfg
|
||||
|
||||
Docs: https://docs.ultralytics.com
|
||||
Community: https://community.ultralytics.com
|
||||
GitHub: https://github.com/ultralytics/ultralytics
|
||||
"""
|
||||
|
||||
# Settings
|
||||
torch.set_printoptions(linewidth=320, precision=4, profile='default')
|
||||
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
||||
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
|
||||
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
|
||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # for deterministic training
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # suppress verbose TF compiler warnings in Colab
|
||||
|
||||
|
||||
class SimpleClass:
|
||||
"""
|
||||
Ultralytics SimpleClass is a base class providing helpful string representation, error reporting, and attribute
|
||||
access methods for easier debugging and usage.
|
||||
"""
|
||||
|
||||
def __str__(self):
|
||||
"""Return a human-readable string representation of the object."""
|
||||
attr = []
|
||||
for a in dir(self):
|
||||
v = getattr(self, a)
|
||||
if not callable(v) and not a.startswith('_'):
|
||||
if isinstance(v, SimpleClass):
|
||||
# Display only the module and class name for subclasses
|
||||
s = f'{a}: {v.__module__}.{v.__class__.__name__} object'
|
||||
else:
|
||||
s = f'{a}: {repr(v)}'
|
||||
attr.append(s)
|
||||
return f'{self.__module__}.{self.__class__.__name__} object with attributes:\n\n' + '\n'.join(attr)
|
||||
|
||||
def __repr__(self):
|
||||
"""Return a machine-readable string representation of the object."""
|
||||
return self.__str__()
|
||||
|
||||
def __getattr__(self, attr):
|
||||
"""Custom attribute access error message with helpful information."""
|
||||
name = self.__class__.__name__
|
||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||
|
||||
|
||||
class IterableSimpleNamespace(SimpleNamespace):
|
||||
"""
|
||||
Ultralytics IterableSimpleNamespace is an extension class of SimpleNamespace that adds iterable functionality and
|
||||
enables usage with dict() and for loops.
|
||||
"""
|
||||
|
||||
def __iter__(self):
|
||||
"""Return an iterator of key-value pairs from the namespace's attributes."""
|
||||
return iter(vars(self).items())
|
||||
|
||||
def __str__(self):
|
||||
"""Return a human-readable string representation of the object."""
|
||||
return '\n'.join(f'{k}={v}' for k, v in vars(self).items())
|
||||
|
||||
def __getattr__(self, attr):
|
||||
"""Custom attribute access error message with helpful information."""
|
||||
name = self.__class__.__name__
|
||||
raise AttributeError(f"""
|
||||
'{name}' object has no attribute '{attr}'. This may be caused by a modified or out of date ultralytics
|
||||
'default.yaml' file.\nPlease update your code with 'pip install -U ultralytics' and if necessary replace
|
||||
{DEFAULT_CFG_PATH} with the latest version from
|
||||
https://github.com/ultralytics/ultralytics/blob/main/ultralytics/yolo/cfg/default.yaml
|
||||
""")
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""Return the value of the specified key if it exists; otherwise, return the default value."""
|
||||
return getattr(self, key, default)
|
||||
|
||||
|
||||
def plt_settings(rcparams=None, backend='Agg'):
|
||||
"""
|
||||
Decorator to temporarily set rc parameters and the backend for a plotting function.
|
||||
|
||||
Usage:
|
||||
decorator: @plt_settings({"font.size": 12})
|
||||
context manager: with plt_settings({"font.size": 12}):
|
||||
|
||||
Args:
|
||||
rcparams (dict): Dictionary of rc parameters to set.
|
||||
backend (str, optional): Name of the backend to use. Defaults to 'Agg'.
|
||||
|
||||
Returns:
|
||||
(Callable): Decorated function with temporarily set rc parameters and backend. This decorator can be
|
||||
applied to any function that needs to have specific matplotlib rc parameters and backend for its execution.
|
||||
"""
|
||||
|
||||
if rcparams is None:
|
||||
rcparams = {'font.size': 11}
|
||||
|
||||
def decorator(func):
|
||||
"""Decorator to apply temporary rc parameters and backend to a function."""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
"""Sets rc parameters and backend, calls the original function, and restores the settings."""
|
||||
original_backend = plt.get_backend()
|
||||
plt.switch_backend(backend)
|
||||
|
||||
with plt.rc_context(rcparams):
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
plt.switch_backend(original_backend)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def set_logging(name=LOGGING_NAME, verbose=True):
|
||||
"""Sets up logging for the given name."""
|
||||
rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
|
||||
level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
|
||||
logging.config.dictConfig({
|
||||
'version': 1,
|
||||
'disable_existing_loggers': False,
|
||||
'formatters': {
|
||||
name: {
|
||||
'format': '%(message)s'}},
|
||||
'handlers': {
|
||||
name: {
|
||||
'class': 'logging.StreamHandler',
|
||||
'formatter': name,
|
||||
'level': level}},
|
||||
'loggers': {
|
||||
name: {
|
||||
'level': level,
|
||||
'handlers': [name],
|
||||
'propagate': False}}})
|
||||
|
||||
|
||||
def emojis(string=''):
|
||||
"""Return platform-dependent emoji-safe version of string."""
|
||||
return string.encode().decode('ascii', 'ignore') if WINDOWS else string
|
||||
|
||||
|
||||
class EmojiFilter(logging.Filter):
|
||||
"""
|
||||
A custom logging filter class for removing emojis in log messages.
|
||||
|
||||
This filter is particularly useful for ensuring compatibility with Windows terminals
|
||||
that may not support the display of emojis in log messages.
|
||||
"""
|
||||
|
||||
def filter(self, record):
|
||||
"""Filter logs by emoji unicode characters on windows."""
|
||||
record.msg = emojis(record.msg)
|
||||
return super().filter(record)
|
||||
|
||||
|
||||
# Set logger
|
||||
set_logging(LOGGING_NAME, verbose=VERBOSE) # run before defining LOGGER
|
||||
LOGGER = logging.getLogger(LOGGING_NAME) # define globally (used in train.py, val.py, detect.py, etc.)
|
||||
if WINDOWS: # emoji-safe logging
|
||||
LOGGER.addFilter(EmojiFilter())
|
||||
|
||||
|
||||
class ThreadingLocked:
|
||||
"""
|
||||
A decorator class for ensuring thread-safe execution of a function or method.
|
||||
This class can be used as a decorator to make sure that if the decorated function
|
||||
is called from multiple threads, only one thread at a time will be able to execute the function.
|
||||
|
||||
Attributes:
|
||||
lock (threading.Lock): A lock object used to manage access to the decorated function.
|
||||
|
||||
Usage:
|
||||
@ThreadingLocked()
|
||||
def my_function():
|
||||
# Your code here
|
||||
pass
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def __call__(self, f):
|
||||
from functools import wraps
|
||||
|
||||
@wraps(f)
|
||||
def decorated(*args, **kwargs):
|
||||
with self.lock:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def yaml_save(file='data.yaml', data=None):
|
||||
"""
|
||||
Save YAML data to a file.
|
||||
|
||||
Args:
|
||||
file (str, optional): File name. Default is 'data.yaml'.
|
||||
data (dict): Data to save in YAML format.
|
||||
|
||||
Returns:
|
||||
(None): Data is saved to the specified file.
|
||||
"""
|
||||
if data is None:
|
||||
data = {}
|
||||
file = Path(file)
|
||||
if not file.parent.exists():
|
||||
# Create parent directories if they don't exist
|
||||
file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Convert Path objects to strings
|
||||
for k, v in data.items():
|
||||
if isinstance(v, Path):
|
||||
data[k] = str(v)
|
||||
|
||||
# Dump data to file in YAML format
|
||||
with open(file, 'w') as f:
|
||||
yaml.safe_dump(data, f, sort_keys=False, allow_unicode=True)
|
||||
|
||||
|
||||
def yaml_load(file='data.yaml', append_filename=False):
|
||||
"""
|
||||
Load YAML data from a file.
|
||||
|
||||
Args:
|
||||
file (str, optional): File name. Default is 'data.yaml'.
|
||||
append_filename (bool): Add the YAML filename to the YAML dictionary. Default is False.
|
||||
|
||||
Returns:
|
||||
(dict): YAML data and file name.
|
||||
"""
|
||||
with open(file, errors='ignore', encoding='utf-8') as f:
|
||||
s = f.read() # string
|
||||
|
||||
# Remove special characters
|
||||
if not s.isprintable():
|
||||
s = re.sub(r'[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+', '', s)
|
||||
|
||||
# Add YAML filename to dict and return
|
||||
return {**yaml.safe_load(s), 'yaml_file': str(file)} if append_filename else yaml.safe_load(s)
|
||||
|
||||
|
||||
def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
|
||||
"""
|
||||
Pretty prints a yaml file or a yaml-formatted dictionary.
|
||||
|
||||
Args:
|
||||
yaml_file: The file path of the yaml file or a yaml-formatted dictionary.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
yaml_dict = yaml_load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file
|
||||
dump = yaml.dump(yaml_dict, sort_keys=False, allow_unicode=True)
|
||||
LOGGER.info(f"Printing '{colorstr('bold', 'black', yaml_file)}'\n\n{dump}")
|
||||
|
||||
|
||||
# Default configuration
|
||||
DEFAULT_CFG_DICT = yaml_load(DEFAULT_CFG_PATH)
|
||||
for k, v in DEFAULT_CFG_DICT.items():
|
||||
if isinstance(v, str) and v.lower() == 'none':
|
||||
DEFAULT_CFG_DICT[k] = None
|
||||
DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()
|
||||
DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT)
|
||||
|
||||
|
||||
def is_colab():
|
||||
"""
|
||||
Check if the current script is running inside a Google Colab notebook.
|
||||
|
||||
Returns:
|
||||
(bool): True if running inside a Colab notebook, False otherwise.
|
||||
"""
|
||||
return 'COLAB_RELEASE_TAG' in os.environ or 'COLAB_BACKEND_VERSION' in os.environ
|
||||
|
||||
|
||||
def is_kaggle():
|
||||
"""
|
||||
Check if the current script is running inside a Kaggle kernel.
|
||||
|
||||
Returns:
|
||||
(bool): True if running inside a Kaggle kernel, False otherwise.
|
||||
"""
|
||||
return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
|
||||
|
||||
|
||||
def is_jupyter():
|
||||
"""
|
||||
Check if the current script is running inside a Jupyter Notebook.
|
||||
Verified on Colab, Jupyterlab, Kaggle, Paperspace.
|
||||
|
||||
Returns:
|
||||
(bool): True if running inside a Jupyter Notebook, False otherwise.
|
||||
"""
|
||||
with contextlib.suppress(Exception):
|
||||
from IPython import get_ipython
|
||||
return get_ipython() is not None
|
||||
return False
|
||||
|
||||
|
||||
def is_docker() -> bool:
|
||||
"""
|
||||
Determine if the script is running inside a Docker container.
|
||||
|
||||
Returns:
|
||||
(bool): True if the script is running inside a Docker container, False otherwise.
|
||||
"""
|
||||
file = Path('/proc/self/cgroup')
|
||||
if file.exists():
|
||||
with open(file) as f:
|
||||
return 'docker' in f.read()
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_online() -> bool:
|
||||
"""
|
||||
Check internet connectivity by attempting to connect to a known online host.
|
||||
|
||||
Returns:
|
||||
(bool): True if connection is successful, False otherwise.
|
||||
"""
|
||||
import socket
|
||||
|
||||
for host in '1.1.1.1', '8.8.8.8', '223.5.5.5': # Cloudflare, Google, AliDNS:
|
||||
try:
|
||||
test_connection = socket.create_connection(address=(host, 53), timeout=2)
|
||||
except (socket.timeout, socket.gaierror, OSError):
|
||||
continue
|
||||
else:
|
||||
# If the connection was successful, close it to avoid a ResourceWarning
|
||||
test_connection.close()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
ONLINE = is_online()
|
||||
|
||||
|
||||
def is_pip_package(filepath: str = __name__) -> bool:
|
||||
"""
|
||||
Determines if the file at the given filepath is part of a pip package.
|
||||
|
||||
Args:
|
||||
filepath (str): The filepath to check.
|
||||
|
||||
Returns:
|
||||
(bool): True if the file is part of a pip package, False otherwise.
|
||||
"""
|
||||
import importlib.util
|
||||
|
||||
# Get the spec for the module
|
||||
spec = importlib.util.find_spec(filepath)
|
||||
|
||||
# Return whether the spec is not None and the origin is not None (indicating it is a package)
|
||||
return spec is not None and spec.origin is not None
|
||||
|
||||
|
||||
def is_dir_writeable(dir_path: Union[str, Path]) -> bool:
|
||||
"""
|
||||
Check if a directory is writeable.
|
||||
|
||||
Args:
|
||||
dir_path (str | Path): The path to the directory.
|
||||
|
||||
Returns:
|
||||
(bool): True if the directory is writeable, False otherwise.
|
||||
"""
|
||||
return os.access(str(dir_path), os.W_OK)
|
||||
|
||||
|
||||
def is_pytest_running():
|
||||
"""
|
||||
Determines whether pytest is currently running or not.
|
||||
|
||||
Returns:
|
||||
(bool): True if pytest is running, False otherwise.
|
||||
"""
|
||||
return ('PYTEST_CURRENT_TEST' in os.environ) or ('pytest' in sys.modules) or ('pytest' in Path(sys.argv[0]).stem)
|
||||
|
||||
|
||||
def is_github_actions_ci() -> bool:
|
||||
"""
|
||||
Determine if the current environment is a GitHub Actions CI Python runner.
|
||||
|
||||
Returns:
|
||||
(bool): True if the current environment is a GitHub Actions CI Python runner, False otherwise.
|
||||
"""
|
||||
return 'GITHUB_ACTIONS' in os.environ and 'RUNNER_OS' in os.environ and 'RUNNER_TOOL_CACHE' in os.environ
|
||||
|
||||
|
||||
def is_git_dir():
|
||||
"""
|
||||
Determines whether the current file is part of a git repository.
|
||||
If the current file is not part of a git repository, returns None.
|
||||
|
||||
Returns:
|
||||
(bool): True if current file is part of a git repository.
|
||||
"""
|
||||
return get_git_dir() is not None
|
||||
|
||||
|
||||
def get_git_dir():
|
||||
"""
|
||||
Determines whether the current file is part of a git repository and if so, returns the repository root directory.
|
||||
If the current file is not part of a git repository, returns None.
|
||||
|
||||
Returns:
|
||||
(Path | None): Git root directory if found or None if not found.
|
||||
"""
|
||||
for d in Path(__file__).parents:
|
||||
if (d / '.git').is_dir():
|
||||
return d
|
||||
return None # no .git dir found
|
||||
|
||||
|
||||
def get_git_origin_url():
|
||||
"""
|
||||
Retrieves the origin URL of a git repository.
|
||||
|
||||
Returns:
|
||||
(str | None): The origin URL of the git repository.
|
||||
"""
|
||||
if is_git_dir():
|
||||
with contextlib.suppress(subprocess.CalledProcessError):
|
||||
origin = subprocess.check_output(['git', 'config', '--get', 'remote.origin.url'])
|
||||
return origin.decode().strip()
|
||||
return None # if not git dir or on error
|
||||
|
||||
|
||||
def get_git_branch():
|
||||
"""
|
||||
Returns the current git branch name. If not in a git repository, returns None.
|
||||
|
||||
Returns:
|
||||
(str | None): The current git branch name.
|
||||
"""
|
||||
if is_git_dir():
|
||||
with contextlib.suppress(subprocess.CalledProcessError):
|
||||
origin = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
||||
return origin.decode().strip()
|
||||
return None # if not git dir or on error
|
||||
|
||||
|
||||
def get_default_args(func):
|
||||
"""Returns a dictionary of default arguments for a function.
|
||||
|
||||
Args:
|
||||
func (callable): The function to inspect.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary where each key is a parameter name, and each value is the default value of that parameter.
|
||||
"""
|
||||
signature = inspect.signature(func)
|
||||
return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
|
||||
|
||||
|
||||
def get_user_config_dir(sub_dir='Ultralytics'):
|
||||
"""
|
||||
Get the user config directory.
|
||||
|
||||
Args:
|
||||
sub_dir (str): The name of the subdirectory to create.
|
||||
|
||||
Returns:
|
||||
(Path): The path to the user config directory.
|
||||
"""
|
||||
# Return the appropriate config directory for each operating system
|
||||
if WINDOWS:
|
||||
path = Path.home() / 'AppData' / 'Roaming' / sub_dir
|
||||
elif MACOS: # macOS
|
||||
path = Path.home() / 'Library' / 'Application Support' / sub_dir
|
||||
elif LINUX:
|
||||
path = Path.home() / '.config' / sub_dir
|
||||
else:
|
||||
raise ValueError(f'Unsupported operating system: {platform.system()}')
|
||||
|
||||
# GCP and AWS lambda fix, only /tmp is writeable
|
||||
if not is_dir_writeable(str(path.parent)):
|
||||
path = Path('/tmp') / sub_dir
|
||||
LOGGER.warning(f"WARNING ⚠️ user config directory is not writeable, defaulting to '{path}'.")
|
||||
|
||||
# Create the subdirectory if it does not exist
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
USER_CONFIG_DIR = Path(os.getenv('YOLO_CONFIG_DIR', get_user_config_dir())) # Ultralytics settings dir
|
||||
SETTINGS_YAML = USER_CONFIG_DIR / 'settings.yaml'
|
||||
|
||||
|
||||
def colorstr(*input):
|
||||
"""Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')."""
|
||||
*args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
|
||||
colors = {
|
||||
'black': '\033[30m', # basic colors
|
||||
'red': '\033[31m',
|
||||
'green': '\033[32m',
|
||||
'yellow': '\033[33m',
|
||||
'blue': '\033[34m',
|
||||
'magenta': '\033[35m',
|
||||
'cyan': '\033[36m',
|
||||
'white': '\033[37m',
|
||||
'bright_black': '\033[90m', # bright colors
|
||||
'bright_red': '\033[91m',
|
||||
'bright_green': '\033[92m',
|
||||
'bright_yellow': '\033[93m',
|
||||
'bright_blue': '\033[94m',
|
||||
'bright_magenta': '\033[95m',
|
||||
'bright_cyan': '\033[96m',
|
||||
'bright_white': '\033[97m',
|
||||
'end': '\033[0m', # misc
|
||||
'bold': '\033[1m',
|
||||
'underline': '\033[4m'}
|
||||
return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
|
||||
|
||||
|
||||
class TryExcept(contextlib.ContextDecorator):
|
||||
"""YOLOv8 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager."""
|
||||
|
||||
def __init__(self, msg='', verbose=True):
|
||||
"""Initialize TryExcept class with optional message and verbosity settings."""
|
||||
self.msg = msg
|
||||
self.verbose = verbose
|
||||
|
||||
def __enter__(self):
|
||||
"""Executes when entering TryExcept context, initializes instance."""
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type, value, traceback):
|
||||
"""Defines behavior when exiting a 'with' block, prints error message if necessary."""
|
||||
if self.verbose and value:
|
||||
print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}"))
|
||||
return True
|
||||
|
||||
|
||||
def threaded(func):
|
||||
"""Multi-threads a target function and returns thread. Usage: @threaded decorator."""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
"""Multi-threads a given function and returns the thread."""
|
||||
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
|
||||
thread.start()
|
||||
return thread
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def set_sentry():
|
||||
"""
|
||||
Initialize the Sentry SDK for error tracking and reporting. Only used if sentry_sdk package is installed and
|
||||
sync=True in settings. Run 'yolo settings' to see and update settings YAML file.
|
||||
|
||||
Conditions required to send errors (ALL conditions must be met or no errors will be reported):
|
||||
- sentry_sdk package is installed
|
||||
- sync=True in YOLO settings
|
||||
- pytest is not running
|
||||
- running in a pip package installation
|
||||
- running in a non-git directory
|
||||
- running with rank -1 or 0
|
||||
- online environment
|
||||
- CLI used to run package (checked with 'yolo' as the name of the main CLI command)
|
||||
|
||||
The function also configures Sentry SDK to ignore KeyboardInterrupt and FileNotFoundError
|
||||
exceptions and to exclude events with 'out of memory' in their exception message.
|
||||
|
||||
Additionally, the function sets custom tags and user information for Sentry events.
|
||||
"""
|
||||
|
||||
def before_send(event, hint):
|
||||
"""
|
||||
Modify the event before sending it to Sentry based on specific exception types and messages.
|
||||
|
||||
Args:
|
||||
event (dict): The event dictionary containing information about the error.
|
||||
hint (dict): A dictionary containing additional information about the error.
|
||||
|
||||
Returns:
|
||||
dict: The modified event or None if the event should not be sent to Sentry.
|
||||
"""
|
||||
if 'exc_info' in hint:
|
||||
exc_type, exc_value, tb = hint['exc_info']
|
||||
if exc_type in (KeyboardInterrupt, FileNotFoundError) \
|
||||
or 'out of memory' in str(exc_value):
|
||||
return None # do not send event
|
||||
|
||||
event['tags'] = {
|
||||
'sys_argv': sys.argv[0],
|
||||
'sys_argv_name': Path(sys.argv[0]).name,
|
||||
'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
|
||||
'os': ENVIRONMENT}
|
||||
return event
|
||||
|
||||
if SETTINGS['sync'] and \
|
||||
RANK in (-1, 0) and \
|
||||
Path(sys.argv[0]).name == 'yolo' and \
|
||||
not TESTS_RUNNING and \
|
||||
ONLINE and \
|
||||
is_pip_package() and \
|
||||
not is_git_dir():
|
||||
|
||||
# If sentry_sdk package is not installed then return and do not use Sentry
|
||||
try:
|
||||
import sentry_sdk # noqa
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
sentry_sdk.init(
|
||||
dsn='https://5ff1556b71594bfea135ff0203a0d290@o4504521589325824.ingest.sentry.io/4504521592406016',
|
||||
debug=False,
|
||||
traces_sample_rate=1.0,
|
||||
release=__version__,
|
||||
environment='production', # 'dev' or 'production'
|
||||
before_send=before_send,
|
||||
ignore_errors=[KeyboardInterrupt, FileNotFoundError])
|
||||
sentry_sdk.set_user({'id': SETTINGS['uuid']}) # SHA-256 anonymized UUID hash
|
||||
|
||||
# Disable all sentry logging
|
||||
for logger in 'sentry_sdk', 'sentry_sdk.errors':
|
||||
logging.getLogger(logger).setLevel(logging.CRITICAL)
|
||||
|
||||
|
||||
def get_settings(file=SETTINGS_YAML, version='0.0.3'):
|
||||
"""
|
||||
Loads a global Ultralytics settings YAML file or creates one with default values if it does not exist.
|
||||
|
||||
Args:
|
||||
file (Path): Path to the Ultralytics settings YAML file. Defaults to 'settings.yaml' in the USER_CONFIG_DIR.
|
||||
version (str): Settings version. If min settings version not met, new default settings will be saved.
|
||||
|
||||
Returns:
|
||||
(dict): Dictionary of settings key-value pairs.
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
from ultralytics.yolo.utils.checks import check_version
|
||||
from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first
|
||||
|
||||
git_dir = get_git_dir()
|
||||
root = git_dir or Path()
|
||||
datasets_root = (root.parent if git_dir and is_dir_writeable(root.parent) else root).resolve()
|
||||
defaults = {
|
||||
'datasets_dir': str(datasets_root / 'datasets'), # default datasets directory.
|
||||
'weights_dir': str(root / 'weights'), # default weights directory.
|
||||
'runs_dir': str(root / 'runs'), # default runs directory.
|
||||
'uuid': hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(), # SHA-256 anonymized UUID hash
|
||||
'sync': True, # sync analytics to help with YOLO development
|
||||
'api_key': '', # Ultralytics HUB API key (https://hub.ultralytics.com/)
|
||||
'settings_version': version} # Ultralytics settings version
|
||||
|
||||
with torch_distributed_zero_first(RANK):
|
||||
if not file.exists():
|
||||
yaml_save(file, defaults)
|
||||
settings = yaml_load(file)
|
||||
|
||||
# Check that settings keys and types match defaults
|
||||
correct = \
|
||||
settings \
|
||||
and settings.keys() == defaults.keys() \
|
||||
and all(type(a) == type(b) for a, b in zip(settings.values(), defaults.values())) \
|
||||
and check_version(settings['settings_version'], version)
|
||||
if not correct:
|
||||
LOGGER.warning('WARNING ⚠️ Ultralytics settings reset to defaults. This is normal and may be due to a '
|
||||
'recent ultralytics package update, but may have overwritten previous settings. '
|
||||
f"\nView and update settings with 'yolo settings' or at '{file}'")
|
||||
settings = defaults # merge **defaults with **settings (prefer **settings)
|
||||
yaml_save(file, settings) # save updated defaults
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
def set_settings(kwargs, file=SETTINGS_YAML):
|
||||
"""
|
||||
Function that runs on a first-time ultralytics package installation to set up global settings and create necessary
|
||||
directories.
|
||||
"""
|
||||
SETTINGS.update(kwargs)
|
||||
yaml_save(file, SETTINGS)
|
||||
|
||||
|
||||
def deprecation_warn(arg, new_arg, version=None):
|
||||
"""Issue a deprecation warning when a deprecated argument is used, suggesting an updated argument."""
|
||||
if not version:
|
||||
version = float(__version__[:3]) + 0.2 # deprecate after 2nd major release
|
||||
LOGGER.warning(f"WARNING ⚠️ '{arg}' is deprecated and will be removed in 'ultralytics {version}' in the future. "
|
||||
f"Please use '{new_arg}' instead.")
|
||||
|
||||
|
||||
def clean_url(url):
|
||||
"""Strip auth from URL, i.e. https://url.com/file.txt?auth -> https://url.com/file.txt."""
|
||||
url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
|
||||
return urllib.parse.unquote(url).split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
|
||||
|
||||
|
||||
def url2file(url):
|
||||
"""Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt."""
|
||||
return Path(clean_url(url)).name
|
||||
|
||||
|
||||
# Run below code on yolo/utils init ------------------------------------------------------------------------------------
|
||||
|
||||
# Check first-install steps
|
||||
PREFIX = colorstr('Ultralytics: ')
|
||||
SETTINGS = get_settings()
|
||||
DATASETS_DIR = Path(SETTINGS['datasets_dir']) # global datasets directory
|
||||
ENVIRONMENT = 'Colab' if is_colab() else 'Kaggle' if is_kaggle() else 'Jupyter' if is_jupyter() else \
|
||||
'Docker' if is_docker() else platform.system()
|
||||
TESTS_RUNNING = is_pytest_running() or is_github_actions_ci()
|
||||
set_sentry()
|
||||
|
||||
# Apply monkey patches if the script is being run from within the parent directory of the script's location
|
||||
from .patches import imread, imshow, imwrite
|
||||
|
||||
# torch.save = torch_save
|
||||
if Path(inspect.stack()[0].filename).parent.parent.as_posix() in inspect.stack()[-1].filename:
|
||||
cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow
|
||||
UTILS_WARNING = """WARNING ⚠️ 'ultralytics.yolo.utils' is deprecated since '8.0.136' and will be removed in '8.1.0'. Please use 'ultralytics.utils' instead.
|
||||
Note this warning may be related to loading older models. You can update your model to current structure with:
|
||||
import torch
|
||||
ckpt = torch.load("model.pt") # applies to both official and custom models
|
||||
torch.save(ckpt, "updated-model.pt")
|
||||
"""
|
||||
LOGGER.warning(UTILS_WARNING)
|
||||
|
@ -1,90 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch.
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, colorstr
|
||||
from ultralytics.yolo.utils.torch_utils import profile
|
||||
|
||||
|
||||
def check_train_batch_size(model, imgsz=640, amp=True):
|
||||
"""
|
||||
Check YOLO training batch size using the autobatch() function.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): YOLO model to check batch size for.
|
||||
imgsz (int): Image size used for training.
|
||||
amp (bool): If True, use automatic mixed precision (AMP) for training.
|
||||
|
||||
Returns:
|
||||
(int): Optimal batch size computed using the autobatch() function.
|
||||
"""
|
||||
|
||||
with torch.cuda.amp.autocast(amp):
|
||||
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
|
||||
|
||||
|
||||
def autobatch(model, imgsz=640, fraction=0.67, batch_size=DEFAULT_CFG.batch):
|
||||
"""
|
||||
Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory.
|
||||
|
||||
Args:
|
||||
model (torch.nn.module): YOLO model to compute batch size for.
|
||||
imgsz (int, optional): The image size used as input for the YOLO model. Defaults to 640.
|
||||
fraction (float, optional): The fraction of available CUDA memory to use. Defaults to 0.67.
|
||||
batch_size (int, optional): The default batch size to use if an error is detected. Defaults to 16.
|
||||
|
||||
Returns:
|
||||
(int): The optimal batch size.
|
||||
"""
|
||||
|
||||
# Check device
|
||||
prefix = colorstr('AutoBatch: ')
|
||||
LOGGER.info(f'{prefix}Computing optimal batch size for imgsz={imgsz}')
|
||||
device = next(model.parameters()).device # get model device
|
||||
if device.type == 'cpu':
|
||||
LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
|
||||
return batch_size
|
||||
if torch.backends.cudnn.benchmark:
|
||||
LOGGER.info(f'{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}')
|
||||
return batch_size
|
||||
|
||||
# Inspect CUDA memory
|
||||
gb = 1 << 30 # bytes to GiB (1024 ** 3)
|
||||
d = str(device).upper() # 'CUDA:0'
|
||||
properties = torch.cuda.get_device_properties(device) # device properties
|
||||
t = properties.total_memory / gb # GiB total
|
||||
r = torch.cuda.memory_reserved(device) / gb # GiB reserved
|
||||
a = torch.cuda.memory_allocated(device) / gb # GiB allocated
|
||||
f = t - (r + a) # GiB free
|
||||
LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free')
|
||||
|
||||
# Profile batch sizes
|
||||
batch_sizes = [1, 2, 4, 8, 16]
|
||||
try:
|
||||
img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
|
||||
results = profile(img, model, n=3, device=device)
|
||||
|
||||
# Fit a solution
|
||||
y = [x[2] for x in results if x] # memory [2]
|
||||
p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit
|
||||
b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
|
||||
if None in results: # some sizes failed
|
||||
i = results.index(None) # first fail index
|
||||
if b >= batch_sizes[i]: # y intercept above failure point
|
||||
b = batch_sizes[max(i - 1, 0)] # select prior safe point
|
||||
if b < 1 or b > 1024: # b outside of safe range
|
||||
b = batch_size
|
||||
LOGGER.info(f'{prefix}WARNING ⚠️ CUDA anomaly detected, using default batch-size {batch_size}.')
|
||||
|
||||
fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted
|
||||
LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅')
|
||||
return b
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.')
|
||||
return batch_size
|
@ -1,358 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Benchmark a YOLO model formats for speed and accuracy
|
||||
|
||||
Usage:
|
||||
from ultralytics.yolo.utils.benchmarks import ProfileModels, benchmark
|
||||
ProfileModels(['yolov8n.yaml', 'yolov8s.yaml']).profile()
|
||||
run_benchmarks(model='yolov8n.pt', imgsz=160)
|
||||
|
||||
Format | `format=argument` | Model
|
||||
--- | --- | ---
|
||||
PyTorch | - | yolov8n.pt
|
||||
TorchScript | `torchscript` | yolov8n.torchscript
|
||||
ONNX | `onnx` | yolov8n.onnx
|
||||
OpenVINO | `openvino` | yolov8n_openvino_model/
|
||||
TensorRT | `engine` | yolov8n.engine
|
||||
CoreML | `coreml` | yolov8n.mlmodel
|
||||
TensorFlow SavedModel | `saved_model` | yolov8n_saved_model/
|
||||
TensorFlow GraphDef | `pb` | yolov8n.pb
|
||||
TensorFlow Lite | `tflite` | yolov8n.tflite
|
||||
TensorFlow Edge TPU | `edgetpu` | yolov8n_edgetpu.tflite
|
||||
TensorFlow.js | `tfjs` | yolov8n_web_model/
|
||||
PaddlePaddle | `paddle` | yolov8n_paddle_model/
|
||||
ncnn | `ncnn` | yolov8n_ncnn_model/
|
||||
"""
|
||||
|
||||
import glob
|
||||
import platform
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch.cuda
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.yolo.cfg import TASK2DATA, TASK2METRIC
|
||||
from ultralytics.yolo.engine.exporter import export_formats
|
||||
from ultralytics.yolo.utils import LINUX, LOGGER, MACOS, ROOT, SETTINGS
|
||||
from ultralytics.yolo.utils.checks import check_requirements, check_yolo
|
||||
from ultralytics.yolo.utils.downloads import download
|
||||
from ultralytics.yolo.utils.files import file_size
|
||||
from ultralytics.yolo.utils.torch_utils import select_device
|
||||
|
||||
|
||||
def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
|
||||
imgsz=160,
|
||||
half=False,
|
||||
int8=False,
|
||||
device='cpu',
|
||||
hard_fail=False):
|
||||
"""
|
||||
Benchmark a YOLO model across different formats for speed and accuracy.
|
||||
|
||||
Args:
|
||||
model (str | Path | optional): Path to the model file or directory. Default is
|
||||
Path(SETTINGS['weights_dir']) / 'yolov8n.pt'.
|
||||
imgsz (int, optional): Image size for the benchmark. Default is 160.
|
||||
half (bool, optional): Use half-precision for the model if True. Default is False.
|
||||
int8 (bool, optional): Use int8-precision for the model if True. Default is False.
|
||||
device (str, optional): Device to run the benchmark on, either 'cpu' or 'cuda'. Default is 'cpu'.
|
||||
hard_fail (bool | float | optional): If True or a float, assert benchmarks pass with given metric.
|
||||
Default is False.
|
||||
|
||||
Returns:
|
||||
df (pandas.DataFrame): A pandas DataFrame with benchmark results for each format, including file size,
|
||||
metric, and inference time.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
pd.options.display.max_columns = 10
|
||||
pd.options.display.width = 120
|
||||
device = select_device(device, verbose=False)
|
||||
if isinstance(model, (str, Path)):
|
||||
model = YOLO(model)
|
||||
|
||||
y = []
|
||||
t0 = time.time()
|
||||
for i, (name, format, suffix, cpu, gpu) in export_formats().iterrows(): # index, (name, format, suffix, CPU, GPU)
|
||||
emoji, filename = '❌', None # export defaults
|
||||
try:
|
||||
assert i != 9 or LINUX, 'Edge TPU export only supported on Linux'
|
||||
if i == 10:
|
||||
assert MACOS or LINUX, 'TF.js export only supported on macOS and Linux'
|
||||
if 'cpu' in device.type:
|
||||
assert cpu, 'inference not supported on CPU'
|
||||
if 'cuda' in device.type:
|
||||
assert gpu, 'inference not supported on GPU'
|
||||
|
||||
# Export
|
||||
if format == '-':
|
||||
filename = model.ckpt_path or model.cfg
|
||||
export = model # PyTorch format
|
||||
else:
|
||||
filename = model.export(imgsz=imgsz, format=format, half=half, int8=int8, device=device, verbose=False)
|
||||
export = YOLO(filename, task=model.task)
|
||||
assert suffix in str(filename), 'export failed'
|
||||
emoji = '❎' # indicates export succeeded
|
||||
|
||||
# Predict
|
||||
assert model.task != 'pose' or i != 7, 'GraphDef Pose inference is not supported'
|
||||
assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported
|
||||
assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML
|
||||
if not (ROOT / 'assets/bus.jpg').exists():
|
||||
download(url='https://ultralytics.com/images/bus.jpg', dir=ROOT / 'assets')
|
||||
export.predict(ROOT / 'assets/bus.jpg', imgsz=imgsz, device=device, half=half)
|
||||
|
||||
# Validate
|
||||
data = TASK2DATA[model.task] # task to dataset, i.e. coco8.yaml for task=detect
|
||||
key = TASK2METRIC[model.task] # task to metric, i.e. metrics/mAP50-95(B) for task=detect
|
||||
results = export.val(data=data,
|
||||
batch=1,
|
||||
imgsz=imgsz,
|
||||
plots=False,
|
||||
device=device,
|
||||
half=half,
|
||||
int8=int8,
|
||||
verbose=False)
|
||||
metric, speed = results.results_dict[key], results.speed['inference']
|
||||
y.append([name, '✅', round(file_size(filename), 1), round(metric, 4), round(speed, 2)])
|
||||
except Exception as e:
|
||||
if hard_fail:
|
||||
assert type(e) is AssertionError, f'Benchmark hard_fail for {name}: {e}'
|
||||
LOGGER.warning(f'ERROR ❌️ Benchmark failure for {name}: {e}')
|
||||
y.append([name, emoji, round(file_size(filename), 1), None, None]) # mAP, t_inference
|
||||
|
||||
# Print results
|
||||
check_yolo(device=device) # print system info
|
||||
df = pd.DataFrame(y, columns=['Format', 'Status❔', 'Size (MB)', key, 'Inference time (ms/im)'])
|
||||
|
||||
name = Path(model.ckpt_path).name
|
||||
s = f'\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)\n{df}\n'
|
||||
LOGGER.info(s)
|
||||
with open('benchmarks.log', 'a', errors='ignore', encoding='utf-8') as f:
|
||||
f.write(s)
|
||||
|
||||
if hard_fail and isinstance(hard_fail, float):
|
||||
metrics = df[key].array # values to compare to floor
|
||||
floor = hard_fail # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n
|
||||
assert all(x > floor for x in metrics if pd.notna(x)), f'HARD FAIL: one or more metric(s) < floor {floor}'
|
||||
|
||||
return df
|
||||
|
||||
|
||||
class ProfileModels:
|
||||
"""
|
||||
ProfileModels class for profiling different models on ONNX and TensorRT.
|
||||
|
||||
This class profiles the performance of different models, provided their paths. The profiling includes parameters such as
|
||||
model speed and FLOPs.
|
||||
|
||||
Attributes:
|
||||
paths (list): Paths of the models to profile.
|
||||
num_timed_runs (int): Number of timed runs for the profiling. Default is 100.
|
||||
num_warmup_runs (int): Number of warmup runs before profiling. Default is 10.
|
||||
min_time (float): Minimum number of seconds to profile for. Default is 60.
|
||||
imgsz (int): Image size used in the models. Default is 640.
|
||||
|
||||
Methods:
|
||||
profile(): Profiles the models and prints the result.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
paths: list,
|
||||
num_timed_runs=100,
|
||||
num_warmup_runs=10,
|
||||
min_time=60,
|
||||
imgsz=640,
|
||||
trt=True,
|
||||
device=None):
|
||||
self.paths = paths
|
||||
self.num_timed_runs = num_timed_runs
|
||||
self.num_warmup_runs = num_warmup_runs
|
||||
self.min_time = min_time
|
||||
self.imgsz = imgsz
|
||||
self.trt = trt # run TensorRT profiling
|
||||
self.device = device or torch.device(0 if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
def profile(self):
|
||||
files = self.get_files()
|
||||
|
||||
if not files:
|
||||
print('No matching *.pt or *.onnx files found.')
|
||||
return
|
||||
|
||||
table_rows = []
|
||||
output = []
|
||||
for file in files:
|
||||
engine_file = file.with_suffix('.engine')
|
||||
if file.suffix in ('.pt', '.yaml'):
|
||||
model = YOLO(str(file))
|
||||
model.fuse() # to report correct params and GFLOPs in model.info()
|
||||
model_info = model.info()
|
||||
if self.trt and self.device.type != 'cpu' and not engine_file.is_file():
|
||||
engine_file = model.export(format='engine',
|
||||
half=True,
|
||||
imgsz=self.imgsz,
|
||||
device=self.device,
|
||||
verbose=False)
|
||||
onnx_file = model.export(format='onnx',
|
||||
half=True,
|
||||
imgsz=self.imgsz,
|
||||
simplify=True,
|
||||
device=self.device,
|
||||
verbose=False)
|
||||
elif file.suffix == '.onnx':
|
||||
model_info = self.get_onnx_model_info(file)
|
||||
onnx_file = file
|
||||
else:
|
||||
continue
|
||||
|
||||
t_engine = self.profile_tensorrt_model(str(engine_file))
|
||||
t_onnx = self.profile_onnx_model(str(onnx_file))
|
||||
table_rows.append(self.generate_table_row(file.stem, t_onnx, t_engine, model_info))
|
||||
output.append(self.generate_results_dict(file.stem, t_onnx, t_engine, model_info))
|
||||
|
||||
self.print_table(table_rows)
|
||||
return output
|
||||
|
||||
def get_files(self):
|
||||
files = []
|
||||
for path in self.paths:
|
||||
path = Path(path)
|
||||
if path.is_dir():
|
||||
extensions = ['*.pt', '*.onnx', '*.yaml']
|
||||
files.extend([file for ext in extensions for file in glob.glob(str(path / ext))])
|
||||
elif path.suffix in {'.pt', '.yaml'}: # add non-existing
|
||||
files.append(str(path))
|
||||
else:
|
||||
files.extend(glob.glob(str(path)))
|
||||
|
||||
print(f'Profiling: {sorted(files)}')
|
||||
return [Path(file) for file in sorted(files)]
|
||||
|
||||
def get_onnx_model_info(self, onnx_file: str):
|
||||
# return (num_layers, num_params, num_gradients, num_flops)
|
||||
return 0.0, 0.0, 0.0, 0.0
|
||||
|
||||
def iterative_sigma_clipping(self, data, sigma=2, max_iters=3):
|
||||
data = np.array(data)
|
||||
for _ in range(max_iters):
|
||||
mean, std = np.mean(data), np.std(data)
|
||||
clipped_data = data[(data > mean - sigma * std) & (data < mean + sigma * std)]
|
||||
if len(clipped_data) == len(data):
|
||||
break
|
||||
data = clipped_data
|
||||
return data
|
||||
|
||||
def profile_tensorrt_model(self, engine_file: str):
|
||||
if not self.trt or not Path(engine_file).is_file():
|
||||
return 0.0, 0.0
|
||||
|
||||
# Model and input
|
||||
model = YOLO(engine_file)
|
||||
input_data = np.random.rand(self.imgsz, self.imgsz, 3).astype(np.float32) # must be FP32
|
||||
|
||||
# Warmup runs
|
||||
elapsed = 0.0
|
||||
for _ in range(3):
|
||||
start_time = time.time()
|
||||
for _ in range(self.num_warmup_runs):
|
||||
model(input_data, imgsz=self.imgsz, verbose=False)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Compute number of runs as higher of min_time or num_timed_runs
|
||||
num_runs = max(round(self.min_time / elapsed * self.num_warmup_runs), self.num_timed_runs * 50)
|
||||
|
||||
# Timed runs
|
||||
run_times = []
|
||||
for _ in tqdm(range(num_runs), desc=engine_file):
|
||||
results = model(input_data, imgsz=self.imgsz, verbose=False)
|
||||
run_times.append(results[0].speed['inference']) # Convert to milliseconds
|
||||
|
||||
run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=3) # sigma clipping
|
||||
return np.mean(run_times), np.std(run_times)
|
||||
|
||||
def profile_onnx_model(self, onnx_file: str):
|
||||
check_requirements('onnxruntime')
|
||||
import onnxruntime as ort
|
||||
|
||||
# Session with either 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'
|
||||
sess_options = ort.SessionOptions()
|
||||
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
sess_options.intra_op_num_threads = 8 # Limit the number of threads
|
||||
sess = ort.InferenceSession(onnx_file, sess_options, providers=['CPUExecutionProvider'])
|
||||
|
||||
input_tensor = sess.get_inputs()[0]
|
||||
input_type = input_tensor.type
|
||||
|
||||
# Mapping ONNX datatype to numpy datatype
|
||||
if 'float16' in input_type:
|
||||
input_dtype = np.float16
|
||||
elif 'float' in input_type:
|
||||
input_dtype = np.float32
|
||||
elif 'double' in input_type:
|
||||
input_dtype = np.float64
|
||||
elif 'int64' in input_type:
|
||||
input_dtype = np.int64
|
||||
elif 'int32' in input_type:
|
||||
input_dtype = np.int32
|
||||
else:
|
||||
raise ValueError(f'Unsupported ONNX datatype {input_type}')
|
||||
|
||||
input_data = np.random.rand(*input_tensor.shape).astype(input_dtype)
|
||||
input_name = input_tensor.name
|
||||
output_name = sess.get_outputs()[0].name
|
||||
|
||||
# Warmup runs
|
||||
elapsed = 0.0
|
||||
for _ in range(3):
|
||||
start_time = time.time()
|
||||
for _ in range(self.num_warmup_runs):
|
||||
sess.run([output_name], {input_name: input_data})
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Compute number of runs as higher of min_time or num_timed_runs
|
||||
num_runs = max(round(self.min_time / elapsed * self.num_warmup_runs), self.num_timed_runs)
|
||||
|
||||
# Timed runs
|
||||
run_times = []
|
||||
for _ in tqdm(range(num_runs), desc=onnx_file):
|
||||
start_time = time.time()
|
||||
sess.run([output_name], {input_name: input_data})
|
||||
run_times.append((time.time() - start_time) * 1000) # Convert to milliseconds
|
||||
|
||||
run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=5) # sigma clipping
|
||||
return np.mean(run_times), np.std(run_times)
|
||||
|
||||
def generate_table_row(self, model_name, t_onnx, t_engine, model_info):
|
||||
layers, params, gradients, flops = model_info
|
||||
return f'| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± {t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |'
|
||||
|
||||
def generate_results_dict(self, model_name, t_onnx, t_engine, model_info):
|
||||
layers, params, gradients, flops = model_info
|
||||
return {
|
||||
'model/name': model_name,
|
||||
'model/parameters': params,
|
||||
'model/GFLOPs': round(flops, 3),
|
||||
'model/speed_ONNX(ms)': round(t_onnx[0], 3),
|
||||
'model/speed_TensorRT(ms)': round(t_engine[0], 3)}
|
||||
|
||||
def print_table(self, table_rows):
|
||||
gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'GPU'
|
||||
header = f'| Model | size<br><sup>(pixels) | mAP<sup>val<br>50-95 | Speed<br><sup>CPU ONNX<br>(ms) | Speed<br><sup>{gpu} TensorRT<br>(ms) | params<br><sup>(M) | FLOPs<br><sup>(B) |'
|
||||
separator = '|-------------|---------------------|--------------------|------------------------------|-----------------------------------|------------------|-----------------|'
|
||||
|
||||
print(f'\n\n{header}')
|
||||
print(separator)
|
||||
for row in table_rows:
|
||||
print(row)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Benchmark all export formats
|
||||
benchmark()
|
||||
|
||||
# Profiling models on ONNX and TensorRT
|
||||
ProfileModels(['yolov8n.yaml', 'yolov8s.yaml'])
|
@ -1,5 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from .base import add_integration_callbacks, default_callbacks, get_default_callbacks
|
||||
|
||||
__all__ = 'add_integration_callbacks', 'default_callbacks', 'get_default_callbacks'
|
@ -1,212 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Base callbacks
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
|
||||
# Trainer callbacks ----------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Called before the pretraining routine starts."""
|
||||
pass
|
||||
|
||||
|
||||
def on_pretrain_routine_end(trainer):
|
||||
"""Called after the pretraining routine ends."""
|
||||
pass
|
||||
|
||||
|
||||
def on_train_start(trainer):
|
||||
"""Called when the training starts."""
|
||||
pass
|
||||
|
||||
|
||||
def on_train_epoch_start(trainer):
|
||||
"""Called at the start of each training epoch."""
|
||||
pass
|
||||
|
||||
|
||||
def on_train_batch_start(trainer):
|
||||
"""Called at the start of each training batch."""
|
||||
pass
|
||||
|
||||
|
||||
def optimizer_step(trainer):
|
||||
"""Called when the optimizer takes a step."""
|
||||
pass
|
||||
|
||||
|
||||
def on_before_zero_grad(trainer):
|
||||
"""Called before the gradients are set to zero."""
|
||||
pass
|
||||
|
||||
|
||||
def on_train_batch_end(trainer):
|
||||
"""Called at the end of each training batch."""
|
||||
pass
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Called at the end of each training epoch."""
|
||||
pass
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Called at the end of each fit epoch (train + val)."""
|
||||
pass
|
||||
|
||||
|
||||
def on_model_save(trainer):
|
||||
"""Called when the model is saved."""
|
||||
pass
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Called when the training ends."""
|
||||
pass
|
||||
|
||||
|
||||
def on_params_update(trainer):
|
||||
"""Called when the model parameters are updated."""
|
||||
pass
|
||||
|
||||
|
||||
def teardown(trainer):
|
||||
"""Called during the teardown of the training process."""
|
||||
pass
|
||||
|
||||
|
||||
# Validator callbacks --------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def on_val_start(validator):
|
||||
"""Called when the validation starts."""
|
||||
pass
|
||||
|
||||
|
||||
def on_val_batch_start(validator):
|
||||
"""Called at the start of each validation batch."""
|
||||
pass
|
||||
|
||||
|
||||
def on_val_batch_end(validator):
|
||||
"""Called at the end of each validation batch."""
|
||||
pass
|
||||
|
||||
|
||||
def on_val_end(validator):
|
||||
"""Called when the validation ends."""
|
||||
pass
|
||||
|
||||
|
||||
# Predictor callbacks --------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def on_predict_start(predictor):
|
||||
"""Called when the prediction starts."""
|
||||
pass
|
||||
|
||||
|
||||
def on_predict_batch_start(predictor):
|
||||
"""Called at the start of each prediction batch."""
|
||||
pass
|
||||
|
||||
|
||||
def on_predict_batch_end(predictor):
|
||||
"""Called at the end of each prediction batch."""
|
||||
pass
|
||||
|
||||
|
||||
def on_predict_postprocess_end(predictor):
|
||||
"""Called after the post-processing of the prediction ends."""
|
||||
pass
|
||||
|
||||
|
||||
def on_predict_end(predictor):
|
||||
"""Called when the prediction ends."""
|
||||
pass
|
||||
|
||||
|
||||
# Exporter callbacks ---------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def on_export_start(exporter):
|
||||
"""Called when the model export starts."""
|
||||
pass
|
||||
|
||||
|
||||
def on_export_end(exporter):
|
||||
"""Called when the model export ends."""
|
||||
pass
|
||||
|
||||
|
||||
default_callbacks = {
|
||||
# Run in trainer
|
||||
'on_pretrain_routine_start': [on_pretrain_routine_start],
|
||||
'on_pretrain_routine_end': [on_pretrain_routine_end],
|
||||
'on_train_start': [on_train_start],
|
||||
'on_train_epoch_start': [on_train_epoch_start],
|
||||
'on_train_batch_start': [on_train_batch_start],
|
||||
'optimizer_step': [optimizer_step],
|
||||
'on_before_zero_grad': [on_before_zero_grad],
|
||||
'on_train_batch_end': [on_train_batch_end],
|
||||
'on_train_epoch_end': [on_train_epoch_end],
|
||||
'on_fit_epoch_end': [on_fit_epoch_end], # fit = train + val
|
||||
'on_model_save': [on_model_save],
|
||||
'on_train_end': [on_train_end],
|
||||
'on_params_update': [on_params_update],
|
||||
'teardown': [teardown],
|
||||
|
||||
# Run in validator
|
||||
'on_val_start': [on_val_start],
|
||||
'on_val_batch_start': [on_val_batch_start],
|
||||
'on_val_batch_end': [on_val_batch_end],
|
||||
'on_val_end': [on_val_end],
|
||||
|
||||
# Run in predictor
|
||||
'on_predict_start': [on_predict_start],
|
||||
'on_predict_batch_start': [on_predict_batch_start],
|
||||
'on_predict_postprocess_end': [on_predict_postprocess_end],
|
||||
'on_predict_batch_end': [on_predict_batch_end],
|
||||
'on_predict_end': [on_predict_end],
|
||||
|
||||
# Run in exporter
|
||||
'on_export_start': [on_export_start],
|
||||
'on_export_end': [on_export_end]}
|
||||
|
||||
|
||||
def get_default_callbacks():
|
||||
"""
|
||||
Return a copy of the default_callbacks dictionary with lists as default values.
|
||||
|
||||
Returns:
|
||||
(defaultdict): A defaultdict with keys from default_callbacks and empty lists as default values.
|
||||
"""
|
||||
return defaultdict(list, deepcopy(default_callbacks))
|
||||
|
||||
|
||||
def add_integration_callbacks(instance):
|
||||
"""
|
||||
Add integration callbacks from various sources to the instance's callbacks.
|
||||
|
||||
Args:
|
||||
instance (Trainer, Predictor, Validator, Exporter): An object with a 'callbacks' attribute that is a dictionary
|
||||
of callback lists.
|
||||
"""
|
||||
from .clearml import callbacks as clearml_cb
|
||||
from .comet import callbacks as comet_cb
|
||||
from .dvc import callbacks as dvc_cb
|
||||
from .hub import callbacks as hub_cb
|
||||
from .mlflow import callbacks as mlflow_cb
|
||||
from .neptune import callbacks as neptune_cb
|
||||
from .raytune import callbacks as tune_cb
|
||||
from .tensorboard import callbacks as tensorboard_cb
|
||||
from .wb import callbacks as wb_cb
|
||||
|
||||
for x in clearml_cb, comet_cb, hub_cb, mlflow_cb, neptune_cb, tune_cb, tensorboard_cb, wb_cb, dvc_cb:
|
||||
for k, v in x.items():
|
||||
if v not in instance.callbacks[k]: # prevent duplicate callbacks addition
|
||||
instance.callbacks[k].append(v) # callback[name].append(func)
|
@ -1,143 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import re
|
||||
|
||||
import matplotlib.image as mpimg
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING
|
||||
from ultralytics.yolo.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
try:
|
||||
import clearml
|
||||
from clearml import Task
|
||||
from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
||||
from clearml.binding.matplotlib_bind import PatchedMatplotlib
|
||||
|
||||
assert hasattr(clearml, '__version__') # verify package is not directory
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
except (ImportError, AssertionError):
|
||||
clearml = None
|
||||
|
||||
|
||||
def _log_debug_samples(files, title='Debug Samples') -> None:
|
||||
"""
|
||||
Log files (images) as debug samples in the ClearML task.
|
||||
|
||||
Args:
|
||||
files (list): A list of file paths in PosixPath format.
|
||||
title (str): A title that groups together images with the same values.
|
||||
"""
|
||||
task = Task.current_task()
|
||||
if task:
|
||||
for f in files:
|
||||
if f.exists():
|
||||
it = re.search(r'_batch(\d+)', f.name)
|
||||
iteration = int(it.groups()[0]) if it else 0
|
||||
task.get_logger().report_image(title=title,
|
||||
series=f.name.replace(it.group(), ''),
|
||||
local_path=str(f),
|
||||
iteration=iteration)
|
||||
|
||||
|
||||
def _log_plot(title, plot_path) -> None:
|
||||
"""
|
||||
Log an image as a plot in the plot section of ClearML.
|
||||
|
||||
Args:
|
||||
title (str): The title of the plot.
|
||||
plot_path (str): The path to the saved image file.
|
||||
"""
|
||||
img = mpimg.imread(plot_path)
|
||||
fig = plt.figure()
|
||||
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
|
||||
ax.imshow(img)
|
||||
|
||||
Task.current_task().get_logger().report_matplotlib_figure(title=title,
|
||||
series='',
|
||||
figure=fig,
|
||||
report_interactive=False)
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Runs at start of pretraining routine; initializes and connects/ logs task to ClearML."""
|
||||
try:
|
||||
task = Task.current_task()
|
||||
if task:
|
||||
# Make sure the automatic pytorch and matplotlib bindings are disabled!
|
||||
# We are logging these plots and model files manually in the integration
|
||||
PatchPyTorchModelIO.update_current_task(None)
|
||||
PatchedMatplotlib.update_current_task(None)
|
||||
else:
|
||||
task = Task.init(project_name=trainer.args.project or 'YOLOv8',
|
||||
task_name=trainer.args.name,
|
||||
tags=['YOLOv8'],
|
||||
output_uri=True,
|
||||
reuse_last_task_id=False,
|
||||
auto_connect_frameworks={
|
||||
'pytorch': False,
|
||||
'matplotlib': False})
|
||||
LOGGER.warning('ClearML Initialized a new task. If you want to run remotely, '
|
||||
'please add clearml-init and connect your arguments before initializing YOLO.')
|
||||
task.connect(vars(trainer.args), name='General')
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}')
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
task = Task.current_task()
|
||||
|
||||
if task:
|
||||
"""Logs debug samples for the first epoch of YOLO training."""
|
||||
if trainer.epoch == 1:
|
||||
_log_debug_samples(sorted(trainer.save_dir.glob('train_batch*.jpg')), 'Mosaic')
|
||||
"""Report the current training progress."""
|
||||
for k, v in trainer.validator.metrics.results_dict.items():
|
||||
task.get_logger().report_scalar('train', k, v, iteration=trainer.epoch)
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Reports model information to logger at the end of an epoch."""
|
||||
task = Task.current_task()
|
||||
if task:
|
||||
# You should have access to the validation bboxes under jdict
|
||||
task.get_logger().report_scalar(title='Epoch Time',
|
||||
series='Epoch Time',
|
||||
value=trainer.epoch_time,
|
||||
iteration=trainer.epoch)
|
||||
if trainer.epoch == 0:
|
||||
for k, v in model_info_for_loggers(trainer).items():
|
||||
task.get_logger().report_single_value(k, v)
|
||||
|
||||
|
||||
def on_val_end(validator):
|
||||
"""Logs validation results including labels and predictions."""
|
||||
if Task.current_task():
|
||||
# Log val_labels and val_pred
|
||||
_log_debug_samples(sorted(validator.save_dir.glob('val*.jpg')), 'Validation')
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Logs final model and its name on training completion."""
|
||||
task = Task.current_task()
|
||||
if task:
|
||||
# Log final results, CM matrix + PR plots
|
||||
files = [
|
||||
'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png',
|
||||
*(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
|
||||
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
|
||||
for f in files:
|
||||
_log_plot(title=f.stem, plot_path=f)
|
||||
# Report final metrics
|
||||
for k, v in trainer.validator.metrics.results_dict.items():
|
||||
task.get_logger().report_single_value(k, v)
|
||||
# Log the final model
|
||||
task.update_output_model(model_path=str(trainer.best), model_name=trainer.args.name, auto_delete_file=False)
|
||||
|
||||
|
||||
callbacks = {
|
||||
'on_pretrain_routine_start': on_pretrain_routine_start,
|
||||
'on_train_epoch_end': on_train_epoch_end,
|
||||
'on_fit_epoch_end': on_fit_epoch_end,
|
||||
'on_val_end': on_val_end,
|
||||
'on_train_end': on_train_end} if clearml else {}
|
@ -1,368 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER, RANK, TESTS_RUNNING, ops
|
||||
from ultralytics.yolo.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
try:
|
||||
import comet_ml
|
||||
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert hasattr(comet_ml, '__version__') # verify package is not directory
|
||||
except (ImportError, AssertionError):
|
||||
comet_ml = None
|
||||
|
||||
# Ensures certain logging functions only run for supported tasks
|
||||
COMET_SUPPORTED_TASKS = ['detect']
|
||||
|
||||
# Names of plots created by YOLOv8 that are logged to Comet
|
||||
EVALUATION_PLOT_NAMES = 'F1_curve', 'P_curve', 'R_curve', 'PR_curve', 'confusion_matrix'
|
||||
LABEL_PLOT_NAMES = 'labels', 'labels_correlogram'
|
||||
|
||||
_comet_image_prediction_count = 0
|
||||
|
||||
|
||||
def _get_comet_mode():
|
||||
return os.getenv('COMET_MODE', 'online')
|
||||
|
||||
|
||||
def _get_comet_model_name():
|
||||
return os.getenv('COMET_MODEL_NAME', 'YOLOv8')
|
||||
|
||||
|
||||
def _get_eval_batch_logging_interval():
|
||||
return int(os.getenv('COMET_EVAL_BATCH_LOGGING_INTERVAL', 1))
|
||||
|
||||
|
||||
def _get_max_image_predictions_to_log():
|
||||
return int(os.getenv('COMET_MAX_IMAGE_PREDICTIONS', 100))
|
||||
|
||||
|
||||
def _scale_confidence_score(score):
|
||||
scale = float(os.getenv('COMET_MAX_CONFIDENCE_SCORE', 100.0))
|
||||
return score * scale
|
||||
|
||||
|
||||
def _should_log_confusion_matrix():
|
||||
return os.getenv('COMET_EVAL_LOG_CONFUSION_MATRIX', 'false').lower() == 'true'
|
||||
|
||||
|
||||
def _should_log_image_predictions():
|
||||
return os.getenv('COMET_EVAL_LOG_IMAGE_PREDICTIONS', 'true').lower() == 'true'
|
||||
|
||||
|
||||
def _get_experiment_type(mode, project_name):
|
||||
"""Return an experiment based on mode and project name."""
|
||||
if mode == 'offline':
|
||||
return comet_ml.OfflineExperiment(project_name=project_name)
|
||||
|
||||
return comet_ml.Experiment(project_name=project_name)
|
||||
|
||||
|
||||
def _create_experiment(args):
|
||||
"""Ensures that the experiment object is only created in a single process during distributed training."""
|
||||
if RANK not in (-1, 0):
|
||||
return
|
||||
try:
|
||||
comet_mode = _get_comet_mode()
|
||||
_project_name = os.getenv('COMET_PROJECT_NAME', args.project)
|
||||
experiment = _get_experiment_type(comet_mode, _project_name)
|
||||
experiment.log_parameters(vars(args))
|
||||
experiment.log_others({
|
||||
'eval_batch_logging_interval': _get_eval_batch_logging_interval(),
|
||||
'log_confusion_matrix_on_eval': _should_log_confusion_matrix(),
|
||||
'log_image_predictions': _should_log_image_predictions(),
|
||||
'max_image_predictions': _get_max_image_predictions_to_log(), })
|
||||
experiment.log_other('Created from', 'yolov8')
|
||||
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}')
|
||||
|
||||
|
||||
def _fetch_trainer_metadata(trainer):
|
||||
"""Returns metadata for YOLO training including epoch and asset saving status."""
|
||||
curr_epoch = trainer.epoch + 1
|
||||
|
||||
train_num_steps_per_epoch = len(trainer.train_loader.dataset) // trainer.batch_size
|
||||
curr_step = curr_epoch * train_num_steps_per_epoch
|
||||
final_epoch = curr_epoch == trainer.epochs
|
||||
|
||||
save = trainer.args.save
|
||||
save_period = trainer.args.save_period
|
||||
save_interval = curr_epoch % save_period == 0
|
||||
save_assets = save and save_period > 0 and save_interval and not final_epoch
|
||||
|
||||
return dict(
|
||||
curr_epoch=curr_epoch,
|
||||
curr_step=curr_step,
|
||||
save_assets=save_assets,
|
||||
final_epoch=final_epoch,
|
||||
)
|
||||
|
||||
|
||||
def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad):
|
||||
"""YOLOv8 resizes images during training and the label values
|
||||
are normalized based on this resized shape. This function rescales the
|
||||
bounding box labels to the original image shape.
|
||||
"""
|
||||
|
||||
resized_image_height, resized_image_width = resized_image_shape
|
||||
|
||||
# Convert normalized xywh format predictions to xyxy in resized scale format
|
||||
box = ops.xywhn2xyxy(box, h=resized_image_height, w=resized_image_width)
|
||||
# Scale box predictions from resized image scale back to original image scale
|
||||
box = ops.scale_boxes(resized_image_shape, box, original_image_shape, ratio_pad)
|
||||
# Convert bounding box format from xyxy to xywh for Comet logging
|
||||
box = ops.xyxy2xywh(box)
|
||||
# Adjust xy center to correspond top-left corner
|
||||
box[:2] -= box[2:] / 2
|
||||
box = box.tolist()
|
||||
|
||||
return box
|
||||
|
||||
|
||||
def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None):
|
||||
"""Format ground truth annotations for detection."""
|
||||
indices = batch['batch_idx'] == img_idx
|
||||
bboxes = batch['bboxes'][indices]
|
||||
if len(bboxes) == 0:
|
||||
LOGGER.debug(f'COMET WARNING: Image: {image_path} has no bounding boxes labels')
|
||||
return None
|
||||
|
||||
cls_labels = batch['cls'][indices].squeeze(1).tolist()
|
||||
if class_name_map:
|
||||
cls_labels = [str(class_name_map[label]) for label in cls_labels]
|
||||
|
||||
original_image_shape = batch['ori_shape'][img_idx]
|
||||
resized_image_shape = batch['resized_shape'][img_idx]
|
||||
ratio_pad = batch['ratio_pad'][img_idx]
|
||||
|
||||
data = []
|
||||
for box, label in zip(bboxes, cls_labels):
|
||||
box = _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad)
|
||||
data.append({
|
||||
'boxes': [box],
|
||||
'label': f'gt_{label}',
|
||||
'score': _scale_confidence_score(1.0), })
|
||||
|
||||
return {'name': 'ground_truth', 'data': data}
|
||||
|
||||
|
||||
def _format_prediction_annotations_for_detection(image_path, metadata, class_label_map=None):
|
||||
"""Format YOLO predictions for object detection visualization."""
|
||||
stem = image_path.stem
|
||||
image_id = int(stem) if stem.isnumeric() else stem
|
||||
|
||||
predictions = metadata.get(image_id)
|
||||
if not predictions:
|
||||
LOGGER.debug(f'COMET WARNING: Image: {image_path} has no bounding boxes predictions')
|
||||
return None
|
||||
|
||||
data = []
|
||||
for prediction in predictions:
|
||||
boxes = prediction['bbox']
|
||||
score = _scale_confidence_score(prediction['score'])
|
||||
cls_label = prediction['category_id']
|
||||
if class_label_map:
|
||||
cls_label = str(class_label_map[cls_label])
|
||||
|
||||
data.append({'boxes': [boxes], 'label': cls_label, 'score': score})
|
||||
|
||||
return {'name': 'prediction', 'data': data}
|
||||
|
||||
|
||||
def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, class_label_map):
|
||||
"""Join the ground truth and prediction annotations if they exist."""
|
||||
ground_truth_annotations = _format_ground_truth_annotations_for_detection(img_idx, image_path, batch,
|
||||
class_label_map)
|
||||
prediction_annotations = _format_prediction_annotations_for_detection(image_path, prediction_metadata_map,
|
||||
class_label_map)
|
||||
|
||||
annotations = [
|
||||
annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None]
|
||||
return [annotations] if annotations else None
|
||||
|
||||
|
||||
def _create_prediction_metadata_map(model_predictions):
|
||||
"""Create metadata map for model predictions by groupings them based on image ID."""
|
||||
pred_metadata_map = {}
|
||||
for prediction in model_predictions:
|
||||
pred_metadata_map.setdefault(prediction['image_id'], [])
|
||||
pred_metadata_map[prediction['image_id']].append(prediction)
|
||||
|
||||
return pred_metadata_map
|
||||
|
||||
|
||||
def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch):
|
||||
"""Log the confusion matrix to Comet experiment."""
|
||||
conf_mat = trainer.validator.confusion_matrix.matrix
|
||||
names = list(trainer.data['names'].values()) + ['background']
|
||||
experiment.log_confusion_matrix(
|
||||
matrix=conf_mat,
|
||||
labels=names,
|
||||
max_categories=len(names),
|
||||
epoch=curr_epoch,
|
||||
step=curr_step,
|
||||
)
|
||||
|
||||
|
||||
def _log_images(experiment, image_paths, curr_step, annotations=None):
|
||||
"""Logs images to the experiment with optional annotations."""
|
||||
if annotations:
|
||||
for image_path, annotation in zip(image_paths, annotations):
|
||||
experiment.log_image(image_path, name=image_path.stem, step=curr_step, annotations=annotation)
|
||||
|
||||
else:
|
||||
for image_path in image_paths:
|
||||
experiment.log_image(image_path, name=image_path.stem, step=curr_step)
|
||||
|
||||
|
||||
def _log_image_predictions(experiment, validator, curr_step):
|
||||
"""Logs predicted boxes for a single image during training."""
|
||||
global _comet_image_prediction_count
|
||||
|
||||
task = validator.args.task
|
||||
if task not in COMET_SUPPORTED_TASKS:
|
||||
return
|
||||
|
||||
jdict = validator.jdict
|
||||
if not jdict:
|
||||
return
|
||||
|
||||
predictions_metadata_map = _create_prediction_metadata_map(jdict)
|
||||
dataloader = validator.dataloader
|
||||
class_label_map = validator.names
|
||||
|
||||
batch_logging_interval = _get_eval_batch_logging_interval()
|
||||
max_image_predictions = _get_max_image_predictions_to_log()
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
if (batch_idx + 1) % batch_logging_interval != 0:
|
||||
continue
|
||||
|
||||
image_paths = batch['im_file']
|
||||
for img_idx, image_path in enumerate(image_paths):
|
||||
if _comet_image_prediction_count >= max_image_predictions:
|
||||
return
|
||||
|
||||
image_path = Path(image_path)
|
||||
annotations = _fetch_annotations(
|
||||
img_idx,
|
||||
image_path,
|
||||
batch,
|
||||
predictions_metadata_map,
|
||||
class_label_map,
|
||||
)
|
||||
_log_images(
|
||||
experiment,
|
||||
[image_path],
|
||||
curr_step,
|
||||
annotations=annotations,
|
||||
)
|
||||
_comet_image_prediction_count += 1
|
||||
|
||||
|
||||
def _log_plots(experiment, trainer):
|
||||
"""Logs evaluation plots and label plots for the experiment."""
|
||||
plot_filenames = [trainer.save_dir / f'{plots}.png' for plots in EVALUATION_PLOT_NAMES]
|
||||
_log_images(experiment, plot_filenames, None)
|
||||
|
||||
label_plot_filenames = [trainer.save_dir / f'{labels}.jpg' for labels in LABEL_PLOT_NAMES]
|
||||
_log_images(experiment, label_plot_filenames, None)
|
||||
|
||||
|
||||
def _log_model(experiment, trainer):
|
||||
"""Log the best-trained model to Comet.ml."""
|
||||
model_name = _get_comet_model_name()
|
||||
experiment.log_model(
|
||||
model_name,
|
||||
file_or_folder=str(trainer.best),
|
||||
file_name='best.pt',
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Creates or resumes a CometML experiment at the start of a YOLO pre-training routine."""
|
||||
experiment = comet_ml.get_global_experiment()
|
||||
is_alive = getattr(experiment, 'alive', False)
|
||||
if not experiment or not is_alive:
|
||||
_create_experiment(trainer.args)
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Log metrics and save batch images at the end of training epochs."""
|
||||
experiment = comet_ml.get_global_experiment()
|
||||
if not experiment:
|
||||
return
|
||||
|
||||
metadata = _fetch_trainer_metadata(trainer)
|
||||
curr_epoch = metadata['curr_epoch']
|
||||
curr_step = metadata['curr_step']
|
||||
|
||||
experiment.log_metrics(
|
||||
trainer.label_loss_items(trainer.tloss, prefix='train'),
|
||||
step=curr_step,
|
||||
epoch=curr_epoch,
|
||||
)
|
||||
|
||||
if curr_epoch == 1:
|
||||
_log_images(experiment, trainer.save_dir.glob('train_batch*.jpg'), curr_step)
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Logs model assets at the end of each epoch."""
|
||||
experiment = comet_ml.get_global_experiment()
|
||||
if not experiment:
|
||||
return
|
||||
|
||||
metadata = _fetch_trainer_metadata(trainer)
|
||||
curr_epoch = metadata['curr_epoch']
|
||||
curr_step = metadata['curr_step']
|
||||
save_assets = metadata['save_assets']
|
||||
|
||||
experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch)
|
||||
experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch)
|
||||
if curr_epoch == 1:
|
||||
experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch)
|
||||
|
||||
if not save_assets:
|
||||
return
|
||||
|
||||
_log_model(experiment, trainer)
|
||||
if _should_log_confusion_matrix():
|
||||
_log_confusion_matrix(experiment, trainer, curr_step, curr_epoch)
|
||||
if _should_log_image_predictions():
|
||||
_log_image_predictions(experiment, trainer.validator, curr_step)
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Perform operations at the end of training."""
|
||||
experiment = comet_ml.get_global_experiment()
|
||||
if not experiment:
|
||||
return
|
||||
|
||||
metadata = _fetch_trainer_metadata(trainer)
|
||||
curr_epoch = metadata['curr_epoch']
|
||||
curr_step = metadata['curr_step']
|
||||
plots = trainer.args.plots
|
||||
|
||||
_log_model(experiment, trainer)
|
||||
if plots:
|
||||
_log_plots(experiment, trainer)
|
||||
|
||||
_log_confusion_matrix(experiment, trainer, curr_step, curr_epoch)
|
||||
_log_image_predictions(experiment, trainer.validator, curr_step)
|
||||
experiment.end()
|
||||
|
||||
global _comet_image_prediction_count
|
||||
_comet_image_prediction_count = 0
|
||||
|
||||
|
||||
callbacks = {
|
||||
'on_pretrain_routine_start': on_pretrain_routine_start,
|
||||
'on_train_epoch_end': on_train_epoch_end,
|
||||
'on_fit_epoch_end': on_fit_epoch_end,
|
||||
'on_train_end': on_train_end} if comet_ml else {}
|
@ -1,136 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
import os
|
||||
|
||||
import pkg_resources as pkg
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING
|
||||
from ultralytics.yolo.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
|
||||
import dvclive
|
||||
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
|
||||
ver = version('dvclive')
|
||||
if pkg.parse_version(ver) < pkg.parse_version('2.11.0'):
|
||||
LOGGER.debug(f'DVCLive is detected but version {ver} is incompatible (>=2.11 required).')
|
||||
dvclive = None # noqa: F811
|
||||
except (ImportError, AssertionError, TypeError):
|
||||
dvclive = None
|
||||
|
||||
# DVCLive logger instance
|
||||
live = None
|
||||
_processed_plots = {}
|
||||
|
||||
# `on_fit_epoch_end` is called on final validation (probably need to be fixed)
|
||||
# for now this is the way we distinguish final evaluation of the best model vs
|
||||
# last epoch validation
|
||||
_training_epoch = False
|
||||
|
||||
|
||||
def _logger_disabled():
|
||||
return os.getenv('ULTRALYTICS_DVC_DISABLED', 'false').lower() == 'true'
|
||||
|
||||
|
||||
def _log_images(image_path, prefix=''):
|
||||
if live:
|
||||
live.log_image(os.path.join(prefix, image_path.name), image_path)
|
||||
|
||||
|
||||
def _log_plots(plots, prefix=''):
|
||||
for name, params in plots.items():
|
||||
timestamp = params['timestamp']
|
||||
if _processed_plots.get(name) != timestamp:
|
||||
_log_images(name, prefix)
|
||||
_processed_plots[name] = timestamp
|
||||
|
||||
|
||||
def _log_confusion_matrix(validator):
|
||||
targets = []
|
||||
preds = []
|
||||
matrix = validator.confusion_matrix.matrix
|
||||
names = list(validator.names.values())
|
||||
if validator.confusion_matrix.task == 'detect':
|
||||
names += ['background']
|
||||
|
||||
for ti, pred in enumerate(matrix.T.astype(int)):
|
||||
for pi, num in enumerate(pred):
|
||||
targets.extend([names[ti]] * num)
|
||||
preds.extend([names[pi]] * num)
|
||||
|
||||
live.log_sklearn_plot('confusion_matrix', targets, preds, name='cf.json', normalized=True)
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
try:
|
||||
global live
|
||||
if not _logger_disabled():
|
||||
live = dvclive.Live(save_dvc_exp=True, cache_images=True)
|
||||
LOGGER.info(
|
||||
'DVCLive is detected and auto logging is enabled (can be disabled with `ULTRALYTICS_DVC_DISABLED=true`).'
|
||||
)
|
||||
else:
|
||||
LOGGER.debug('DVCLive is detected and auto logging is disabled via `ULTRALYTICS_DVC_DISABLED`.')
|
||||
live = None
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}')
|
||||
|
||||
|
||||
def on_pretrain_routine_end(trainer):
|
||||
_log_plots(trainer.plots, 'train')
|
||||
|
||||
|
||||
def on_train_start(trainer):
|
||||
if live:
|
||||
live.log_params(trainer.args)
|
||||
|
||||
|
||||
def on_train_epoch_start(trainer):
|
||||
global _training_epoch
|
||||
_training_epoch = True
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
global _training_epoch
|
||||
if live and _training_epoch:
|
||||
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
|
||||
for metric, value in all_metrics.items():
|
||||
live.log_metric(metric, value)
|
||||
|
||||
if trainer.epoch == 0:
|
||||
for metric, value in model_info_for_loggers(trainer).items():
|
||||
live.log_metric(metric, value, plot=False)
|
||||
|
||||
_log_plots(trainer.plots, 'train')
|
||||
_log_plots(trainer.validator.plots, 'val')
|
||||
|
||||
live.next_step()
|
||||
_training_epoch = False
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
if live:
|
||||
# At the end log the best metrics. It runs validator on the best model internally.
|
||||
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
|
||||
for metric, value in all_metrics.items():
|
||||
live.log_metric(metric, value, plot=False)
|
||||
|
||||
_log_plots(trainer.plots, 'eval')
|
||||
_log_plots(trainer.validator.plots, 'eval')
|
||||
_log_confusion_matrix(trainer.validator)
|
||||
|
||||
if trainer.best.exists():
|
||||
live.log_artifact(trainer.best, copy=True)
|
||||
|
||||
live.end()
|
||||
|
||||
|
||||
callbacks = {
|
||||
'on_pretrain_routine_start': on_pretrain_routine_start,
|
||||
'on_pretrain_routine_end': on_pretrain_routine_end,
|
||||
'on_train_start': on_train_start,
|
||||
'on_train_epoch_start': on_train_epoch_start,
|
||||
'on_fit_epoch_end': on_fit_epoch_end,
|
||||
'on_train_end': on_train_end} if dvclive else {}
|
@ -1,87 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import json
|
||||
from time import time
|
||||
|
||||
from ultralytics.hub.utils import PREFIX, events
|
||||
from ultralytics.yolo.utils import LOGGER
|
||||
from ultralytics.yolo.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
|
||||
def on_pretrain_routine_end(trainer):
|
||||
"""Logs info before starting timer for upload rate limit."""
|
||||
session = getattr(trainer, 'hub_session', None)
|
||||
if session:
|
||||
# Start timer for upload rate limit
|
||||
LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀')
|
||||
session.timers = {'metrics': time(), 'ckpt': time()} # start timer on session.rate_limit
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Uploads training progress metrics at the end of each epoch."""
|
||||
session = getattr(trainer, 'hub_session', None)
|
||||
if session:
|
||||
# Upload metrics after val end
|
||||
all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics}
|
||||
if trainer.epoch == 0:
|
||||
all_plots = {**all_plots, **model_info_for_loggers(trainer)}
|
||||
session.metrics_queue[trainer.epoch] = json.dumps(all_plots)
|
||||
if time() - session.timers['metrics'] > session.rate_limits['metrics']:
|
||||
session.upload_metrics()
|
||||
session.timers['metrics'] = time() # reset timer
|
||||
session.metrics_queue = {} # reset queue
|
||||
|
||||
|
||||
def on_model_save(trainer):
|
||||
"""Saves checkpoints to Ultralytics HUB with rate limiting."""
|
||||
session = getattr(trainer, 'hub_session', None)
|
||||
if session:
|
||||
# Upload checkpoints with rate limiting
|
||||
is_best = trainer.best_fitness == trainer.fitness
|
||||
if time() - session.timers['ckpt'] > session.rate_limits['ckpt']:
|
||||
LOGGER.info(f'{PREFIX}Uploading checkpoint https://hub.ultralytics.com/models/{session.model_id}')
|
||||
session.upload_model(trainer.epoch, trainer.last, is_best)
|
||||
session.timers['ckpt'] = time() # reset timer
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Upload final model and metrics to Ultralytics HUB at the end of training."""
|
||||
session = getattr(trainer, 'hub_session', None)
|
||||
if session:
|
||||
# Upload final model and metrics with exponential standoff
|
||||
LOGGER.info(f'{PREFIX}Syncing final model...')
|
||||
session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics.get('metrics/mAP50-95(B)', 0), final=True)
|
||||
session.alive = False # stop heartbeats
|
||||
LOGGER.info(f'{PREFIX}Done ✅\n'
|
||||
f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀')
|
||||
|
||||
|
||||
def on_train_start(trainer):
|
||||
"""Run events on train start."""
|
||||
events(trainer.args)
|
||||
|
||||
|
||||
def on_val_start(validator):
|
||||
"""Runs events on validation start."""
|
||||
events(validator.args)
|
||||
|
||||
|
||||
def on_predict_start(predictor):
|
||||
"""Run events on predict start."""
|
||||
events(predictor.args)
|
||||
|
||||
|
||||
def on_export_start(exporter):
|
||||
"""Run events on export start."""
|
||||
events(exporter.args)
|
||||
|
||||
|
||||
callbacks = {
|
||||
'on_pretrain_routine_end': on_pretrain_routine_end,
|
||||
'on_fit_epoch_end': on_fit_epoch_end,
|
||||
'on_model_save': on_model_save,
|
||||
'on_train_end': on_train_end,
|
||||
'on_train_start': on_train_start,
|
||||
'on_val_start': on_val_start,
|
||||
'on_predict_start': on_predict_start,
|
||||
'on_export_start': on_export_start}
|
@ -1,71 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING, colorstr
|
||||
|
||||
try:
|
||||
import mlflow
|
||||
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert hasattr(mlflow, '__version__') # verify package is not directory
|
||||
except (ImportError, AssertionError):
|
||||
mlflow = None
|
||||
|
||||
|
||||
def on_pretrain_routine_end(trainer):
|
||||
"""Logs training parameters to MLflow."""
|
||||
global mlflow, run, run_id, experiment_name
|
||||
|
||||
if os.environ.get('MLFLOW_TRACKING_URI') is None:
|
||||
mlflow = None
|
||||
|
||||
if mlflow:
|
||||
mlflow_location = os.environ['MLFLOW_TRACKING_URI'] # "http://192.168.xxx.xxx:5000"
|
||||
mlflow.set_tracking_uri(mlflow_location)
|
||||
|
||||
experiment_name = os.environ.get('MLFLOW_EXPERIMENT_NAME') or trainer.args.project or '/Shared/YOLOv8'
|
||||
run_name = os.environ.get('MLFLOW_RUN') or trainer.args.name
|
||||
experiment = mlflow.get_experiment_by_name(experiment_name)
|
||||
if experiment is None:
|
||||
mlflow.create_experiment(experiment_name)
|
||||
mlflow.set_experiment(experiment_name)
|
||||
|
||||
prefix = colorstr('MLFlow: ')
|
||||
try:
|
||||
run, active_run = mlflow, mlflow.active_run()
|
||||
if not active_run:
|
||||
active_run = mlflow.start_run(experiment_id=experiment.experiment_id, run_name=run_name)
|
||||
run_id = active_run.info.run_id
|
||||
LOGGER.info(f'{prefix}Using run_id({run_id}) at {mlflow_location}')
|
||||
run.log_params(vars(trainer.model.args))
|
||||
except Exception as err:
|
||||
LOGGER.error(f'{prefix}Failing init - {repr(err)}')
|
||||
LOGGER.warning(f'{prefix}Continuing without Mlflow')
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Logs training metrics to Mlflow."""
|
||||
if mlflow:
|
||||
metrics_dict = {f"{re.sub('[()]', '', k)}": float(v) for k, v in trainer.metrics.items()}
|
||||
run.log_metrics(metrics=metrics_dict, step=trainer.epoch)
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Called at end of train loop to log model artifact info."""
|
||||
if mlflow:
|
||||
root_dir = Path(__file__).resolve().parents[3]
|
||||
run.log_artifact(trainer.last)
|
||||
run.log_artifact(trainer.best)
|
||||
run.pyfunc.log_model(artifact_path=experiment_name,
|
||||
code_path=[str(root_dir)],
|
||||
artifacts={'model_path': str(trainer.save_dir)},
|
||||
python_model=run.pyfunc.PythonModel())
|
||||
|
||||
|
||||
callbacks = {
|
||||
'on_pretrain_routine_end': on_pretrain_routine_end,
|
||||
'on_fit_epoch_end': on_fit_epoch_end,
|
||||
'on_train_end': on_train_end} if mlflow else {}
|
@ -1,103 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import matplotlib.image as mpimg
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING
|
||||
from ultralytics.yolo.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
try:
|
||||
import neptune
|
||||
from neptune.types import File
|
||||
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert hasattr(neptune, '__version__')
|
||||
except (ImportError, AssertionError):
|
||||
neptune = None
|
||||
|
||||
run = None # NeptuneAI experiment logger instance
|
||||
|
||||
|
||||
def _log_scalars(scalars, step=0):
|
||||
"""Log scalars to the NeptuneAI experiment logger."""
|
||||
if run:
|
||||
for k, v in scalars.items():
|
||||
run[k].append(value=v, step=step)
|
||||
|
||||
|
||||
def _log_images(imgs_dict, group=''):
|
||||
"""Log scalars to the NeptuneAI experiment logger."""
|
||||
if run:
|
||||
for k, v in imgs_dict.items():
|
||||
run[f'{group}/{k}'].upload(File(v))
|
||||
|
||||
|
||||
def _log_plot(title, plot_path):
|
||||
"""Log plots to the NeptuneAI experiment logger."""
|
||||
"""
|
||||
Log image as plot in the plot section of NeptuneAI
|
||||
|
||||
arguments:
|
||||
title (str) Title of the plot
|
||||
plot_path (PosixPath or str) Path to the saved image file
|
||||
"""
|
||||
img = mpimg.imread(plot_path)
|
||||
fig = plt.figure()
|
||||
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
|
||||
ax.imshow(img)
|
||||
run[f'Plots/{title}'].upload(fig)
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Callback function called before the training routine starts."""
|
||||
try:
|
||||
global run
|
||||
run = neptune.init_run(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, tags=['YOLOv8'])
|
||||
run['Configuration/Hyperparameters'] = {k: '' if v is None else v for k, v in vars(trainer.args).items()}
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}')
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Callback function called at end of each training epoch."""
|
||||
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
|
||||
_log_scalars(trainer.lr, trainer.epoch + 1)
|
||||
if trainer.epoch == 1:
|
||||
_log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, 'Mosaic')
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Callback function called at end of each fit (train+val) epoch."""
|
||||
if run and trainer.epoch == 0:
|
||||
run['Configuration/Model'] = model_info_for_loggers(trainer)
|
||||
_log_scalars(trainer.metrics, trainer.epoch + 1)
|
||||
|
||||
|
||||
def on_val_end(validator):
|
||||
"""Callback function called at end of each validation."""
|
||||
if run:
|
||||
# Log val_labels and val_pred
|
||||
_log_images({f.stem: str(f) for f in validator.save_dir.glob('val*.jpg')}, 'Validation')
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Callback function called at end of training."""
|
||||
if run:
|
||||
# Log final results, CM matrix + PR plots
|
||||
files = [
|
||||
'results.png', 'confusion_matrix.png', 'confusion_matrix_normalized.png',
|
||||
*(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
|
||||
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
|
||||
for f in files:
|
||||
_log_plot(title=f.stem, plot_path=f)
|
||||
# Log the final model
|
||||
run[f'weights/{trainer.args.name or trainer.args.task}/{str(trainer.best.name)}'].upload(File(str(
|
||||
trainer.best)))
|
||||
|
||||
|
||||
callbacks = {
|
||||
'on_pretrain_routine_start': on_pretrain_routine_start,
|
||||
'on_train_epoch_end': on_train_epoch_end,
|
||||
'on_fit_epoch_end': on_fit_epoch_end,
|
||||
'on_val_end': on_val_end,
|
||||
'on_train_end': on_train_end} if neptune else {}
|
@ -1,20 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
try:
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.air import session
|
||||
except (ImportError, AssertionError):
|
||||
tune = None
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Sends training metrics to Ray Tune at end of each epoch."""
|
||||
if ray.tune.is_session_enabled():
|
||||
metrics = trainer.metrics
|
||||
metrics['epoch'] = trainer.epoch
|
||||
session.report(metrics)
|
||||
|
||||
|
||||
callbacks = {
|
||||
'on_fit_epoch_end': on_fit_epoch_end, } if tune else {}
|
@ -1,47 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING, colorstr
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
except (ImportError, AssertionError):
|
||||
SummaryWriter = None
|
||||
|
||||
writer = None # TensorBoard SummaryWriter instance
|
||||
|
||||
|
||||
def _log_scalars(scalars, step=0):
|
||||
"""Logs scalar values to TensorBoard."""
|
||||
if writer:
|
||||
for k, v in scalars.items():
|
||||
writer.add_scalar(k, v, step)
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Initialize TensorBoard logging with SummaryWriter."""
|
||||
if SummaryWriter:
|
||||
try:
|
||||
global writer
|
||||
writer = SummaryWriter(str(trainer.save_dir))
|
||||
prefix = colorstr('TensorBoard: ')
|
||||
LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}')
|
||||
|
||||
|
||||
def on_batch_end(trainer):
|
||||
"""Logs scalar statistics at the end of a training batch."""
|
||||
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Logs epoch metrics at end of training epoch."""
|
||||
_log_scalars(trainer.metrics, trainer.epoch + 1)
|
||||
|
||||
|
||||
callbacks = {
|
||||
'on_pretrain_routine_start': on_pretrain_routine_start,
|
||||
'on_fit_epoch_end': on_fit_epoch_end,
|
||||
'on_batch_end': on_batch_end}
|
@ -1,60 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
from ultralytics.yolo.utils import TESTS_RUNNING
|
||||
from ultralytics.yolo.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
try:
|
||||
import wandb as wb
|
||||
|
||||
assert hasattr(wb, '__version__')
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
except (ImportError, AssertionError):
|
||||
wb = None
|
||||
|
||||
_processed_plots = {}
|
||||
|
||||
|
||||
def _log_plots(plots, step):
|
||||
for name, params in plots.items():
|
||||
timestamp = params['timestamp']
|
||||
if _processed_plots.get(name, None) != timestamp:
|
||||
wb.run.log({name.stem: wb.Image(str(name))}, step=step)
|
||||
_processed_plots[name] = timestamp
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Initiate and start project if module is present."""
|
||||
wb.run or wb.init(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, config=vars(trainer.args))
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Logs training metrics and model information at the end of an epoch."""
|
||||
wb.run.log(trainer.metrics, step=trainer.epoch + 1)
|
||||
_log_plots(trainer.plots, step=trainer.epoch + 1)
|
||||
_log_plots(trainer.validator.plots, step=trainer.epoch + 1)
|
||||
if trainer.epoch == 0:
|
||||
wb.run.log(model_info_for_loggers(trainer), step=trainer.epoch + 1)
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Log metrics and save images at the end of each training epoch."""
|
||||
wb.run.log(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1)
|
||||
wb.run.log(trainer.lr, step=trainer.epoch + 1)
|
||||
if trainer.epoch == 1:
|
||||
_log_plots(trainer.plots, step=trainer.epoch + 1)
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Save the best model as an artifact at end of training."""
|
||||
_log_plots(trainer.validator.plots, step=trainer.epoch + 1)
|
||||
_log_plots(trainer.plots, step=trainer.epoch + 1)
|
||||
art = wb.Artifact(type='model', name=f'run_{wb.run.id}_model')
|
||||
if trainer.best.exists():
|
||||
art.add_file(trainer.best)
|
||||
wb.run.log_artifact(art, aliases=['best'])
|
||||
|
||||
|
||||
callbacks = {
|
||||
'on_pretrain_routine_start': on_pretrain_routine_start,
|
||||
'on_train_epoch_end': on_train_epoch_end,
|
||||
'on_fit_epoch_end': on_fit_epoch_end,
|
||||
'on_train_end': on_train_end} if wb else {}
|
@ -1,459 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
import contextlib
|
||||
import glob
|
||||
import inspect
|
||||
import math
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pkg_resources as pkg
|
||||
import psutil
|
||||
import requests
|
||||
import torch
|
||||
from matplotlib import font_manager
|
||||
|
||||
from ultralytics.yolo.utils import (AUTOINSTALL, LOGGER, ONLINE, ROOT, USER_CONFIG_DIR, ThreadingLocked, TryExcept,
|
||||
clean_url, colorstr, downloads, emojis, is_colab, is_docker, is_jupyter, is_kaggle,
|
||||
is_online, is_pip_package, url2file)
|
||||
|
||||
|
||||
def is_ascii(s) -> bool:
|
||||
"""
|
||||
Check if a string is composed of only ASCII characters.
|
||||
|
||||
Args:
|
||||
s (str): String to be checked.
|
||||
|
||||
Returns:
|
||||
bool: True if the string is composed only of ASCII characters, False otherwise.
|
||||
"""
|
||||
# Convert list, tuple, None, etc. to string
|
||||
s = str(s)
|
||||
|
||||
# Check if the string is composed of only ASCII characters
|
||||
return all(ord(c) < 128 for c in s)
|
||||
|
||||
|
||||
def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
|
||||
"""
|
||||
Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the
|
||||
stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value.
|
||||
|
||||
Args:
|
||||
imgsz (int | cList[int]): Image size.
|
||||
stride (int): Stride value.
|
||||
min_dim (int): Minimum number of dimensions.
|
||||
floor (int): Minimum allowed value for image size.
|
||||
|
||||
Returns:
|
||||
(List[int]): Updated image size.
|
||||
"""
|
||||
# Convert stride to integer if it is a tensor
|
||||
stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride)
|
||||
|
||||
# Convert image size to list if it is an integer
|
||||
if isinstance(imgsz, int):
|
||||
imgsz = [imgsz]
|
||||
elif isinstance(imgsz, (list, tuple)):
|
||||
imgsz = list(imgsz)
|
||||
else:
|
||||
raise TypeError(f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. "
|
||||
f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'")
|
||||
|
||||
# Apply max_dim
|
||||
if len(imgsz) > max_dim:
|
||||
msg = "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list " \
|
||||
"or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'"
|
||||
if max_dim != 1:
|
||||
raise ValueError(f'imgsz={imgsz} is not a valid image size. {msg}')
|
||||
LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}")
|
||||
imgsz = [max(imgsz)]
|
||||
# Make image size a multiple of the stride
|
||||
sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz]
|
||||
|
||||
# Print warning message if image size was updated
|
||||
if sz != imgsz:
|
||||
LOGGER.warning(f'WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}')
|
||||
|
||||
# Add missing dimensions if necessary
|
||||
sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz
|
||||
|
||||
return sz
|
||||
|
||||
|
||||
def check_version(current: str = '0.0.0',
|
||||
minimum: str = '0.0.0',
|
||||
name: str = 'version ',
|
||||
pinned: bool = False,
|
||||
hard: bool = False,
|
||||
verbose: bool = False) -> bool:
|
||||
"""
|
||||
Check current version against the required minimum version.
|
||||
|
||||
Args:
|
||||
current (str): Current version.
|
||||
minimum (str): Required minimum version.
|
||||
name (str): Name to be used in warning message.
|
||||
pinned (bool): If True, versions must match exactly. If False, minimum version must be satisfied.
|
||||
hard (bool): If True, raise an AssertionError if the minimum version is not met.
|
||||
verbose (bool): If True, print warning message if minimum version is not met.
|
||||
|
||||
Returns:
|
||||
(bool): True if minimum version is met, False otherwise.
|
||||
"""
|
||||
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
|
||||
result = (current == minimum) if pinned else (current >= minimum) # bool
|
||||
warning_message = f'WARNING ⚠️ {name}{minimum} is required by YOLOv8, but {name}{current} is currently installed'
|
||||
if hard:
|
||||
assert result, emojis(warning_message) # assert min requirements met
|
||||
if verbose and not result:
|
||||
LOGGER.warning(warning_message)
|
||||
return result
|
||||
|
||||
|
||||
def check_latest_pypi_version(package_name='ultralytics'):
|
||||
"""
|
||||
Returns the latest version of a PyPI package without downloading or installing it.
|
||||
|
||||
Parameters:
|
||||
package_name (str): The name of the package to find the latest version for.
|
||||
|
||||
Returns:
|
||||
(str): The latest version of the package.
|
||||
"""
|
||||
with contextlib.suppress(Exception):
|
||||
requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning
|
||||
response = requests.get(f'https://pypi.org/pypi/{package_name}/json', timeout=3)
|
||||
if response.status_code == 200:
|
||||
return response.json()['info']['version']
|
||||
return None
|
||||
|
||||
|
||||
def check_pip_update_available():
|
||||
"""
|
||||
Checks if a new version of the ultralytics package is available on PyPI.
|
||||
|
||||
Returns:
|
||||
(bool): True if an update is available, False otherwise.
|
||||
"""
|
||||
if ONLINE and is_pip_package():
|
||||
with contextlib.suppress(Exception):
|
||||
from ultralytics import __version__
|
||||
latest = check_latest_pypi_version()
|
||||
if pkg.parse_version(__version__) < pkg.parse_version(latest): # update is available
|
||||
LOGGER.info(f'New https://pypi.org/project/ultralytics/{latest} available 😃 '
|
||||
f"Update with 'pip install -U ultralytics'")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@ThreadingLocked()
|
||||
def check_font(font='Arial.ttf'):
|
||||
"""
|
||||
Find font locally or download to user's configuration directory if it does not already exist.
|
||||
|
||||
Args:
|
||||
font (str): Path or name of font.
|
||||
|
||||
Returns:
|
||||
file (Path): Resolved font file path.
|
||||
"""
|
||||
name = Path(font).name
|
||||
|
||||
# Check USER_CONFIG_DIR
|
||||
file = USER_CONFIG_DIR / name
|
||||
if file.exists():
|
||||
return file
|
||||
|
||||
# Check system fonts
|
||||
matches = [s for s in font_manager.findSystemFonts() if font in s]
|
||||
if any(matches):
|
||||
return matches[0]
|
||||
|
||||
# Download to USER_CONFIG_DIR if missing
|
||||
url = f'https://ultralytics.com/assets/{name}'
|
||||
if downloads.is_url(url):
|
||||
downloads.safe_download(url=url, file=file)
|
||||
return file
|
||||
|
||||
|
||||
def check_python(minimum: str = '3.7.0') -> bool:
|
||||
"""
|
||||
Check current python version against the required minimum version.
|
||||
|
||||
Args:
|
||||
minimum (str): Required minimum version of python.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
return check_version(platform.python_version(), minimum, name='Python ', hard=True)
|
||||
|
||||
|
||||
@TryExcept()
|
||||
def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=(), install=True, cmds=''):
|
||||
"""
|
||||
Check if installed dependencies meet YOLOv8 requirements and attempt to auto-update if needed.
|
||||
|
||||
Args:
|
||||
requirements (Union[Path, str, List[str]]): Path to a requirements.txt file, a single package requirement as a
|
||||
string, or a list of package requirements as strings.
|
||||
exclude (Tuple[str]): Tuple of package names to exclude from checking.
|
||||
install (bool): If True, attempt to auto-update packages that don't meet requirements.
|
||||
cmds (str): Additional commands to pass to the pip install command when auto-updating.
|
||||
"""
|
||||
prefix = colorstr('red', 'bold', 'requirements:')
|
||||
check_python() # check python version
|
||||
check_torchvision() # check torch-torchvision compatibility
|
||||
if isinstance(requirements, Path): # requirements.txt file
|
||||
file = requirements.resolve()
|
||||
assert file.exists(), f'{prefix} {file} not found, check failed.'
|
||||
with file.open() as f:
|
||||
requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
|
||||
elif isinstance(requirements, str):
|
||||
requirements = [requirements]
|
||||
|
||||
s = '' # console string
|
||||
pkgs = []
|
||||
for r in requirements:
|
||||
r_stripped = r.split('/')[-1].replace('.git', '') # replace git+https://org/repo.git -> 'repo'
|
||||
try:
|
||||
pkg.require(r_stripped)
|
||||
except (pkg.VersionConflict, pkg.DistributionNotFound): # exception if requirements not met
|
||||
try: # attempt to import (slower but more accurate)
|
||||
import importlib
|
||||
importlib.import_module(next(pkg.parse_requirements(r_stripped)).name)
|
||||
except ImportError:
|
||||
s += f'"{r}" '
|
||||
pkgs.append(r)
|
||||
|
||||
if s:
|
||||
if install and AUTOINSTALL: # check environment variable
|
||||
n = len(pkgs) # number of packages updates
|
||||
LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...")
|
||||
try:
|
||||
t = time.time()
|
||||
assert is_online(), 'AutoUpdate skipped (offline)'
|
||||
LOGGER.info(subprocess.check_output(f'pip install --no-cache {s} {cmds}', shell=True).decode())
|
||||
dt = time.time() - t
|
||||
LOGGER.info(
|
||||
f"{prefix} AutoUpdate success ✅ {dt:.1f}s, installed {n} package{'s' * (n > 1)}: {pkgs}\n"
|
||||
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n")
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'{prefix} ❌ {e}')
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def check_torchvision():
|
||||
"""
|
||||
Checks the installed versions of PyTorch and Torchvision to ensure they're compatible.
|
||||
|
||||
This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according
|
||||
to the provided compatibility table based on https://github.com/pytorch/vision#installation. The
|
||||
compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
|
||||
Torchvision versions.
|
||||
"""
|
||||
|
||||
import torchvision
|
||||
|
||||
# Compatibility table
|
||||
compatibility_table = {'2.0': ['0.15'], '1.13': ['0.14'], '1.12': ['0.13']}
|
||||
|
||||
# Extract only the major and minor versions
|
||||
v_torch = '.'.join(torch.__version__.split('+')[0].split('.')[:2])
|
||||
v_torchvision = '.'.join(torchvision.__version__.split('+')[0].split('.')[:2])
|
||||
|
||||
if v_torch in compatibility_table:
|
||||
compatible_versions = compatibility_table[v_torch]
|
||||
if all(pkg.parse_version(v_torchvision) != pkg.parse_version(v) for v in compatible_versions):
|
||||
print(f'WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n'
|
||||
f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or "
|
||||
"'pip install -U torch torchvision' to update both.\n"
|
||||
'For a full compatibility table see https://github.com/pytorch/vision#installation')
|
||||
|
||||
|
||||
def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''):
|
||||
"""Check file(s) for acceptable suffix."""
|
||||
if file and suffix:
|
||||
if isinstance(suffix, str):
|
||||
suffix = (suffix, )
|
||||
for f in file if isinstance(file, (list, tuple)) else [file]:
|
||||
s = Path(f).suffix.lower().strip() # file suffix
|
||||
if len(s):
|
||||
assert s in suffix, f'{msg}{f} acceptable suffix is {suffix}, not {s}'
|
||||
|
||||
|
||||
def check_yolov5u_filename(file: str, verbose: bool = True):
|
||||
"""Replace legacy YOLOv5 filenames with updated YOLOv5u filenames."""
|
||||
if ('yolov3' in file or 'yolov5' in file) and 'u' not in file:
|
||||
original_file = file
|
||||
file = re.sub(r'(.*yolov5([nsmlx]))\.pt', '\\1u.pt', file) # i.e. yolov5n.pt -> yolov5nu.pt
|
||||
file = re.sub(r'(.*yolov5([nsmlx])6)\.pt', '\\1u.pt', file) # i.e. yolov5n6.pt -> yolov5n6u.pt
|
||||
file = re.sub(r'(.*yolov3(|-tiny|-spp))\.pt', '\\1u.pt', file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
|
||||
if file != original_file and verbose:
|
||||
LOGGER.info(f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are "
|
||||
f'trained with https://github.com/ultralytics/ultralytics and feature improved performance vs '
|
||||
f'standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n')
|
||||
return file
|
||||
|
||||
|
||||
def check_file(file, suffix='', download=True, hard=True):
|
||||
"""Search/download file (if necessary) and return path."""
|
||||
check_suffix(file, suffix) # optional
|
||||
file = str(file).strip() # convert to string and strip spaces
|
||||
file = check_yolov5u_filename(file) # yolov5n -> yolov5nu
|
||||
if not file or ('://' not in file and Path(file).exists()): # exists ('://' check required in Windows Python<3.10)
|
||||
return file
|
||||
elif download and file.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://')): # download
|
||||
url = file # warning: Pathlib turns :// -> :/
|
||||
file = url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth
|
||||
if Path(file).exists():
|
||||
LOGGER.info(f'Found {clean_url(url)} locally at {file}') # file already exists
|
||||
else:
|
||||
downloads.safe_download(url=url, file=file, unzip=False)
|
||||
return file
|
||||
else: # search
|
||||
files = []
|
||||
for d in 'models', 'datasets', 'tracker/cfg', 'yolo/cfg': # search directories
|
||||
files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
|
||||
if not files and hard:
|
||||
raise FileNotFoundError(f"'{file}' does not exist")
|
||||
elif len(files) > 1 and hard:
|
||||
raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}")
|
||||
return files[0] if len(files) else [] # return file
|
||||
|
||||
|
||||
def check_yaml(file, suffix=('.yaml', '.yml'), hard=True):
|
||||
"""Search/download YAML file (if necessary) and return path, checking suffix."""
|
||||
return check_file(file, suffix, hard=hard)
|
||||
|
||||
|
||||
def check_imshow(warn=False):
|
||||
"""Check if environment supports image displays."""
|
||||
try:
|
||||
assert not any((is_colab(), is_kaggle(), is_docker()))
|
||||
cv2.imshow('test', np.zeros((1, 1, 3)))
|
||||
cv2.waitKey(1)
|
||||
cv2.destroyAllWindows()
|
||||
cv2.waitKey(1)
|
||||
return True
|
||||
except Exception as e:
|
||||
if warn:
|
||||
LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}')
|
||||
return False
|
||||
|
||||
|
||||
def check_yolo(verbose=True, device=''):
|
||||
"""Return a human-readable YOLO software and hardware summary."""
|
||||
from ultralytics.yolo.utils.torch_utils import select_device
|
||||
|
||||
if is_jupyter():
|
||||
if check_requirements('wandb', install=False):
|
||||
os.system('pip uninstall -y wandb') # uninstall wandb: unwanted account creation prompt with infinite hang
|
||||
if is_colab():
|
||||
shutil.rmtree('sample_data', ignore_errors=True) # remove colab /sample_data directory
|
||||
|
||||
if verbose:
|
||||
# System info
|
||||
gib = 1 << 30 # bytes per GiB
|
||||
ram = psutil.virtual_memory().total
|
||||
total, used, free = shutil.disk_usage('/')
|
||||
s = f'({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)'
|
||||
with contextlib.suppress(Exception): # clear display if ipython is installed
|
||||
from IPython import display
|
||||
display.clear_output()
|
||||
else:
|
||||
s = ''
|
||||
|
||||
select_device(device=device, newline=False)
|
||||
LOGGER.info(f'Setup complete ✅ {s}')
|
||||
|
||||
|
||||
def check_amp(model):
|
||||
"""
|
||||
This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model.
|
||||
If the checks fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP
|
||||
results, so AMP will be disabled during training.
|
||||
|
||||
Args:
|
||||
model (nn.Module): A YOLOv8 model instance.
|
||||
|
||||
Returns:
|
||||
(bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the AMP checks fail, indicating anomalies with the AMP functionality on the system.
|
||||
"""
|
||||
device = next(model.parameters()).device # get model device
|
||||
if device.type in ('cpu', 'mps'):
|
||||
return False # AMP only used on CUDA devices
|
||||
|
||||
def amp_allclose(m, im):
|
||||
"""All close FP32 vs AMP results."""
|
||||
a = m(im, device=device, verbose=False)[0].boxes.data # FP32 inference
|
||||
with torch.cuda.amp.autocast(True):
|
||||
b = m(im, device=device, verbose=False)[0].boxes.data # AMP inference
|
||||
del m
|
||||
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance
|
||||
|
||||
f = ROOT / 'assets/bus.jpg' # image to check
|
||||
im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if ONLINE else np.ones((640, 640, 3))
|
||||
prefix = colorstr('AMP: ')
|
||||
LOGGER.info(f'{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...')
|
||||
warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False."
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
assert amp_allclose(YOLO('yolov8n.pt'), im)
|
||||
LOGGER.info(f'{prefix}checks passed ✅')
|
||||
except ConnectionError:
|
||||
LOGGER.warning(f'{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}')
|
||||
except (AttributeError, ModuleNotFoundError):
|
||||
LOGGER.warning(
|
||||
f'{prefix}checks skipped ⚠️. Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}'
|
||||
)
|
||||
except AssertionError:
|
||||
LOGGER.warning(f'{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to '
|
||||
f'NaN losses or zero-mAP results, so AMP will be disabled during training.')
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def git_describe(path=ROOT): # path must be a directory
|
||||
"""Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe."""
|
||||
try:
|
||||
assert (Path(path) / '.git').is_dir()
|
||||
return subprocess.check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
|
||||
except AssertionError:
|
||||
return ''
|
||||
|
||||
|
||||
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
|
||||
"""Print function arguments (optional args dict)."""
|
||||
|
||||
def strip_auth(v):
|
||||
"""Clean longer Ultralytics HUB URLs by stripping potential authentication information."""
|
||||
return clean_url(v) if (isinstance(v, str) and v.startswith('http') and len(v) > 100) else v
|
||||
|
||||
x = inspect.currentframe().f_back # previous frame
|
||||
file, _, func, _, _ = inspect.getframeinfo(x)
|
||||
if args is None: # get args automatically
|
||||
args, _, _, frm = inspect.getargvalues(x)
|
||||
args = {k: v for k, v in frm.items() if k in args}
|
||||
try:
|
||||
file = Path(file).resolve().relative_to(ROOT).with_suffix('')
|
||||
except ValueError:
|
||||
file = Path(file).stem
|
||||
s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '')
|
||||
LOGGER.info(colorstr(s) + ', '.join(f'{k}={strip_auth(v)}' for k, v in args.items()))
|
@ -1,67 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import socket
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from . import USER_CONFIG_DIR
|
||||
from .torch_utils import TORCH_1_9
|
||||
|
||||
|
||||
def find_free_network_port() -> int:
|
||||
"""Finds a free port on localhost.
|
||||
|
||||
It is useful in single-node training when we don't want to connect to a real main node but have to set the
|
||||
`MASTER_PORT` environment variable.
|
||||
"""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('127.0.0.1', 0))
|
||||
return s.getsockname()[1] # port
|
||||
|
||||
|
||||
def generate_ddp_file(trainer):
|
||||
"""Generates a DDP file and returns its file name."""
|
||||
module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1)
|
||||
|
||||
content = f'''overrides = {vars(trainer.args)} \nif __name__ == "__main__":
|
||||
from {module} import {name}
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG_DICT
|
||||
|
||||
cfg = DEFAULT_CFG_DICT.copy()
|
||||
cfg.update(save_dir='') # handle the extra key 'save_dir'
|
||||
trainer = {name}(cfg=cfg, overrides=overrides)
|
||||
trainer.train()'''
|
||||
(USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
|
||||
with tempfile.NamedTemporaryFile(prefix='_temp_',
|
||||
suffix=f'{id(trainer)}.py',
|
||||
mode='w+',
|
||||
encoding='utf-8',
|
||||
dir=USER_CONFIG_DIR / 'DDP',
|
||||
delete=False) as file:
|
||||
file.write(content)
|
||||
return file.name
|
||||
|
||||
|
||||
def generate_ddp_command(world_size, trainer):
|
||||
"""Generates and returns command for distributed training."""
|
||||
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
|
||||
if not trainer.resume:
|
||||
shutil.rmtree(trainer.save_dir) # remove the save_dir
|
||||
file = str(Path(sys.argv[0]).resolve())
|
||||
safe_pattern = re.compile(r'^[a-zA-Z0-9_. /\\-]{1,128}$') # allowed characters and maximum of 100 characters
|
||||
if not (safe_pattern.match(file) and Path(file).exists() and file.endswith('.py')): # using CLI
|
||||
file = generate_ddp_file(trainer)
|
||||
dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
|
||||
port = find_free_network_port()
|
||||
cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file]
|
||||
return cmd, file
|
||||
|
||||
|
||||
def ddp_cleanup(trainer, file):
|
||||
"""Delete temp file if created."""
|
||||
if f'{id(trainer)}.py' in file: # if temp_file suffix in file
|
||||
os.remove(file)
|
@ -1,271 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import contextlib
|
||||
import shutil
|
||||
import subprocess
|
||||
from itertools import repeat
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
from urllib import parse, request
|
||||
from zipfile import BadZipFile, ZipFile, is_zipfile
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER, checks, clean_url, emojis, is_online, url2file
|
||||
|
||||
GITHUB_ASSET_NAMES = [f'yolov8{k}{suffix}.pt' for k in 'nsmlx' for suffix in ('', '6', '-cls', '-seg', '-pose')] + \
|
||||
[f'yolov5{k}u.pt' for k in 'nsmlx'] + \
|
||||
[f'yolov3{k}u.pt' for k in ('', '-spp', '-tiny')] + \
|
||||
[f'yolo_nas_{k}.pt' for k in 'sml'] + \
|
||||
[f'sam_{k}.pt' for k in 'bl'] + \
|
||||
[f'FastSAM-{k}.pt' for k in 'sx'] + \
|
||||
[f'rtdetr-{k}.pt' for k in 'lx'] + \
|
||||
['mobile_sam.pt']
|
||||
GITHUB_ASSET_STEMS = [Path(k).stem for k in GITHUB_ASSET_NAMES]
|
||||
|
||||
|
||||
def is_url(url, check=True):
|
||||
"""Check if string is URL and check if URL exists."""
|
||||
with contextlib.suppress(Exception):
|
||||
url = str(url)
|
||||
result = parse.urlparse(url)
|
||||
assert all([result.scheme, result.netloc]) # check if is url
|
||||
if check:
|
||||
with request.urlopen(url) as response:
|
||||
return response.getcode() == 200 # check if exists online
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX'), exist_ok=False):
|
||||
"""
|
||||
Unzips a *.zip file to the specified path, excluding files containing strings in the exclude list.
|
||||
|
||||
If the zipfile does not contain a single top-level directory, the function will create a new
|
||||
directory with the same name as the zipfile (without the extension) to extract its contents.
|
||||
If a path is not provided, the function will use the parent directory of the zipfile as the default path.
|
||||
|
||||
Args:
|
||||
file (str): The path to the zipfile to be extracted.
|
||||
path (str, optional): The path to extract the zipfile to. Defaults to None.
|
||||
exclude (tuple, optional): A tuple of filename strings to be excluded. Defaults to ('.DS_Store', '__MACOSX').
|
||||
exist_ok (bool, optional): Whether to overwrite existing contents if they exist. Defaults to False.
|
||||
|
||||
Raises:
|
||||
BadZipFile: If the provided file does not exist or is not a valid zipfile.
|
||||
|
||||
Returns:
|
||||
(Path): The path to the directory where the zipfile was extracted.
|
||||
"""
|
||||
if not (Path(file).exists() and is_zipfile(file)):
|
||||
raise BadZipFile(f"File '{file}' does not exist or is a bad zip file.")
|
||||
if path is None:
|
||||
path = Path(file).parent # default path
|
||||
|
||||
# Unzip the file contents
|
||||
with ZipFile(file) as zipObj:
|
||||
file_list = [f for f in zipObj.namelist() if all(x not in f for x in exclude)]
|
||||
top_level_dirs = {Path(f).parts[0] for f in file_list}
|
||||
|
||||
if len(top_level_dirs) > 1 or not file_list[0].endswith('/'):
|
||||
path = Path(path) / Path(file).stem # define new unzip directory
|
||||
|
||||
# Check if destination directory already exists and contains files
|
||||
extract_path = Path(path) / list(top_level_dirs)[0]
|
||||
if extract_path.exists() and any(extract_path.iterdir()) and not exist_ok:
|
||||
# If it exists and is not empty, return the path without unzipping
|
||||
LOGGER.info(f'Skipping {file} unzip (already unzipped)')
|
||||
return path
|
||||
|
||||
for f in file_list:
|
||||
zipObj.extract(f, path=path)
|
||||
|
||||
return path # return unzip dir
|
||||
|
||||
|
||||
def check_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=1.5, hard=True):
|
||||
"""
|
||||
Check if there is sufficient disk space to download and store a file.
|
||||
|
||||
Args:
|
||||
url (str, optional): The URL to the file. Defaults to 'https://ultralytics.com/assets/coco128.zip'.
|
||||
sf (float, optional): Safety factor, the multiplier for the required free space. Defaults to 2.0.
|
||||
hard (bool, optional): Whether to throw an error or not on insufficient disk space. Defaults to True.
|
||||
|
||||
Returns:
|
||||
(bool): True if there is sufficient disk space, False otherwise.
|
||||
"""
|
||||
with contextlib.suppress(Exception):
|
||||
gib = 1 << 30 # bytes per GiB
|
||||
data = int(requests.head(url).headers['Content-Length']) / gib # file size (GB)
|
||||
total, used, free = (x / gib for x in shutil.disk_usage('/')) # bytes
|
||||
if data * sf < free:
|
||||
return True # sufficient space
|
||||
|
||||
# Insufficient space
|
||||
text = (f'WARNING ⚠️ Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, '
|
||||
f'Please free {data * sf - free:.1f} GB additional disk space and try again.')
|
||||
if hard:
|
||||
raise MemoryError(text)
|
||||
else:
|
||||
LOGGER.warning(text)
|
||||
return False
|
||||
|
||||
# Pass if error
|
||||
return True
|
||||
|
||||
|
||||
def safe_download(url,
|
||||
file=None,
|
||||
dir=None,
|
||||
unzip=True,
|
||||
delete=False,
|
||||
curl=False,
|
||||
retry=3,
|
||||
min_bytes=1E0,
|
||||
progress=True):
|
||||
"""
|
||||
Downloads files from a URL, with options for retrying, unzipping, and deleting the downloaded file.
|
||||
|
||||
Args:
|
||||
url (str): The URL of the file to be downloaded.
|
||||
file (str, optional): The filename of the downloaded file.
|
||||
If not provided, the file will be saved with the same name as the URL.
|
||||
dir (str, optional): The directory to save the downloaded file.
|
||||
If not provided, the file will be saved in the current working directory.
|
||||
unzip (bool, optional): Whether to unzip the downloaded file. Default: True.
|
||||
delete (bool, optional): Whether to delete the downloaded file after unzipping. Default: False.
|
||||
curl (bool, optional): Whether to use curl command line tool for downloading. Default: False.
|
||||
retry (int, optional): The number of times to retry the download in case of failure. Default: 3.
|
||||
min_bytes (float, optional): The minimum number of bytes that the downloaded file should have, to be considered
|
||||
a successful download. Default: 1E0.
|
||||
progress (bool, optional): Whether to display a progress bar during the download. Default: True.
|
||||
"""
|
||||
f = dir / url2file(url) if dir else Path(file) # URL converted to filename
|
||||
if '://' not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10)
|
||||
f = Path(url) # filename
|
||||
elif not f.is_file(): # URL and file do not exist
|
||||
assert dir or file, 'dir or file required for download'
|
||||
f = dir / url2file(url) if dir else Path(file)
|
||||
desc = f'Downloading {clean_url(url)} to {f}'
|
||||
LOGGER.info(f'{desc}...')
|
||||
f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing
|
||||
check_disk_space(url)
|
||||
for i in range(retry + 1):
|
||||
try:
|
||||
if curl or i > 0: # curl download with retry, continue
|
||||
s = 'sS' * (not progress) # silent
|
||||
r = subprocess.run(['curl', '-#', f'-{s}L', url, '-o', f, '--retry', '3', '-C', '-']).returncode
|
||||
assert r == 0, f'Curl return value {r}'
|
||||
else: # urllib download
|
||||
method = 'torch'
|
||||
if method == 'torch':
|
||||
torch.hub.download_url_to_file(url, f, progress=progress)
|
||||
else:
|
||||
from ultralytics.yolo.utils import TQDM_BAR_FORMAT
|
||||
with request.urlopen(url) as response, tqdm(total=int(response.getheader('Content-Length', 0)),
|
||||
desc=desc,
|
||||
disable=not progress,
|
||||
unit='B',
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
bar_format=TQDM_BAR_FORMAT) as pbar:
|
||||
with open(f, 'wb') as f_opened:
|
||||
for data in response:
|
||||
f_opened.write(data)
|
||||
pbar.update(len(data))
|
||||
|
||||
if f.exists():
|
||||
if f.stat().st_size > min_bytes:
|
||||
break # success
|
||||
f.unlink() # remove partial downloads
|
||||
except Exception as e:
|
||||
if i == 0 and not is_online():
|
||||
raise ConnectionError(emojis(f'❌ Download failure for {url}. Environment is not online.')) from e
|
||||
elif i >= retry:
|
||||
raise ConnectionError(emojis(f'❌ Download failure for {url}. Retry limit reached.')) from e
|
||||
LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
|
||||
|
||||
if unzip and f.exists() and f.suffix in ('', '.zip', '.tar', '.gz'):
|
||||
unzip_dir = dir or f.parent # unzip to dir if provided else unzip in place
|
||||
LOGGER.info(f'Unzipping {f} to {unzip_dir.absolute()}...')
|
||||
if is_zipfile(f):
|
||||
unzip_dir = unzip_file(file=f, path=unzip_dir) # unzip
|
||||
elif f.suffix == '.tar':
|
||||
subprocess.run(['tar', 'xf', f, '--directory', unzip_dir], check=True) # unzip
|
||||
elif f.suffix == '.gz':
|
||||
subprocess.run(['tar', 'xfz', f, '--directory', unzip_dir], check=True) # unzip
|
||||
if delete:
|
||||
f.unlink() # remove zip
|
||||
return unzip_dir
|
||||
|
||||
|
||||
def get_github_assets(repo='ultralytics/assets', version='latest'):
|
||||
"""Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov8s.pt', ...])."""
|
||||
if version != 'latest':
|
||||
version = f'tags/{version}' # i.e. tags/v6.2
|
||||
response = requests.get(f'https://api.github.com/repos/{repo}/releases/{version}').json() # github api
|
||||
return response['tag_name'], [x['name'] for x in response['assets']] # tag, assets
|
||||
|
||||
|
||||
def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'):
|
||||
"""Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc."""
|
||||
from ultralytics.yolo.utils import SETTINGS # scoped for circular import
|
||||
|
||||
# YOLOv3/5u updates
|
||||
file = str(file)
|
||||
file = checks.check_yolov5u_filename(file)
|
||||
file = Path(file.strip().replace("'", ''))
|
||||
if file.exists():
|
||||
return str(file)
|
||||
elif (SETTINGS['weights_dir'] / file).exists():
|
||||
return str(SETTINGS['weights_dir'] / file)
|
||||
else:
|
||||
# URL specified
|
||||
name = Path(parse.unquote(str(file))).name # decode '%2F' to '/' etc.
|
||||
if str(file).startswith(('http:/', 'https:/')): # download
|
||||
url = str(file).replace(':/', '://') # Pathlib turns :// -> :/
|
||||
file = url2file(name) # parse authentication https://url.com/file.txt?auth...
|
||||
if Path(file).is_file():
|
||||
LOGGER.info(f'Found {clean_url(url)} locally at {file}') # file already exists
|
||||
else:
|
||||
safe_download(url=url, file=file, min_bytes=1E5)
|
||||
return file
|
||||
|
||||
# GitHub assets
|
||||
assets = GITHUB_ASSET_NAMES
|
||||
try:
|
||||
tag, assets = get_github_assets(repo, release)
|
||||
except Exception:
|
||||
try:
|
||||
tag, assets = get_github_assets(repo) # latest release
|
||||
except Exception:
|
||||
try:
|
||||
tag = subprocess.check_output(['git', 'tag']).decode().split()[-1]
|
||||
except Exception:
|
||||
tag = release
|
||||
|
||||
file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
|
||||
if name in assets:
|
||||
safe_download(url=f'https://github.com/{repo}/releases/download/{tag}/{name}', file=file, min_bytes=1E5)
|
||||
|
||||
return str(file)
|
||||
|
||||
|
||||
def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3):
|
||||
"""Downloads and unzips files concurrently if threads > 1, else sequentially."""
|
||||
dir = Path(dir)
|
||||
dir.mkdir(parents=True, exist_ok=True) # make directory
|
||||
if threads > 1:
|
||||
with ThreadPool(threads) as pool:
|
||||
pool.map(
|
||||
lambda x: safe_download(
|
||||
url=x[0], dir=x[1], unzip=unzip, delete=delete, curl=curl, retry=retry, progress=threads <= 1),
|
||||
zip(url, repeat(dir)))
|
||||
pool.close()
|
||||
pool.join()
|
||||
else:
|
||||
for u in [url] if isinstance(url, (str, Path)) else url:
|
||||
safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry)
|
@ -1,10 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from ultralytics.yolo.utils import emojis
|
||||
|
||||
|
||||
class HUBModelError(Exception):
|
||||
|
||||
def __init__(self, message='Model not found. Please check model URL and try again.'):
|
||||
"""Create an exception for when a model is not found."""
|
||||
super().__init__(emojis(message))
|
@ -1,100 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import contextlib
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class WorkingDirectory(contextlib.ContextDecorator):
|
||||
"""Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager."""
|
||||
|
||||
def __init__(self, new_dir):
|
||||
"""Sets the working directory to 'new_dir' upon instantiation."""
|
||||
self.dir = new_dir # new dir
|
||||
self.cwd = Path.cwd().resolve() # current dir
|
||||
|
||||
def __enter__(self):
|
||||
"""Changes the current directory to the specified directory."""
|
||||
os.chdir(self.dir)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Restore the current working directory on context exit."""
|
||||
os.chdir(self.cwd)
|
||||
|
||||
|
||||
def increment_path(path, exist_ok=False, sep='', mkdir=False):
|
||||
"""
|
||||
Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
|
||||
|
||||
If the path exists and exist_ok is not set to True, the path will be incremented by appending a number and sep to
|
||||
the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the
|
||||
number will be appended directly to the end of the path. If mkdir is set to True, the path will be created as a
|
||||
directory if it does not already exist.
|
||||
|
||||
Args:
|
||||
path (str, pathlib.Path): Path to increment.
|
||||
exist_ok (bool, optional): If True, the path will not be incremented and returned as-is. Defaults to False.
|
||||
sep (str, optional): Separator to use between the path and the incrementation number. Defaults to ''.
|
||||
mkdir (bool, optional): Create a directory if it does not exist. Defaults to False.
|
||||
|
||||
Returns:
|
||||
(pathlib.Path): Incremented path.
|
||||
"""
|
||||
path = Path(path) # os-agnostic
|
||||
if path.exists() and not exist_ok:
|
||||
path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
|
||||
|
||||
# Method 1
|
||||
for n in range(2, 9999):
|
||||
p = f'{path}{sep}{n}{suffix}' # increment path
|
||||
if not os.path.exists(p): #
|
||||
break
|
||||
path = Path(p)
|
||||
|
||||
if mkdir:
|
||||
path.mkdir(parents=True, exist_ok=True) # make directory
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def file_age(path=__file__):
|
||||
"""Return days since last file update."""
|
||||
dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
|
||||
return dt.days # + dt.seconds / 86400 # fractional days
|
||||
|
||||
|
||||
def file_date(path=__file__):
|
||||
"""Return human-readable file modification date, i.e. '2021-3-26'."""
|
||||
t = datetime.fromtimestamp(Path(path).stat().st_mtime)
|
||||
return f'{t.year}-{t.month}-{t.day}'
|
||||
|
||||
|
||||
def file_size(path):
|
||||
"""Return file/dir size (MB)."""
|
||||
if isinstance(path, (str, Path)):
|
||||
mb = 1 << 20 # bytes to MiB (1024 ** 2)
|
||||
path = Path(path)
|
||||
if path.is_file():
|
||||
return path.stat().st_size / mb
|
||||
elif path.is_dir():
|
||||
return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
|
||||
return 0.0
|
||||
|
||||
|
||||
def get_latest_run(search_dir='.'):
|
||||
"""Return path to most recent 'last.pt' in /runs (i.e. to --resume from)."""
|
||||
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
|
||||
return max(last_list, key=os.path.getctime) if last_list else ''
|
||||
|
||||
|
||||
def make_dirs(dir='new_dir/'):
|
||||
"""Create directories."""
|
||||
dir = Path(dir)
|
||||
if dir.exists():
|
||||
shutil.rmtree(dir) # delete dir
|
||||
for p in dir, dir / 'labels', dir / 'images':
|
||||
p.mkdir(parents=True, exist_ok=True) # make dir
|
||||
return dir
|
@ -1,392 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from collections import abc
|
||||
from itertools import repeat
|
||||
from numbers import Number
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .ops import ltwh2xywh, ltwh2xyxy, resample_segments, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
"""From PyTorch internals."""
|
||||
|
||||
def parse(x):
|
||||
"""Parse bounding boxes format between XYWH and LTWH."""
|
||||
return x if isinstance(x, abc.Iterable) else tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
to_2tuple = _ntuple(2)
|
||||
to_4tuple = _ntuple(4)
|
||||
|
||||
# `xyxy` means left top and right bottom
|
||||
# `xywh` means center x, center y and width, height(yolo format)
|
||||
# `ltwh` means left top and width, height(coco format)
|
||||
_formats = ['xyxy', 'xywh', 'ltwh']
|
||||
|
||||
__all__ = 'Bboxes', # tuple or list
|
||||
|
||||
|
||||
class Bboxes:
|
||||
"""Now only numpy is supported."""
|
||||
|
||||
def __init__(self, bboxes, format='xyxy') -> None:
|
||||
assert format in _formats, f'Invalid bounding box format: {format}, format must be one of {_formats}'
|
||||
bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes
|
||||
assert bboxes.ndim == 2
|
||||
assert bboxes.shape[1] == 4
|
||||
self.bboxes = bboxes
|
||||
self.format = format
|
||||
# self.normalized = normalized
|
||||
|
||||
# def convert(self, format):
|
||||
# assert format in _formats
|
||||
# if self.format == format:
|
||||
# bboxes = self.bboxes
|
||||
# elif self.format == "xyxy":
|
||||
# if format == "xywh":
|
||||
# bboxes = xyxy2xywh(self.bboxes)
|
||||
# else:
|
||||
# bboxes = xyxy2ltwh(self.bboxes)
|
||||
# elif self.format == "xywh":
|
||||
# if format == "xyxy":
|
||||
# bboxes = xywh2xyxy(self.bboxes)
|
||||
# else:
|
||||
# bboxes = xywh2ltwh(self.bboxes)
|
||||
# else:
|
||||
# if format == "xyxy":
|
||||
# bboxes = ltwh2xyxy(self.bboxes)
|
||||
# else:
|
||||
# bboxes = ltwh2xywh(self.bboxes)
|
||||
#
|
||||
# return Bboxes(bboxes, format)
|
||||
|
||||
def convert(self, format):
|
||||
"""Converts bounding box format from one type to another."""
|
||||
assert format in _formats, f'Invalid bounding box format: {format}, format must be one of {_formats}'
|
||||
if self.format == format:
|
||||
return
|
||||
elif self.format == 'xyxy':
|
||||
bboxes = xyxy2xywh(self.bboxes) if format == 'xywh' else xyxy2ltwh(self.bboxes)
|
||||
elif self.format == 'xywh':
|
||||
bboxes = xywh2xyxy(self.bboxes) if format == 'xyxy' else xywh2ltwh(self.bboxes)
|
||||
else:
|
||||
bboxes = ltwh2xyxy(self.bboxes) if format == 'xyxy' else ltwh2xywh(self.bboxes)
|
||||
self.bboxes = bboxes
|
||||
self.format = format
|
||||
|
||||
def areas(self):
|
||||
"""Return box areas."""
|
||||
self.convert('xyxy')
|
||||
return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1])
|
||||
|
||||
# def denormalize(self, w, h):
|
||||
# if not self.normalized:
|
||||
# return
|
||||
# assert (self.bboxes <= 1.0).all()
|
||||
# self.bboxes[:, 0::2] *= w
|
||||
# self.bboxes[:, 1::2] *= h
|
||||
# self.normalized = False
|
||||
#
|
||||
# def normalize(self, w, h):
|
||||
# if self.normalized:
|
||||
# return
|
||||
# assert (self.bboxes > 1.0).any()
|
||||
# self.bboxes[:, 0::2] /= w
|
||||
# self.bboxes[:, 1::2] /= h
|
||||
# self.normalized = True
|
||||
|
||||
def mul(self, scale):
|
||||
"""
|
||||
Args:
|
||||
scale (tuple | list | int): the scale for four coords.
|
||||
"""
|
||||
if isinstance(scale, Number):
|
||||
scale = to_4tuple(scale)
|
||||
assert isinstance(scale, (tuple, list))
|
||||
assert len(scale) == 4
|
||||
self.bboxes[:, 0] *= scale[0]
|
||||
self.bboxes[:, 1] *= scale[1]
|
||||
self.bboxes[:, 2] *= scale[2]
|
||||
self.bboxes[:, 3] *= scale[3]
|
||||
|
||||
def add(self, offset):
|
||||
"""
|
||||
Args:
|
||||
offset (tuple | list | int): the offset for four coords.
|
||||
"""
|
||||
if isinstance(offset, Number):
|
||||
offset = to_4tuple(offset)
|
||||
assert isinstance(offset, (tuple, list))
|
||||
assert len(offset) == 4
|
||||
self.bboxes[:, 0] += offset[0]
|
||||
self.bboxes[:, 1] += offset[1]
|
||||
self.bboxes[:, 2] += offset[2]
|
||||
self.bboxes[:, 3] += offset[3]
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of boxes."""
|
||||
return len(self.bboxes)
|
||||
|
||||
@classmethod
|
||||
def concatenate(cls, boxes_list: List['Bboxes'], axis=0) -> 'Bboxes':
|
||||
"""
|
||||
Concatenate a list of Bboxes objects into a single Bboxes object.
|
||||
|
||||
Args:
|
||||
boxes_list (List[Bboxes]): A list of Bboxes objects to concatenate.
|
||||
axis (int, optional): The axis along which to concatenate the bounding boxes.
|
||||
Defaults to 0.
|
||||
|
||||
Returns:
|
||||
Bboxes: A new Bboxes object containing the concatenated bounding boxes.
|
||||
|
||||
Note:
|
||||
The input should be a list or tuple of Bboxes objects.
|
||||
"""
|
||||
assert isinstance(boxes_list, (list, tuple))
|
||||
if not boxes_list:
|
||||
return cls(np.empty(0))
|
||||
assert all(isinstance(box, Bboxes) for box in boxes_list)
|
||||
|
||||
if len(boxes_list) == 1:
|
||||
return boxes_list[0]
|
||||
return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))
|
||||
|
||||
def __getitem__(self, index) -> 'Bboxes':
|
||||
"""
|
||||
Retrieve a specific bounding box or a set of bounding boxes using indexing.
|
||||
|
||||
Args:
|
||||
index (int, slice, or np.ndarray): The index, slice, or boolean array to select
|
||||
the desired bounding boxes.
|
||||
|
||||
Returns:
|
||||
Bboxes: A new Bboxes object containing the selected bounding boxes.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the indexed bounding boxes do not form a 2-dimensional matrix.
|
||||
|
||||
Note:
|
||||
When using boolean indexing, make sure to provide a boolean array with the same
|
||||
length as the number of bounding boxes.
|
||||
"""
|
||||
if isinstance(index, int):
|
||||
return Bboxes(self.bboxes[index].view(1, -1))
|
||||
b = self.bboxes[index]
|
||||
assert b.ndim == 2, f'Indexing on Bboxes with {index} failed to return a matrix!'
|
||||
return Bboxes(b)
|
||||
|
||||
|
||||
class Instances:
|
||||
|
||||
def __init__(self, bboxes, segments=None, keypoints=None, bbox_format='xywh', normalized=True) -> None:
|
||||
"""
|
||||
Args:
|
||||
bboxes (ndarray): bboxes with shape [N, 4].
|
||||
segments (list | ndarray): segments.
|
||||
keypoints (ndarray): keypoints(x, y, visible) with shape [N, 17, 3].
|
||||
"""
|
||||
if segments is None:
|
||||
segments = []
|
||||
self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format)
|
||||
self.keypoints = keypoints
|
||||
self.normalized = normalized
|
||||
|
||||
if len(segments) > 0:
|
||||
# list[np.array(1000, 2)] * num_samples
|
||||
segments = resample_segments(segments)
|
||||
# (N, 1000, 2)
|
||||
segments = np.stack(segments, axis=0)
|
||||
else:
|
||||
segments = np.zeros((0, 1000, 2), dtype=np.float32)
|
||||
self.segments = segments
|
||||
|
||||
def convert_bbox(self, format):
|
||||
"""Convert bounding box format."""
|
||||
self._bboxes.convert(format=format)
|
||||
|
||||
@property
|
||||
def bbox_areas(self):
|
||||
"""Calculate the area of bounding boxes."""
|
||||
return self._bboxes.areas()
|
||||
|
||||
def scale(self, scale_w, scale_h, bbox_only=False):
|
||||
"""this might be similar with denormalize func but without normalized sign."""
|
||||
self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h))
|
||||
if bbox_only:
|
||||
return
|
||||
self.segments[..., 0] *= scale_w
|
||||
self.segments[..., 1] *= scale_h
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 0] *= scale_w
|
||||
self.keypoints[..., 1] *= scale_h
|
||||
|
||||
def denormalize(self, w, h):
|
||||
"""Denormalizes boxes, segments, and keypoints from normalized coordinates."""
|
||||
if not self.normalized:
|
||||
return
|
||||
self._bboxes.mul(scale=(w, h, w, h))
|
||||
self.segments[..., 0] *= w
|
||||
self.segments[..., 1] *= h
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 0] *= w
|
||||
self.keypoints[..., 1] *= h
|
||||
self.normalized = False
|
||||
|
||||
def normalize(self, w, h):
|
||||
"""Normalize bounding boxes, segments, and keypoints to image dimensions."""
|
||||
if self.normalized:
|
||||
return
|
||||
self._bboxes.mul(scale=(1 / w, 1 / h, 1 / w, 1 / h))
|
||||
self.segments[..., 0] /= w
|
||||
self.segments[..., 1] /= h
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 0] /= w
|
||||
self.keypoints[..., 1] /= h
|
||||
self.normalized = True
|
||||
|
||||
def add_padding(self, padw, padh):
|
||||
"""Handle rect and mosaic situation."""
|
||||
assert not self.normalized, 'you should add padding with absolute coordinates.'
|
||||
self._bboxes.add(offset=(padw, padh, padw, padh))
|
||||
self.segments[..., 0] += padw
|
||||
self.segments[..., 1] += padh
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 0] += padw
|
||||
self.keypoints[..., 1] += padh
|
||||
|
||||
def __getitem__(self, index) -> 'Instances':
|
||||
"""
|
||||
Retrieve a specific instance or a set of instances using indexing.
|
||||
|
||||
Args:
|
||||
index (int, slice, or np.ndarray): The index, slice, or boolean array to select
|
||||
the desired instances.
|
||||
|
||||
Returns:
|
||||
Instances: A new Instances object containing the selected bounding boxes,
|
||||
segments, and keypoints if present.
|
||||
|
||||
Note:
|
||||
When using boolean indexing, make sure to provide a boolean array with the same
|
||||
length as the number of instances.
|
||||
"""
|
||||
segments = self.segments[index] if len(self.segments) else self.segments
|
||||
keypoints = self.keypoints[index] if self.keypoints is not None else None
|
||||
bboxes = self.bboxes[index]
|
||||
bbox_format = self._bboxes.format
|
||||
return Instances(
|
||||
bboxes=bboxes,
|
||||
segments=segments,
|
||||
keypoints=keypoints,
|
||||
bbox_format=bbox_format,
|
||||
normalized=self.normalized,
|
||||
)
|
||||
|
||||
def flipud(self, h):
|
||||
"""Flips the coordinates of bounding boxes, segments, and keypoints vertically."""
|
||||
if self._bboxes.format == 'xyxy':
|
||||
y1 = self.bboxes[:, 1].copy()
|
||||
y2 = self.bboxes[:, 3].copy()
|
||||
self.bboxes[:, 1] = h - y2
|
||||
self.bboxes[:, 3] = h - y1
|
||||
else:
|
||||
self.bboxes[:, 1] = h - self.bboxes[:, 1]
|
||||
self.segments[..., 1] = h - self.segments[..., 1]
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 1] = h - self.keypoints[..., 1]
|
||||
|
||||
def fliplr(self, w):
|
||||
"""Reverses the order of the bounding boxes and segments horizontally."""
|
||||
if self._bboxes.format == 'xyxy':
|
||||
x1 = self.bboxes[:, 0].copy()
|
||||
x2 = self.bboxes[:, 2].copy()
|
||||
self.bboxes[:, 0] = w - x2
|
||||
self.bboxes[:, 2] = w - x1
|
||||
else:
|
||||
self.bboxes[:, 0] = w - self.bboxes[:, 0]
|
||||
self.segments[..., 0] = w - self.segments[..., 0]
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 0] = w - self.keypoints[..., 0]
|
||||
|
||||
def clip(self, w, h):
|
||||
"""Clips bounding boxes, segments, and keypoints values to stay within image boundaries."""
|
||||
ori_format = self._bboxes.format
|
||||
self.convert_bbox(format='xyxy')
|
||||
self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
|
||||
self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)
|
||||
if ori_format != 'xyxy':
|
||||
self.convert_bbox(format=ori_format)
|
||||
self.segments[..., 0] = self.segments[..., 0].clip(0, w)
|
||||
self.segments[..., 1] = self.segments[..., 1].clip(0, h)
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 0] = self.keypoints[..., 0].clip(0, w)
|
||||
self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h)
|
||||
|
||||
def remove_zero_area_boxes(self):
|
||||
"""Remove zero-area boxes, i.e. after clipping some boxes may have zero width or height. This removes them."""
|
||||
good = self.bbox_areas > 0
|
||||
if not all(good):
|
||||
self._bboxes = self._bboxes[good]
|
||||
if len(self.segments):
|
||||
self.segments = self.segments[good]
|
||||
if self.keypoints is not None:
|
||||
self.keypoints = self.keypoints[good]
|
||||
return good
|
||||
|
||||
def update(self, bboxes, segments=None, keypoints=None):
|
||||
"""Updates instance variables."""
|
||||
self._bboxes = Bboxes(bboxes, format=self._bboxes.format)
|
||||
if segments is not None:
|
||||
self.segments = segments
|
||||
if keypoints is not None:
|
||||
self.keypoints = keypoints
|
||||
|
||||
def __len__(self):
|
||||
"""Return the length of the instance list."""
|
||||
return len(self.bboxes)
|
||||
|
||||
@classmethod
|
||||
def concatenate(cls, instances_list: List['Instances'], axis=0) -> 'Instances':
|
||||
"""
|
||||
Concatenates a list of Instances objects into a single Instances object.
|
||||
|
||||
Args:
|
||||
instances_list (List[Instances]): A list of Instances objects to concatenate.
|
||||
axis (int, optional): The axis along which the arrays will be concatenated. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
Instances: A new Instances object containing the concatenated bounding boxes,
|
||||
segments, and keypoints if present.
|
||||
|
||||
Note:
|
||||
The `Instances` objects in the list should have the same properties, such as
|
||||
the format of the bounding boxes, whether keypoints are present, and if the
|
||||
coordinates are normalized.
|
||||
"""
|
||||
assert isinstance(instances_list, (list, tuple))
|
||||
if not instances_list:
|
||||
return cls(np.empty(0))
|
||||
assert all(isinstance(instance, Instances) for instance in instances_list)
|
||||
|
||||
if len(instances_list) == 1:
|
||||
return instances_list[0]
|
||||
|
||||
use_keypoint = instances_list[0].keypoints is not None
|
||||
bbox_format = instances_list[0]._bboxes.format
|
||||
normalized = instances_list[0].normalized
|
||||
|
||||
cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis)
|
||||
cat_segments = np.concatenate([b.segments for b in instances_list], axis=axis)
|
||||
cat_keypoints = np.concatenate([b.keypoints for b in instances_list], axis=axis) if use_keypoint else None
|
||||
return cls(cat_boxes, cat_segments, cat_keypoints, bbox_format, normalized)
|
||||
|
||||
@property
|
||||
def bboxes(self):
|
||||
"""Return bounding boxes."""
|
||||
return self._bboxes.bboxes
|
@ -1,392 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ultralytics.yolo.utils.metrics import OKS_SIGMA
|
||||
from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
|
||||
from ultralytics.yolo.utils.tal import TaskAlignedAssigner, dist2bbox, make_anchors
|
||||
|
||||
from .metrics import bbox_iou
|
||||
from .tal import bbox2dist
|
||||
|
||||
|
||||
class VarifocalLoss(nn.Module):
|
||||
"""Varifocal loss by Zhang et al. https://arxiv.org/abs/2008.13367."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the VarifocalLoss class."""
|
||||
super().__init__()
|
||||
|
||||
def forward(self, pred_score, gt_score, label, alpha=0.75, gamma=2.0):
|
||||
"""Computes varfocal loss."""
|
||||
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') *
|
||||
weight).mean(1).sum()
|
||||
return loss
|
||||
|
||||
|
||||
# Losses
|
||||
class FocalLoss(nn.Module):
|
||||
"""Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
|
||||
|
||||
def __init__(self, ):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, pred, label, gamma=1.5, alpha=0.25):
|
||||
"""Calculates and updates confusion matrix for object detection/classification tasks."""
|
||||
loss = F.binary_cross_entropy_with_logits(pred, label, reduction='none')
|
||||
# p_t = torch.exp(-loss)
|
||||
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
|
||||
|
||||
# TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
|
||||
pred_prob = pred.sigmoid() # prob from logits
|
||||
p_t = label * pred_prob + (1 - label) * (1 - pred_prob)
|
||||
modulating_factor = (1.0 - p_t) ** gamma
|
||||
loss *= modulating_factor
|
||||
if alpha > 0:
|
||||
alpha_factor = label * alpha + (1 - label) * (1 - alpha)
|
||||
loss *= alpha_factor
|
||||
return loss.mean(1).sum()
|
||||
|
||||
|
||||
class BboxLoss(nn.Module):
|
||||
|
||||
def __init__(self, reg_max, use_dfl=False):
|
||||
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
|
||||
super().__init__()
|
||||
self.reg_max = reg_max
|
||||
self.use_dfl = use_dfl
|
||||
|
||||
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
|
||||
"""IoU loss."""
|
||||
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
||||
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
|
||||
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
|
||||
|
||||
# DFL loss
|
||||
if self.use_dfl:
|
||||
target_ltrb = bbox2dist(anchor_points, target_bboxes, self.reg_max)
|
||||
loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
|
||||
loss_dfl = loss_dfl.sum() / target_scores_sum
|
||||
else:
|
||||
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
|
||||
|
||||
return loss_iou, loss_dfl
|
||||
|
||||
@staticmethod
|
||||
def _df_loss(pred_dist, target):
|
||||
"""Return sum of left and right DFL losses."""
|
||||
# Distribution Focal Loss (DFL) proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
|
||||
tl = target.long() # target left
|
||||
tr = tl + 1 # target right
|
||||
wl = tr - target # weight left
|
||||
wr = 1 - wl # weight right
|
||||
return (F.cross_entropy(pred_dist, tl.view(-1), reduction='none').view(tl.shape) * wl +
|
||||
F.cross_entropy(pred_dist, tr.view(-1), reduction='none').view(tl.shape) * wr).mean(-1, keepdim=True)
|
||||
|
||||
|
||||
class KeypointLoss(nn.Module):
|
||||
|
||||
def __init__(self, sigmas) -> None:
|
||||
super().__init__()
|
||||
self.sigmas = sigmas
|
||||
|
||||
def forward(self, pred_kpts, gt_kpts, kpt_mask, area):
|
||||
"""Calculates keypoint loss factor and Euclidean distance loss for predicted and actual keypoints."""
|
||||
d = (pred_kpts[..., 0] - gt_kpts[..., 0]) ** 2 + (pred_kpts[..., 1] - gt_kpts[..., 1]) ** 2
|
||||
kpt_loss_factor = (torch.sum(kpt_mask != 0) + torch.sum(kpt_mask == 0)) / (torch.sum(kpt_mask != 0) + 1e-9)
|
||||
# e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula
|
||||
e = d / (2 * self.sigmas) ** 2 / (area + 1e-9) / 2 # from cocoeval
|
||||
return kpt_loss_factor * ((1 - torch.exp(-e)) * kpt_mask).mean()
|
||||
|
||||
|
||||
# Criterion class for computing Detection training losses
|
||||
class v8DetectionLoss:
|
||||
|
||||
def __init__(self, model): # model must be de-paralleled
|
||||
|
||||
device = next(model.parameters()).device # get model device
|
||||
h = model.args # hyperparameters
|
||||
|
||||
m = model.model[-1] # Detect() module
|
||||
self.bce = nn.BCEWithLogitsLoss(reduction='none')
|
||||
self.hyp = h
|
||||
self.stride = m.stride # model strides
|
||||
self.nc = m.nc # number of classes
|
||||
self.no = m.no
|
||||
self.reg_max = m.reg_max
|
||||
self.device = device
|
||||
|
||||
self.use_dfl = m.reg_max > 1
|
||||
|
||||
self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
|
||||
self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device)
|
||||
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
|
||||
|
||||
def preprocess(self, targets, batch_size, scale_tensor):
|
||||
"""Preprocesses the target counts and matches with the input batch size to output a tensor."""
|
||||
if targets.shape[0] == 0:
|
||||
out = torch.zeros(batch_size, 0, 5, device=self.device)
|
||||
else:
|
||||
i = targets[:, 0] # image index
|
||||
_, counts = i.unique(return_counts=True)
|
||||
counts = counts.to(dtype=torch.int32)
|
||||
out = torch.zeros(batch_size, counts.max(), 5, device=self.device)
|
||||
for j in range(batch_size):
|
||||
matches = i == j
|
||||
n = matches.sum()
|
||||
if n:
|
||||
out[j, :n] = targets[matches, 1:]
|
||||
out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
|
||||
return out
|
||||
|
||||
def bbox_decode(self, anchor_points, pred_dist):
|
||||
"""Decode predicted object bounding box coordinates from anchor points and distribution."""
|
||||
if self.use_dfl:
|
||||
b, a, c = pred_dist.shape # batch, anchors, channels
|
||||
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
|
||||
# pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))
|
||||
# pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
|
||||
return dist2bbox(pred_dist, anchor_points, xywh=False)
|
||||
|
||||
def __call__(self, preds, batch):
|
||||
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
|
||||
loss = torch.zeros(3, device=self.device) # box, cls, dfl
|
||||
feats = preds[1] if isinstance(preds, tuple) else preds
|
||||
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
||||
(self.reg_max * 4, self.nc), 1)
|
||||
|
||||
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
||||
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
||||
|
||||
dtype = pred_scores.dtype
|
||||
batch_size = pred_scores.shape[0]
|
||||
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
||||
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
||||
|
||||
# targets
|
||||
targets = torch.cat((batch['batch_idx'].view(-1, 1), batch['cls'].view(-1, 1), batch['bboxes']), 1)
|
||||
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
||||
|
||||
# pboxes
|
||||
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
||||
|
||||
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
|
||||
pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
|
||||
|
||||
target_scores_sum = max(target_scores.sum(), 1)
|
||||
|
||||
# cls loss
|
||||
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
||||
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
||||
|
||||
# bbox loss
|
||||
if fg_mask.sum():
|
||||
target_bboxes /= stride_tensor
|
||||
loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
|
||||
target_scores_sum, fg_mask)
|
||||
|
||||
loss[0] *= self.hyp.box # box gain
|
||||
loss[1] *= self.hyp.cls # cls gain
|
||||
loss[2] *= self.hyp.dfl # dfl gain
|
||||
|
||||
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
|
||||
|
||||
|
||||
# Criterion class for computing training losses
|
||||
class v8SegmentationLoss(v8DetectionLoss):
|
||||
|
||||
def __init__(self, model): # model must be de-paralleled
|
||||
super().__init__(model)
|
||||
self.nm = model.model[-1].nm # number of masks
|
||||
self.overlap = model.args.overlap_mask
|
||||
|
||||
def __call__(self, preds, batch):
|
||||
"""Calculate and return the loss for the YOLO model."""
|
||||
loss = torch.zeros(4, device=self.device) # box, cls, dfl
|
||||
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
|
||||
batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
|
||||
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
||||
(self.reg_max * 4, self.nc), 1)
|
||||
|
||||
# b, grids, ..
|
||||
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
||||
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
||||
pred_masks = pred_masks.permute(0, 2, 1).contiguous()
|
||||
|
||||
dtype = pred_scores.dtype
|
||||
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
||||
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
||||
|
||||
# targets
|
||||
try:
|
||||
batch_idx = batch['batch_idx'].view(-1, 1)
|
||||
targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
|
||||
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
||||
except RuntimeError as e:
|
||||
raise TypeError('ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n'
|
||||
"This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
|
||||
"i.e. 'yolo train model=yolov8n-seg.pt data=coco128.yaml'.\nVerify your dataset is a "
|
||||
"correctly formatted 'segment' dataset using 'data=coco128-seg.yaml' "
|
||||
'as an example.\nSee https://docs.ultralytics.com/tasks/segment/ for help.') from e
|
||||
|
||||
# pboxes
|
||||
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
||||
|
||||
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
|
||||
pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
|
||||
|
||||
target_scores_sum = max(target_scores.sum(), 1)
|
||||
|
||||
# cls loss
|
||||
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
||||
loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
||||
|
||||
if fg_mask.sum():
|
||||
# bbox loss
|
||||
loss[0], loss[3] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor,
|
||||
target_scores, target_scores_sum, fg_mask)
|
||||
# masks loss
|
||||
masks = batch['masks'].to(self.device).float()
|
||||
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
|
||||
masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0]
|
||||
|
||||
for i in range(batch_size):
|
||||
if fg_mask[i].sum():
|
||||
mask_idx = target_gt_idx[i][fg_mask[i]]
|
||||
if self.overlap:
|
||||
gt_mask = torch.where(masks[[i]] == (mask_idx + 1).view(-1, 1, 1), 1.0, 0.0)
|
||||
else:
|
||||
gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
|
||||
xyxyn = target_bboxes[i][fg_mask[i]] / imgsz[[1, 0, 1, 0]]
|
||||
marea = xyxy2xywh(xyxyn)[:, 2:].prod(1)
|
||||
mxyxy = xyxyn * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device)
|
||||
loss[1] += self.single_mask_loss(gt_mask, pred_masks[i][fg_mask[i]], proto[i], mxyxy, marea) # seg
|
||||
|
||||
# WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
|
||||
else:
|
||||
loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
|
||||
|
||||
# WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
|
||||
else:
|
||||
loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
|
||||
|
||||
loss[0] *= self.hyp.box # box gain
|
||||
loss[1] *= self.hyp.box / batch_size # seg gain
|
||||
loss[2] *= self.hyp.cls # cls gain
|
||||
loss[3] *= self.hyp.dfl # dfl gain
|
||||
|
||||
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
|
||||
|
||||
def single_mask_loss(self, gt_mask, pred, proto, xyxy, area):
|
||||
"""Mask loss for one image."""
|
||||
pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:]) # (n, 32) @ (32,80,80) -> (n,80,80)
|
||||
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none')
|
||||
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
|
||||
|
||||
|
||||
# Criterion class for computing training losses
|
||||
class v8PoseLoss(v8DetectionLoss):
|
||||
|
||||
def __init__(self, model): # model must be de-paralleled
|
||||
super().__init__(model)
|
||||
self.kpt_shape = model.model[-1].kpt_shape
|
||||
self.bce_pose = nn.BCEWithLogitsLoss()
|
||||
is_pose = self.kpt_shape == [17, 3]
|
||||
nkpt = self.kpt_shape[0] # number of keypoints
|
||||
sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
|
||||
self.keypoint_loss = KeypointLoss(sigmas=sigmas)
|
||||
|
||||
def __call__(self, preds, batch):
|
||||
"""Calculate the total loss and detach it."""
|
||||
loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
|
||||
feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
|
||||
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
||||
(self.reg_max * 4, self.nc), 1)
|
||||
|
||||
# b, grids, ..
|
||||
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
||||
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
||||
pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()
|
||||
|
||||
dtype = pred_scores.dtype
|
||||
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
||||
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
||||
|
||||
# targets
|
||||
batch_size = pred_scores.shape[0]
|
||||
batch_idx = batch['batch_idx'].view(-1, 1)
|
||||
targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
|
||||
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
||||
|
||||
# pboxes
|
||||
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
||||
pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)
|
||||
|
||||
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
|
||||
pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
|
||||
|
||||
target_scores_sum = max(target_scores.sum(), 1)
|
||||
|
||||
# cls loss
|
||||
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
||||
loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
||||
|
||||
# bbox loss
|
||||
if fg_mask.sum():
|
||||
target_bboxes /= stride_tensor
|
||||
loss[0], loss[4] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
|
||||
target_scores_sum, fg_mask)
|
||||
keypoints = batch['keypoints'].to(self.device).float().clone()
|
||||
keypoints[..., 0] *= imgsz[1]
|
||||
keypoints[..., 1] *= imgsz[0]
|
||||
for i in range(batch_size):
|
||||
if fg_mask[i].sum():
|
||||
idx = target_gt_idx[i][fg_mask[i]]
|
||||
gt_kpt = keypoints[batch_idx.view(-1) == i][idx] # (n, 51)
|
||||
gt_kpt[..., 0] /= stride_tensor[fg_mask[i]]
|
||||
gt_kpt[..., 1] /= stride_tensor[fg_mask[i]]
|
||||
area = xyxy2xywh(target_bboxes[i][fg_mask[i]])[:, 2:].prod(1, keepdim=True)
|
||||
pred_kpt = pred_kpts[i][fg_mask[i]]
|
||||
kpt_mask = gt_kpt[..., 2] != 0
|
||||
loss[1] += self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss
|
||||
# kpt_score loss
|
||||
if pred_kpt.shape[-1] == 3:
|
||||
loss[2] += self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss
|
||||
|
||||
loss[0] *= self.hyp.box # box gain
|
||||
loss[1] *= self.hyp.pose / batch_size # pose gain
|
||||
loss[2] *= self.hyp.kobj / batch_size # kobj gain
|
||||
loss[3] *= self.hyp.cls # cls gain
|
||||
loss[4] *= self.hyp.dfl # dfl gain
|
||||
|
||||
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
|
||||
|
||||
def kpts_decode(self, anchor_points, pred_kpts):
|
||||
"""Decodes predicted keypoints to image coordinates."""
|
||||
y = pred_kpts.clone()
|
||||
y[..., :2] *= 2.0
|
||||
y[..., 0] += anchor_points[:, [0]] - 0.5
|
||||
y[..., 1] += anchor_points[:, [1]] - 0.5
|
||||
return y
|
||||
|
||||
|
||||
class v8ClassificationLoss:
|
||||
|
||||
def __call__(self, preds, batch):
|
||||
"""Compute the classification loss between predictions and true labels."""
|
||||
loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / 64
|
||||
loss_items = loss.detach()
|
||||
return loss, loss_items
|
@ -1,977 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Model validation metrics
|
||||
"""
|
||||
import math
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER, SimpleClass, TryExcept, plt_settings
|
||||
|
||||
OKS_SIGMA = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0
|
||||
|
||||
|
||||
# Boxes
|
||||
def box_area(box):
|
||||
"""Return box area, where box shape is xyxy(4,n)."""
|
||||
return (box[2] - box[0]) * (box[3] - box[1])
|
||||
|
||||
|
||||
def bbox_ioa(box1, box2, eps=1e-7):
|
||||
"""
|
||||
Calculate the intersection over box2 area given box1 and box2. Boxes are in x1y1x2y2 format.
|
||||
|
||||
Args:
|
||||
box1 (np.array): A numpy array of shape (n, 4) representing n bounding boxes.
|
||||
box2 (np.array): A numpy array of shape (m, 4) representing m bounding boxes.
|
||||
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
|
||||
|
||||
Returns:
|
||||
(np.array): A numpy array of shape (n, m) representing the intersection over box2 area.
|
||||
"""
|
||||
|
||||
# Get the coordinates of bounding boxes
|
||||
b1_x1, b1_y1, b1_x2, b1_y2 = box1.T
|
||||
b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
|
||||
|
||||
# Intersection area
|
||||
inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * \
|
||||
(np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1)).clip(0)
|
||||
|
||||
# box2 area
|
||||
box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + eps
|
||||
|
||||
# Intersection over box2 area
|
||||
return inter_area / box2_area
|
||||
|
||||
|
||||
def box_iou(box1, box2, eps=1e-7):
|
||||
"""
|
||||
Calculate intersection-over-union (IoU) of boxes.
|
||||
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
||||
Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
|
||||
|
||||
Args:
|
||||
box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes.
|
||||
box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes.
|
||||
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.
|
||||
"""
|
||||
|
||||
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
|
||||
(a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
|
||||
inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp_(0).prod(2)
|
||||
|
||||
# IoU = inter / (area1 + area2 - inter)
|
||||
return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)
|
||||
|
||||
|
||||
def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
|
||||
"""
|
||||
Calculate Intersection over Union (IoU) of box1(1, 4) to box2(n, 4).
|
||||
|
||||
Args:
|
||||
box1 (torch.Tensor): A tensor representing a single bounding box with shape (1, 4).
|
||||
box2 (torch.Tensor): A tensor representing n bounding boxes with shape (n, 4).
|
||||
xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in
|
||||
(x1, y1, x2, y2) format. Defaults to True.
|
||||
GIoU (bool, optional): If True, calculate Generalized IoU. Defaults to False.
|
||||
DIoU (bool, optional): If True, calculate Distance IoU. Defaults to False.
|
||||
CIoU (bool, optional): If True, calculate Complete IoU. Defaults to False.
|
||||
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags.
|
||||
"""
|
||||
|
||||
# Get the coordinates of bounding boxes
|
||||
if xywh: # transform from xywh to xyxy
|
||||
(x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
|
||||
w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
|
||||
b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
|
||||
b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
|
||||
else: # x1, y1, x2, y2 = box1
|
||||
b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
|
||||
b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
|
||||
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
|
||||
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
|
||||
|
||||
# Intersection area
|
||||
inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * \
|
||||
(b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp_(0)
|
||||
|
||||
# Union Area
|
||||
union = w1 * h1 + w2 * h2 - inter + eps
|
||||
|
||||
# IoU
|
||||
iou = inter / union
|
||||
if CIoU or DIoU or GIoU:
|
||||
cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width
|
||||
ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height
|
||||
if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
|
||||
c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
|
||||
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
|
||||
if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
|
||||
v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
|
||||
with torch.no_grad():
|
||||
alpha = v / (v - iou + (1 + eps))
|
||||
return iou - (rho2 / c2 + v * alpha) # CIoU
|
||||
return iou - rho2 / c2 # DIoU
|
||||
c_area = cw * ch + eps # convex area
|
||||
return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf
|
||||
return iou # IoU
|
||||
|
||||
|
||||
def mask_iou(mask1, mask2, eps=1e-7):
|
||||
"""
|
||||
Calculate masks IoU.
|
||||
|
||||
Args:
|
||||
mask1 (torch.Tensor): A tensor of shape (N, n) where N is the number of ground truth objects and n is the
|
||||
product of image width and height.
|
||||
mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the
|
||||
product of image width and height.
|
||||
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): A tensor of shape (N, M) representing masks IoU.
|
||||
"""
|
||||
intersection = torch.matmul(mask1, mask2.T).clamp_(0)
|
||||
union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection # (area1 + area2) - intersection
|
||||
return intersection / (union + eps)
|
||||
|
||||
|
||||
def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7):
|
||||
"""
|
||||
Calculate Object Keypoint Similarity (OKS).
|
||||
|
||||
Args:
|
||||
kpt1 (torch.Tensor): A tensor of shape (N, 17, 3) representing ground truth keypoints.
|
||||
kpt2 (torch.Tensor): A tensor of shape (M, 17, 3) representing predicted keypoints.
|
||||
area (torch.Tensor): A tensor of shape (N,) representing areas from ground truth.
|
||||
sigma (list): A list containing 17 values representing keypoint scales.
|
||||
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): A tensor of shape (N, M) representing keypoint similarities.
|
||||
"""
|
||||
d = (kpt1[:, None, :, 0] - kpt2[..., 0]) ** 2 + (kpt1[:, None, :, 1] - kpt2[..., 1]) ** 2 # (N, M, 17)
|
||||
sigma = torch.tensor(sigma, device=kpt1.device, dtype=kpt1.dtype) # (17, )
|
||||
kpt_mask = kpt1[..., 2] != 0 # (N, 17)
|
||||
e = d / (2 * sigma) ** 2 / (area[:, None, None] + eps) / 2 # from cocoeval
|
||||
# e = d / ((area[None, :, None] + eps) * sigma) ** 2 / 2 # from formula
|
||||
return (torch.exp(-e) * kpt_mask[:, None]).sum(-1) / (kpt_mask.sum(-1)[:, None] + eps)
|
||||
|
||||
|
||||
def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
|
||||
# return positive, negative label smoothing BCE targets
|
||||
return 1.0 - 0.5 * eps, 0.5 * eps
|
||||
|
||||
|
||||
class ConfusionMatrix:
|
||||
"""
|
||||
A class for calculating and updating a confusion matrix for object detection and classification tasks.
|
||||
|
||||
Attributes:
|
||||
task (str): The type of task, either 'detect' or 'classify'.
|
||||
matrix (np.array): The confusion matrix, with dimensions depending on the task.
|
||||
nc (int): The number of classes.
|
||||
conf (float): The confidence threshold for detections.
|
||||
iou_thres (float): The Intersection over Union threshold.
|
||||
"""
|
||||
|
||||
def __init__(self, nc, conf=0.25, iou_thres=0.45, task='detect'):
|
||||
"""Initialize attributes for the YOLO model."""
|
||||
self.task = task
|
||||
self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == 'detect' else np.zeros((nc, nc))
|
||||
self.nc = nc # number of classes
|
||||
self.conf = conf
|
||||
self.iou_thres = iou_thres
|
||||
|
||||
def process_cls_preds(self, preds, targets):
|
||||
"""
|
||||
Update confusion matrix for classification task
|
||||
|
||||
Args:
|
||||
preds (Array[N, min(nc,5)]): Predicted class labels.
|
||||
targets (Array[N, 1]): Ground truth class labels.
|
||||
"""
|
||||
preds, targets = torch.cat(preds)[:, 0], torch.cat(targets)
|
||||
for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):
|
||||
self.matrix[p][t] += 1
|
||||
|
||||
def process_batch(self, detections, labels):
|
||||
"""
|
||||
Update confusion matrix for object detection task.
|
||||
|
||||
Args:
|
||||
detections (Array[N, 6]): Detected bounding boxes and their associated information.
|
||||
Each row should contain (x1, y1, x2, y2, conf, class).
|
||||
labels (Array[M, 5]): Ground truth bounding boxes and their associated class labels.
|
||||
Each row should contain (class, x1, y1, x2, y2).
|
||||
"""
|
||||
if detections is None:
|
||||
gt_classes = labels.int()
|
||||
for gc in gt_classes:
|
||||
self.matrix[self.nc, gc] += 1 # background FN
|
||||
return
|
||||
|
||||
detections = detections[detections[:, 4] > self.conf]
|
||||
gt_classes = labels[:, 0].int()
|
||||
detection_classes = detections[:, 5].int()
|
||||
iou = box_iou(labels[:, 1:], detections[:, :4])
|
||||
|
||||
x = torch.where(iou > self.iou_thres)
|
||||
if x[0].shape[0]:
|
||||
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
|
||||
if x[0].shape[0] > 1:
|
||||
matches = matches[matches[:, 2].argsort()[::-1]]
|
||||
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
||||
matches = matches[matches[:, 2].argsort()[::-1]]
|
||||
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
||||
else:
|
||||
matches = np.zeros((0, 3))
|
||||
|
||||
n = matches.shape[0] > 0
|
||||
m0, m1, _ = matches.transpose().astype(int)
|
||||
for i, gc in enumerate(gt_classes):
|
||||
j = m0 == i
|
||||
if n and sum(j) == 1:
|
||||
self.matrix[detection_classes[m1[j]], gc] += 1 # correct
|
||||
else:
|
||||
self.matrix[self.nc, gc] += 1 # true background
|
||||
|
||||
if n:
|
||||
for i, dc in enumerate(detection_classes):
|
||||
if not any(m1 == i):
|
||||
self.matrix[dc, self.nc] += 1 # predicted background
|
||||
|
||||
def matrix(self):
|
||||
"""Returns the confusion matrix."""
|
||||
return self.matrix
|
||||
|
||||
def tp_fp(self):
|
||||
"""Returns true positives and false positives."""
|
||||
tp = self.matrix.diagonal() # true positives
|
||||
fp = self.matrix.sum(1) - tp # false positives
|
||||
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
|
||||
return (tp[:-1], fp[:-1]) if self.task == 'detect' else (tp, fp) # remove background class if task=detect
|
||||
|
||||
@TryExcept('WARNING ⚠️ ConfusionMatrix plot failure')
|
||||
@plt_settings()
|
||||
def plot(self, normalize=True, save_dir='', names=(), on_plot=None):
|
||||
"""
|
||||
Plot the confusion matrix using seaborn and save it to a file.
|
||||
|
||||
Args:
|
||||
normalize (bool): Whether to normalize the confusion matrix.
|
||||
save_dir (str): Directory where the plot will be saved.
|
||||
names (tuple): Names of classes, used as labels on the plot.
|
||||
on_plot (func): An optional callback to pass plots path and data when they are rendered.
|
||||
"""
|
||||
import seaborn as sn
|
||||
|
||||
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
|
||||
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
|
||||
|
||||
fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
|
||||
nc, nn = self.nc, len(names) # number of classes, names
|
||||
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
|
||||
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
|
||||
ticklabels = (list(names) + ['background']) if labels else 'auto'
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
|
||||
sn.heatmap(array,
|
||||
ax=ax,
|
||||
annot=nc < 30,
|
||||
annot_kws={
|
||||
'size': 8},
|
||||
cmap='Blues',
|
||||
fmt='.2f' if normalize else '.0f',
|
||||
square=True,
|
||||
vmin=0.0,
|
||||
xticklabels=ticklabels,
|
||||
yticklabels=ticklabels).set_facecolor((1, 1, 1))
|
||||
title = 'Confusion Matrix' + ' Normalized' * normalize
|
||||
ax.set_xlabel('True')
|
||||
ax.set_ylabel('Predicted')
|
||||
ax.set_title(title)
|
||||
plot_fname = Path(save_dir) / f'{title.lower().replace(" ", "_")}.png'
|
||||
fig.savefig(plot_fname, dpi=250)
|
||||
plt.close(fig)
|
||||
if on_plot:
|
||||
on_plot(plot_fname)
|
||||
|
||||
def print(self):
|
||||
"""
|
||||
Print the confusion matrix to the console.
|
||||
"""
|
||||
for i in range(self.nc + 1):
|
||||
LOGGER.info(' '.join(map(str, self.matrix[i])))
|
||||
|
||||
|
||||
def smooth(y, f=0.05):
|
||||
"""Box filter of fraction f."""
|
||||
nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
|
||||
p = np.ones(nf // 2) # ones padding
|
||||
yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded
|
||||
return np.convolve(yp, np.ones(nf) / nf, mode='valid') # y-smoothed
|
||||
|
||||
|
||||
@plt_settings()
|
||||
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=(), on_plot=None):
|
||||
"""Plots a precision-recall curve."""
|
||||
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
||||
py = np.stack(py, axis=1)
|
||||
|
||||
if 0 < len(names) < 21: # display per-class legend if < 21 classes
|
||||
for i, y in enumerate(py.T):
|
||||
ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision)
|
||||
else:
|
||||
ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
|
||||
|
||||
ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
|
||||
ax.set_xlabel('Recall')
|
||||
ax.set_ylabel('Precision')
|
||||
ax.set_xlim(0, 1)
|
||||
ax.set_ylim(0, 1)
|
||||
ax.legend(bbox_to_anchor=(1.04, 1), loc='upper left')
|
||||
ax.set_title('Precision-Recall Curve')
|
||||
fig.savefig(save_dir, dpi=250)
|
||||
plt.close(fig)
|
||||
if on_plot:
|
||||
on_plot(save_dir)
|
||||
|
||||
|
||||
@plt_settings()
|
||||
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric', on_plot=None):
|
||||
"""Plots a metric-confidence curve."""
|
||||
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
||||
|
||||
if 0 < len(names) < 21: # display per-class legend if < 21 classes
|
||||
for i, y in enumerate(py):
|
||||
ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric)
|
||||
else:
|
||||
ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric)
|
||||
|
||||
y = smooth(py.mean(0), 0.05)
|
||||
ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}')
|
||||
ax.set_xlabel(xlabel)
|
||||
ax.set_ylabel(ylabel)
|
||||
ax.set_xlim(0, 1)
|
||||
ax.set_ylim(0, 1)
|
||||
ax.legend(bbox_to_anchor=(1.04, 1), loc='upper left')
|
||||
ax.set_title(f'{ylabel}-Confidence Curve')
|
||||
fig.savefig(save_dir, dpi=250)
|
||||
plt.close(fig)
|
||||
if on_plot:
|
||||
on_plot(save_dir)
|
||||
|
||||
|
||||
def compute_ap(recall, precision):
|
||||
"""
|
||||
Compute the average precision (AP) given the recall and precision curves.
|
||||
|
||||
Arguments:
|
||||
recall (list): The recall curve.
|
||||
precision (list): The precision curve.
|
||||
|
||||
Returns:
|
||||
(float): Average precision.
|
||||
(np.ndarray): Precision envelope curve.
|
||||
(np.ndarray): Modified recall curve with sentinel values added at the beginning and end.
|
||||
"""
|
||||
|
||||
# Append sentinel values to beginning and end
|
||||
mrec = np.concatenate(([0.0], recall, [1.0]))
|
||||
mpre = np.concatenate(([1.0], precision, [0.0]))
|
||||
|
||||
# Compute the precision envelope
|
||||
mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
|
||||
|
||||
# Integrate area under curve
|
||||
method = 'interp' # methods: 'continuous', 'interp'
|
||||
if method == 'interp':
|
||||
x = np.linspace(0, 1, 101) # 101-point interp (COCO)
|
||||
ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
|
||||
else: # 'continuous'
|
||||
i = np.where(mrec[1:] != mrec[:-1])[0] # points where x-axis (recall) changes
|
||||
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
|
||||
|
||||
return ap, mpre, mrec
|
||||
|
||||
|
||||
def ap_per_class(tp,
|
||||
conf,
|
||||
pred_cls,
|
||||
target_cls,
|
||||
plot=False,
|
||||
on_plot=None,
|
||||
save_dir=Path(),
|
||||
names=(),
|
||||
eps=1e-16,
|
||||
prefix=''):
|
||||
"""
|
||||
Computes the average precision per class for object detection evaluation.
|
||||
|
||||
Args:
|
||||
tp (np.ndarray): Binary array indicating whether the detection is correct (True) or not (False).
|
||||
conf (np.ndarray): Array of confidence scores of the detections.
|
||||
pred_cls (np.ndarray): Array of predicted classes of the detections.
|
||||
target_cls (np.ndarray): Array of true classes of the detections.
|
||||
plot (bool, optional): Whether to plot PR curves or not. Defaults to False.
|
||||
on_plot (func, optional): A callback to pass plots path and data when they are rendered. Defaults to None.
|
||||
save_dir (Path, optional): Directory to save the PR curves. Defaults to an empty path.
|
||||
names (tuple, optional): Tuple of class names to plot PR curves. Defaults to an empty tuple.
|
||||
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-16.
|
||||
prefix (str, optional): A prefix string for saving the plot files. Defaults to an empty string.
|
||||
|
||||
Returns:
|
||||
(tuple): A tuple of six arrays and one array of unique classes, where:
|
||||
tp (np.ndarray): True positive counts for each class.
|
||||
fp (np.ndarray): False positive counts for each class.
|
||||
p (np.ndarray): Precision values at each confidence threshold.
|
||||
r (np.ndarray): Recall values at each confidence threshold.
|
||||
f1 (np.ndarray): F1-score values at each confidence threshold.
|
||||
ap (np.ndarray): Average precision for each class at different IoU thresholds.
|
||||
unique_classes (np.ndarray): An array of unique classes that have data.
|
||||
|
||||
"""
|
||||
|
||||
# Sort by objectness
|
||||
i = np.argsort(-conf)
|
||||
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
|
||||
|
||||
# Find unique classes
|
||||
unique_classes, nt = np.unique(target_cls, return_counts=True)
|
||||
nc = unique_classes.shape[0] # number of classes, number of detections
|
||||
|
||||
# Create Precision-Recall curve and compute AP for each class
|
||||
px, py = np.linspace(0, 1, 1000), [] # for plotting
|
||||
ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
|
||||
for ci, c in enumerate(unique_classes):
|
||||
i = pred_cls == c
|
||||
n_l = nt[ci] # number of labels
|
||||
n_p = i.sum() # number of predictions
|
||||
if n_p == 0 or n_l == 0:
|
||||
continue
|
||||
|
||||
# Accumulate FPs and TPs
|
||||
fpc = (1 - tp[i]).cumsum(0)
|
||||
tpc = tp[i].cumsum(0)
|
||||
|
||||
# Recall
|
||||
recall = tpc / (n_l + eps) # recall curve
|
||||
r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
|
||||
|
||||
# Precision
|
||||
precision = tpc / (tpc + fpc) # precision curve
|
||||
p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
|
||||
|
||||
# AP from recall-precision curve
|
||||
for j in range(tp.shape[1]):
|
||||
ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
|
||||
if plot and j == 0:
|
||||
py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
|
||||
|
||||
# Compute F1 (harmonic mean of precision and recall)
|
||||
f1 = 2 * p * r / (p + r + eps)
|
||||
names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
|
||||
names = dict(enumerate(names)) # to dict
|
||||
if plot:
|
||||
plot_pr_curve(px, py, ap, save_dir / f'{prefix}PR_curve.png', names, on_plot=on_plot)
|
||||
plot_mc_curve(px, f1, save_dir / f'{prefix}F1_curve.png', names, ylabel='F1', on_plot=on_plot)
|
||||
plot_mc_curve(px, p, save_dir / f'{prefix}P_curve.png', names, ylabel='Precision', on_plot=on_plot)
|
||||
plot_mc_curve(px, r, save_dir / f'{prefix}R_curve.png', names, ylabel='Recall', on_plot=on_plot)
|
||||
|
||||
i = smooth(f1.mean(0), 0.1).argmax() # max F1 index
|
||||
p, r, f1 = p[:, i], r[:, i], f1[:, i]
|
||||
tp = (r * nt).round() # true positives
|
||||
fp = (tp / (p + eps) - tp).round() # false positives
|
||||
return tp, fp, p, r, f1, ap, unique_classes.astype(int)
|
||||
|
||||
|
||||
class Metric(SimpleClass):
|
||||
"""
|
||||
Class for computing evaluation metrics for YOLOv8 model.
|
||||
|
||||
Attributes:
|
||||
p (list): Precision for each class. Shape: (nc,).
|
||||
r (list): Recall for each class. Shape: (nc,).
|
||||
f1 (list): F1 score for each class. Shape: (nc,).
|
||||
all_ap (list): AP scores for all classes and all IoU thresholds. Shape: (nc, 10).
|
||||
ap_class_index (list): Index of class for each AP score. Shape: (nc,).
|
||||
nc (int): Number of classes.
|
||||
|
||||
Methods:
|
||||
ap50(): AP at IoU threshold of 0.5 for all classes. Returns: List of AP scores. Shape: (nc,) or [].
|
||||
ap(): AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: List of AP scores. Shape: (nc,) or [].
|
||||
mp(): Mean precision of all classes. Returns: Float.
|
||||
mr(): Mean recall of all classes. Returns: Float.
|
||||
map50(): Mean AP at IoU threshold of 0.5 for all classes. Returns: Float.
|
||||
map75(): Mean AP at IoU threshold of 0.75 for all classes. Returns: Float.
|
||||
map(): Mean AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: Float.
|
||||
mean_results(): Mean of results, returns mp, mr, map50, map.
|
||||
class_result(i): Class-aware result, returns p[i], r[i], ap50[i], ap[i].
|
||||
maps(): mAP of each class. Returns: Array of mAP scores, shape: (nc,).
|
||||
fitness(): Model fitness as a weighted combination of metrics. Returns: Float.
|
||||
update(results): Update metric attributes with new evaluation results.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.p = [] # (nc, )
|
||||
self.r = [] # (nc, )
|
||||
self.f1 = [] # (nc, )
|
||||
self.all_ap = [] # (nc, 10)
|
||||
self.ap_class_index = [] # (nc, )
|
||||
self.nc = 0
|
||||
|
||||
@property
|
||||
def ap50(self):
|
||||
"""
|
||||
Returns the Average Precision (AP) at an IoU threshold of 0.5 for all classes.
|
||||
|
||||
Returns:
|
||||
(np.ndarray, list): Array of shape (nc,) with AP50 values per class, or an empty list if not available.
|
||||
"""
|
||||
return self.all_ap[:, 0] if len(self.all_ap) else []
|
||||
|
||||
@property
|
||||
def ap(self):
|
||||
"""
|
||||
Returns the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.
|
||||
|
||||
Returns:
|
||||
(np.ndarray, list): Array of shape (nc,) with AP50-95 values per class, or an empty list if not available.
|
||||
"""
|
||||
return self.all_ap.mean(1) if len(self.all_ap) else []
|
||||
|
||||
@property
|
||||
def mp(self):
|
||||
"""
|
||||
Returns the Mean Precision of all classes.
|
||||
|
||||
Returns:
|
||||
(float): The mean precision of all classes.
|
||||
"""
|
||||
return self.p.mean() if len(self.p) else 0.0
|
||||
|
||||
@property
|
||||
def mr(self):
|
||||
"""
|
||||
Returns the Mean Recall of all classes.
|
||||
|
||||
Returns:
|
||||
(float): The mean recall of all classes.
|
||||
"""
|
||||
return self.r.mean() if len(self.r) else 0.0
|
||||
|
||||
@property
|
||||
def map50(self):
|
||||
"""
|
||||
Returns the mean Average Precision (mAP) at an IoU threshold of 0.5.
|
||||
|
||||
Returns:
|
||||
(float): The mAP50 at an IoU threshold of 0.5.
|
||||
"""
|
||||
return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0
|
||||
|
||||
@property
|
||||
def map75(self):
|
||||
"""
|
||||
Returns the mean Average Precision (mAP) at an IoU threshold of 0.75.
|
||||
|
||||
Returns:
|
||||
(float): The mAP50 at an IoU threshold of 0.75.
|
||||
"""
|
||||
return self.all_ap[:, 5].mean() if len(self.all_ap) else 0.0
|
||||
|
||||
@property
|
||||
def map(self):
|
||||
"""
|
||||
Returns the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
|
||||
|
||||
Returns:
|
||||
(float): The mAP over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
|
||||
"""
|
||||
return self.all_ap.mean() if len(self.all_ap) else 0.0
|
||||
|
||||
def mean_results(self):
|
||||
"""Mean of results, return mp, mr, map50, map."""
|
||||
return [self.mp, self.mr, self.map50, self.map]
|
||||
|
||||
def class_result(self, i):
|
||||
"""class-aware result, return p[i], r[i], ap50[i], ap[i]."""
|
||||
return self.p[i], self.r[i], self.ap50[i], self.ap[i]
|
||||
|
||||
@property
|
||||
def maps(self):
|
||||
"""mAP of each class."""
|
||||
maps = np.zeros(self.nc) + self.map
|
||||
for i, c in enumerate(self.ap_class_index):
|
||||
maps[c] = self.ap[i]
|
||||
return maps
|
||||
|
||||
def fitness(self):
|
||||
"""Model fitness as a weighted combination of metrics."""
|
||||
w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
|
||||
return (np.array(self.mean_results()) * w).sum()
|
||||
|
||||
def update(self, results):
|
||||
"""
|
||||
Args:
|
||||
results (tuple): A tuple of (p, r, ap, f1, ap_class)
|
||||
"""
|
||||
self.p, self.r, self.f1, self.all_ap, self.ap_class_index = results
|
||||
|
||||
|
||||
class DetMetrics(SimpleClass):
|
||||
"""
|
||||
This class is a utility class for computing detection metrics such as precision, recall, and mean average precision
|
||||
(mAP) of an object detection model.
|
||||
|
||||
Args:
|
||||
save_dir (Path): A path to the directory where the output plots will be saved. Defaults to current directory.
|
||||
plot (bool): A flag that indicates whether to plot precision-recall curves for each class. Defaults to False.
|
||||
on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
|
||||
names (tuple of str): A tuple of strings that represents the names of the classes. Defaults to an empty tuple.
|
||||
|
||||
Attributes:
|
||||
save_dir (Path): A path to the directory where the output plots will be saved.
|
||||
plot (bool): A flag that indicates whether to plot the precision-recall curves for each class.
|
||||
on_plot (func): An optional callback to pass plots path and data when they are rendered.
|
||||
names (tuple of str): A tuple of strings that represents the names of the classes.
|
||||
box (Metric): An instance of the Metric class for storing the results of the detection metrics.
|
||||
speed (dict): A dictionary for storing the execution time of different parts of the detection process.
|
||||
|
||||
Methods:
|
||||
process(tp, conf, pred_cls, target_cls): Updates the metric results with the latest batch of predictions.
|
||||
keys: Returns a list of keys for accessing the computed detection metrics.
|
||||
mean_results: Returns a list of mean values for the computed detection metrics.
|
||||
class_result(i): Returns a list of values for the computed detection metrics for a specific class.
|
||||
maps: Returns a dictionary of mean average precision (mAP) values for different IoU thresholds.
|
||||
fitness: Computes the fitness score based on the computed detection metrics.
|
||||
ap_class_index: Returns a list of class indices sorted by their average precision (AP) values.
|
||||
results_dict: Returns a dictionary that maps detection metric keys to their computed values.
|
||||
"""
|
||||
|
||||
def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
|
||||
self.save_dir = save_dir
|
||||
self.plot = plot
|
||||
self.on_plot = on_plot
|
||||
self.names = names
|
||||
self.box = Metric()
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
|
||||
def process(self, tp, conf, pred_cls, target_cls):
|
||||
"""Process predicted results for object detection and update metrics."""
|
||||
results = ap_per_class(tp,
|
||||
conf,
|
||||
pred_cls,
|
||||
target_cls,
|
||||
plot=self.plot,
|
||||
save_dir=self.save_dir,
|
||||
names=self.names,
|
||||
on_plot=self.on_plot)[2:]
|
||||
self.box.nc = len(self.names)
|
||||
self.box.update(results)
|
||||
|
||||
@property
|
||||
def keys(self):
|
||||
"""Returns a list of keys for accessing specific metrics."""
|
||||
return ['metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)']
|
||||
|
||||
def mean_results(self):
|
||||
"""Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
|
||||
return self.box.mean_results()
|
||||
|
||||
def class_result(self, i):
|
||||
"""Return the result of evaluating the performance of an object detection model on a specific class."""
|
||||
return self.box.class_result(i)
|
||||
|
||||
@property
|
||||
def maps(self):
|
||||
"""Returns mean Average Precision (mAP) scores per class."""
|
||||
return self.box.maps
|
||||
|
||||
@property
|
||||
def fitness(self):
|
||||
"""Returns the fitness of box object."""
|
||||
return self.box.fitness()
|
||||
|
||||
@property
|
||||
def ap_class_index(self):
|
||||
"""Returns the average precision index per class."""
|
||||
return self.box.ap_class_index
|
||||
|
||||
@property
|
||||
def results_dict(self):
|
||||
"""Returns dictionary of computed performance metrics and statistics."""
|
||||
return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
|
||||
|
||||
|
||||
class SegmentMetrics(SimpleClass):
|
||||
"""
|
||||
Calculates and aggregates detection and segmentation metrics over a given set of classes.
|
||||
|
||||
Args:
|
||||
save_dir (Path): Path to the directory where the output plots should be saved. Default is the current directory.
|
||||
plot (bool): Whether to save the detection and segmentation plots. Default is False.
|
||||
on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
|
||||
names (list): List of class names. Default is an empty list.
|
||||
|
||||
Attributes:
|
||||
save_dir (Path): Path to the directory where the output plots should be saved.
|
||||
plot (bool): Whether to save the detection and segmentation plots.
|
||||
on_plot (func): An optional callback to pass plots path and data when they are rendered.
|
||||
names (list): List of class names.
|
||||
box (Metric): An instance of the Metric class to calculate box detection metrics.
|
||||
seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
|
||||
speed (dict): Dictionary to store the time taken in different phases of inference.
|
||||
|
||||
Methods:
|
||||
process(tp_m, tp_b, conf, pred_cls, target_cls): Processes metrics over the given set of predictions.
|
||||
mean_results(): Returns the mean of the detection and segmentation metrics over all the classes.
|
||||
class_result(i): Returns the detection and segmentation metrics of class `i`.
|
||||
maps: Returns the mean Average Precision (mAP) scores for IoU thresholds ranging from 0.50 to 0.95.
|
||||
fitness: Returns the fitness scores, which are a single weighted combination of metrics.
|
||||
ap_class_index: Returns the list of indices of classes used to compute Average Precision (AP).
|
||||
results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
|
||||
"""
|
||||
|
||||
def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
|
||||
self.save_dir = save_dir
|
||||
self.plot = plot
|
||||
self.on_plot = on_plot
|
||||
self.names = names
|
||||
self.box = Metric()
|
||||
self.seg = Metric()
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
|
||||
def process(self, tp_b, tp_m, conf, pred_cls, target_cls):
|
||||
"""
|
||||
Processes the detection and segmentation metrics over the given set of predictions.
|
||||
|
||||
Args:
|
||||
tp_b (list): List of True Positive boxes.
|
||||
tp_m (list): List of True Positive masks.
|
||||
conf (list): List of confidence scores.
|
||||
pred_cls (list): List of predicted classes.
|
||||
target_cls (list): List of target classes.
|
||||
"""
|
||||
|
||||
results_mask = ap_per_class(tp_m,
|
||||
conf,
|
||||
pred_cls,
|
||||
target_cls,
|
||||
plot=self.plot,
|
||||
on_plot=self.on_plot,
|
||||
save_dir=self.save_dir,
|
||||
names=self.names,
|
||||
prefix='Mask')[2:]
|
||||
self.seg.nc = len(self.names)
|
||||
self.seg.update(results_mask)
|
||||
results_box = ap_per_class(tp_b,
|
||||
conf,
|
||||
pred_cls,
|
||||
target_cls,
|
||||
plot=self.plot,
|
||||
on_plot=self.on_plot,
|
||||
save_dir=self.save_dir,
|
||||
names=self.names,
|
||||
prefix='Box')[2:]
|
||||
self.box.nc = len(self.names)
|
||||
self.box.update(results_box)
|
||||
|
||||
@property
|
||||
def keys(self):
|
||||
"""Returns a list of keys for accessing metrics."""
|
||||
return [
|
||||
'metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)',
|
||||
'metrics/precision(M)', 'metrics/recall(M)', 'metrics/mAP50(M)', 'metrics/mAP50-95(M)']
|
||||
|
||||
def mean_results(self):
|
||||
"""Return the mean metrics for bounding box and segmentation results."""
|
||||
return self.box.mean_results() + self.seg.mean_results()
|
||||
|
||||
def class_result(self, i):
|
||||
"""Returns classification results for a specified class index."""
|
||||
return self.box.class_result(i) + self.seg.class_result(i)
|
||||
|
||||
@property
|
||||
def maps(self):
|
||||
"""Returns mAP scores for object detection and semantic segmentation models."""
|
||||
return self.box.maps + self.seg.maps
|
||||
|
||||
@property
|
||||
def fitness(self):
|
||||
"""Get the fitness score for both segmentation and bounding box models."""
|
||||
return self.seg.fitness() + self.box.fitness()
|
||||
|
||||
@property
|
||||
def ap_class_index(self):
|
||||
"""Boxes and masks have the same ap_class_index."""
|
||||
return self.box.ap_class_index
|
||||
|
||||
@property
|
||||
def results_dict(self):
|
||||
"""Returns results of object detection model for evaluation."""
|
||||
return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
|
||||
|
||||
|
||||
class PoseMetrics(SegmentMetrics):
|
||||
"""
|
||||
Calculates and aggregates detection and pose metrics over a given set of classes.
|
||||
|
||||
Args:
|
||||
save_dir (Path): Path to the directory where the output plots should be saved. Default is the current directory.
|
||||
plot (bool): Whether to save the detection and segmentation plots. Default is False.
|
||||
on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
|
||||
names (list): List of class names. Default is an empty list.
|
||||
|
||||
Attributes:
|
||||
save_dir (Path): Path to the directory where the output plots should be saved.
|
||||
plot (bool): Whether to save the detection and segmentation plots.
|
||||
on_plot (func): An optional callback to pass plots path and data when they are rendered.
|
||||
names (list): List of class names.
|
||||
box (Metric): An instance of the Metric class to calculate box detection metrics.
|
||||
pose (Metric): An instance of the Metric class to calculate mask segmentation metrics.
|
||||
speed (dict): Dictionary to store the time taken in different phases of inference.
|
||||
|
||||
Methods:
|
||||
process(tp_m, tp_b, conf, pred_cls, target_cls): Processes metrics over the given set of predictions.
|
||||
mean_results(): Returns the mean of the detection and segmentation metrics over all the classes.
|
||||
class_result(i): Returns the detection and segmentation metrics of class `i`.
|
||||
maps: Returns the mean Average Precision (mAP) scores for IoU thresholds ranging from 0.50 to 0.95.
|
||||
fitness: Returns the fitness scores, which are a single weighted combination of metrics.
|
||||
ap_class_index: Returns the list of indices of classes used to compute Average Precision (AP).
|
||||
results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
|
||||
"""
|
||||
|
||||
def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
|
||||
super().__init__(save_dir, plot, names)
|
||||
self.save_dir = save_dir
|
||||
self.plot = plot
|
||||
self.on_plot = on_plot
|
||||
self.names = names
|
||||
self.box = Metric()
|
||||
self.pose = Metric()
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
|
||||
def __getattr__(self, attr):
|
||||
"""Raises an AttributeError if an invalid attribute is accessed."""
|
||||
name = self.__class__.__name__
|
||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||
|
||||
def process(self, tp_b, tp_p, conf, pred_cls, target_cls):
|
||||
"""
|
||||
Processes the detection and pose metrics over the given set of predictions.
|
||||
|
||||
Args:
|
||||
tp_b (list): List of True Positive boxes.
|
||||
tp_p (list): List of True Positive keypoints.
|
||||
conf (list): List of confidence scores.
|
||||
pred_cls (list): List of predicted classes.
|
||||
target_cls (list): List of target classes.
|
||||
"""
|
||||
|
||||
results_pose = ap_per_class(tp_p,
|
||||
conf,
|
||||
pred_cls,
|
||||
target_cls,
|
||||
plot=self.plot,
|
||||
on_plot=self.on_plot,
|
||||
save_dir=self.save_dir,
|
||||
names=self.names,
|
||||
prefix='Pose')[2:]
|
||||
self.pose.nc = len(self.names)
|
||||
self.pose.update(results_pose)
|
||||
results_box = ap_per_class(tp_b,
|
||||
conf,
|
||||
pred_cls,
|
||||
target_cls,
|
||||
plot=self.plot,
|
||||
on_plot=self.on_plot,
|
||||
save_dir=self.save_dir,
|
||||
names=self.names,
|
||||
prefix='Box')[2:]
|
||||
self.box.nc = len(self.names)
|
||||
self.box.update(results_box)
|
||||
|
||||
@property
|
||||
def keys(self):
|
||||
"""Returns list of evaluation metric keys."""
|
||||
return [
|
||||
'metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)',
|
||||
'metrics/precision(P)', 'metrics/recall(P)', 'metrics/mAP50(P)', 'metrics/mAP50-95(P)']
|
||||
|
||||
def mean_results(self):
|
||||
"""Return the mean results of box and pose."""
|
||||
return self.box.mean_results() + self.pose.mean_results()
|
||||
|
||||
def class_result(self, i):
|
||||
"""Return the class-wise detection results for a specific class i."""
|
||||
return self.box.class_result(i) + self.pose.class_result(i)
|
||||
|
||||
@property
|
||||
def maps(self):
|
||||
"""Returns the mean average precision (mAP) per class for both box and pose detections."""
|
||||
return self.box.maps + self.pose.maps
|
||||
|
||||
@property
|
||||
def fitness(self):
|
||||
"""Computes classification metrics and speed using the `targets` and `pred` inputs."""
|
||||
return self.pose.fitness() + self.box.fitness()
|
||||
|
||||
|
||||
class ClassifyMetrics(SimpleClass):
|
||||
"""
|
||||
Class for computing classification metrics including top-1 and top-5 accuracy.
|
||||
|
||||
Attributes:
|
||||
top1 (float): The top-1 accuracy.
|
||||
top5 (float): The top-5 accuracy.
|
||||
speed (Dict[str, float]): A dictionary containing the time taken for each step in the pipeline.
|
||||
|
||||
Properties:
|
||||
fitness (float): The fitness of the model, which is equal to top-5 accuracy.
|
||||
results_dict (Dict[str, Union[float, str]]): A dictionary containing the classification metrics and fitness.
|
||||
keys (List[str]): A list of keys for the results_dict.
|
||||
|
||||
Methods:
|
||||
process(targets, pred): Processes the targets and predictions to compute classification metrics.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.top1 = 0
|
||||
self.top5 = 0
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
|
||||
def process(self, targets, pred):
|
||||
"""Target classes and predicted classes."""
|
||||
pred, targets = torch.cat(pred), torch.cat(targets)
|
||||
correct = (targets[:, None] == pred).float()
|
||||
acc = torch.stack((correct[:, 0], correct.max(1).values), dim=1) # (top1, top5) accuracy
|
||||
self.top1, self.top5 = acc.mean(0).tolist()
|
||||
|
||||
@property
|
||||
def fitness(self):
|
||||
"""Returns top-5 accuracy as fitness score."""
|
||||
return self.top5
|
||||
|
||||
@property
|
||||
def results_dict(self):
|
||||
"""Returns a dictionary with model's performance metrics and fitness score."""
|
||||
return dict(zip(self.keys + ['fitness'], [self.top1, self.top5, self.fitness]))
|
||||
|
||||
@property
|
||||
def keys(self):
|
||||
"""Returns a list of keys for the results_dict property."""
|
||||
return ['metrics/accuracy_top1', 'metrics/accuracy_top5']
|
@ -1,739 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import contextlib
|
||||
import math
|
||||
import re
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER
|
||||
|
||||
from .metrics import box_iou
|
||||
|
||||
|
||||
class Profile(contextlib.ContextDecorator):
|
||||
"""
|
||||
YOLOv8 Profile class.
|
||||
Usage: as a decorator with @Profile() or as a context manager with 'with Profile():'
|
||||
"""
|
||||
|
||||
def __init__(self, t=0.0):
|
||||
"""
|
||||
Initialize the Profile class.
|
||||
|
||||
Args:
|
||||
t (float): Initial time. Defaults to 0.0.
|
||||
"""
|
||||
self.t = t
|
||||
self.cuda = torch.cuda.is_available()
|
||||
|
||||
def __enter__(self):
|
||||
"""
|
||||
Start timing.
|
||||
"""
|
||||
self.start = self.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
"""
|
||||
Stop timing.
|
||||
"""
|
||||
self.dt = self.time() - self.start # delta-time
|
||||
self.t += self.dt # accumulate dt
|
||||
|
||||
def time(self):
|
||||
"""
|
||||
Get current time.
|
||||
"""
|
||||
if self.cuda:
|
||||
torch.cuda.synchronize()
|
||||
return time.time()
|
||||
|
||||
|
||||
def coco80_to_coco91_class(): #
|
||||
"""
|
||||
Converts 80-index (val2014) to 91-index (paper).
|
||||
For details see https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/.
|
||||
|
||||
Example:
|
||||
a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
|
||||
b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
|
||||
x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
|
||||
x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
|
||||
"""
|
||||
return [
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
|
||||
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
|
||||
64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
|
||||
|
||||
|
||||
def segment2box(segment, width=640, height=640):
|
||||
"""
|
||||
Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
|
||||
|
||||
Args:
|
||||
segment (torch.Tensor): the segment label
|
||||
width (int): the width of the image. Defaults to 640
|
||||
height (int): The height of the image. Defaults to 640
|
||||
|
||||
Returns:
|
||||
(np.ndarray): the minimum and maximum x and y values of the segment.
|
||||
"""
|
||||
# Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
|
||||
x, y = segment.T # segment xy
|
||||
inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
|
||||
x, y, = x[inside], y[inside]
|
||||
return np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype) if any(x) else np.zeros(
|
||||
4, dtype=segment.dtype) # xyxy
|
||||
|
||||
|
||||
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True):
|
||||
"""
|
||||
Rescales bounding boxes (in the format of xyxy) from the shape of the image they were originally specified in
|
||||
(img1_shape) to the shape of a different image (img0_shape).
|
||||
|
||||
Args:
|
||||
img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
|
||||
boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
|
||||
img0_shape (tuple): the shape of the target image, in the format of (height, width).
|
||||
ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
|
||||
calculated based on the size difference between the two images.
|
||||
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
|
||||
rescaling.
|
||||
|
||||
Returns:
|
||||
boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
|
||||
"""
|
||||
if ratio_pad is None: # calculate from img0_shape
|
||||
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
||||
pad = round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1), round(
|
||||
(img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1) # wh padding
|
||||
else:
|
||||
gain = ratio_pad[0][0]
|
||||
pad = ratio_pad[1]
|
||||
|
||||
if padding:
|
||||
boxes[..., [0, 2]] -= pad[0] # x padding
|
||||
boxes[..., [1, 3]] -= pad[1] # y padding
|
||||
boxes[..., :4] /= gain
|
||||
clip_boxes(boxes, img0_shape)
|
||||
return boxes
|
||||
|
||||
|
||||
def make_divisible(x, divisor):
|
||||
"""
|
||||
Returns the nearest number that is divisible by the given divisor.
|
||||
|
||||
Args:
|
||||
x (int): The number to make divisible.
|
||||
divisor (int | torch.Tensor): The divisor.
|
||||
|
||||
Returns:
|
||||
(int): The nearest number divisible by the divisor.
|
||||
"""
|
||||
if isinstance(divisor, torch.Tensor):
|
||||
divisor = int(divisor.max()) # to int
|
||||
return math.ceil(x / divisor) * divisor
|
||||
|
||||
|
||||
def non_max_suppression(
|
||||
prediction,
|
||||
conf_thres=0.25,
|
||||
iou_thres=0.45,
|
||||
classes=None,
|
||||
agnostic=False,
|
||||
multi_label=False,
|
||||
labels=(),
|
||||
max_det=300,
|
||||
nc=0, # number of classes (optional)
|
||||
max_time_img=0.05,
|
||||
max_nms=30000,
|
||||
max_wh=7680,
|
||||
):
|
||||
"""
|
||||
Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
|
||||
|
||||
Arguments:
|
||||
prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes)
|
||||
containing the predicted boxes, classes, and masks. The tensor should be in the format
|
||||
output by a model, such as YOLO.
|
||||
conf_thres (float): The confidence threshold below which boxes will be filtered out.
|
||||
Valid values are between 0.0 and 1.0.
|
||||
iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS.
|
||||
Valid values are between 0.0 and 1.0.
|
||||
classes (List[int]): A list of class indices to consider. If None, all classes will be considered.
|
||||
agnostic (bool): If True, the model is agnostic to the number of classes, and all
|
||||
classes will be considered as one.
|
||||
multi_label (bool): If True, each box may have multiple labels.
|
||||
labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner
|
||||
list contains the apriori labels for a given image. The list should be in the format
|
||||
output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2).
|
||||
max_det (int): The maximum number of boxes to keep after NMS.
|
||||
nc (int, optional): The number of classes output by the model. Any indices after this will be considered masks.
|
||||
max_time_img (float): The maximum time (seconds) for processing one image.
|
||||
max_nms (int): The maximum number of boxes into torchvision.ops.nms().
|
||||
max_wh (int): The maximum box width and height in pixels
|
||||
|
||||
Returns:
|
||||
(List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
|
||||
shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
|
||||
(x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
|
||||
"""
|
||||
|
||||
# Checks
|
||||
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
|
||||
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
|
||||
if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
|
||||
prediction = prediction[0] # select only inference output
|
||||
|
||||
device = prediction.device
|
||||
mps = 'mps' in device.type # Apple MPS
|
||||
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
|
||||
prediction = prediction.cpu()
|
||||
bs = prediction.shape[0] # batch size
|
||||
nc = nc or (prediction.shape[1] - 4) # number of classes
|
||||
nm = prediction.shape[1] - nc - 4
|
||||
mi = 4 + nc # mask start index
|
||||
xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
|
||||
|
||||
# Settings
|
||||
# min_wh = 2 # (pixels) minimum box width and height
|
||||
time_limit = 0.5 + max_time_img * bs # seconds to quit after
|
||||
redundant = True # require redundant detections
|
||||
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
||||
merge = False # use merge-NMS
|
||||
|
||||
prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
|
||||
prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
|
||||
|
||||
t = time.time()
|
||||
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
|
||||
for xi, x in enumerate(prediction): # image index, image inference
|
||||
# Apply constraints
|
||||
# x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
|
||||
x = x[xc[xi]] # confidence
|
||||
|
||||
# Cat apriori labels if autolabelling
|
||||
if labels and len(labels[xi]):
|
||||
lb = labels[xi]
|
||||
v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
|
||||
v[:, :4] = lb[:, 1:5] # box
|
||||
v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
|
||||
x = torch.cat((x, v), 0)
|
||||
|
||||
# If none remain process next image
|
||||
if not x.shape[0]:
|
||||
continue
|
||||
|
||||
# Detections matrix nx6 (xyxy, conf, cls)
|
||||
box, cls, mask = x.split((4, nc, nm), 1)
|
||||
|
||||
if multi_label:
|
||||
i, j = torch.where(cls > conf_thres)
|
||||
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
|
||||
else: # best class only
|
||||
conf, j = cls.max(1, keepdim=True)
|
||||
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
|
||||
|
||||
# Filter by class
|
||||
if classes is not None:
|
||||
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
|
||||
|
||||
# Apply finite constraint
|
||||
# if not torch.isfinite(x).all():
|
||||
# x = x[torch.isfinite(x).all(1)]
|
||||
|
||||
# Check shape
|
||||
n = x.shape[0] # number of boxes
|
||||
if not n: # no boxes
|
||||
continue
|
||||
if n > max_nms: # excess boxes
|
||||
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
|
||||
|
||||
# Batched NMS
|
||||
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
||||
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
|
||||
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
||||
i = i[:max_det] # limit detections
|
||||
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
|
||||
# Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
|
||||
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
|
||||
weights = iou * scores[None] # box weights
|
||||
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
|
||||
if redundant:
|
||||
i = i[iou.sum(1) > 1] # require redundancy
|
||||
|
||||
output[xi] = x[i]
|
||||
if mps:
|
||||
output[xi] = output[xi].to(device)
|
||||
if (time.time() - t) > time_limit:
|
||||
LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
|
||||
break # time limit exceeded
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def clip_boxes(boxes, shape):
|
||||
"""
|
||||
It takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the
|
||||
shape
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor): the bounding boxes to clip
|
||||
shape (tuple): the shape of the image
|
||||
"""
|
||||
if isinstance(boxes, torch.Tensor): # faster individually
|
||||
boxes[..., 0].clamp_(0, shape[1]) # x1
|
||||
boxes[..., 1].clamp_(0, shape[0]) # y1
|
||||
boxes[..., 2].clamp_(0, shape[1]) # x2
|
||||
boxes[..., 3].clamp_(0, shape[0]) # y2
|
||||
else: # np.array (faster grouped)
|
||||
boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
|
||||
boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
|
||||
|
||||
|
||||
def clip_coords(coords, shape):
|
||||
"""
|
||||
Clip line coordinates to the image boundaries.
|
||||
|
||||
Args:
|
||||
coords (torch.Tensor | numpy.ndarray): A list of line coordinates.
|
||||
shape (tuple): A tuple of integers representing the size of the image in the format (height, width).
|
||||
|
||||
Returns:
|
||||
(None): The function modifies the input `coordinates` in place, by clipping each coordinate to the image boundaries.
|
||||
"""
|
||||
if isinstance(coords, torch.Tensor): # faster individually
|
||||
coords[..., 0].clamp_(0, shape[1]) # x
|
||||
coords[..., 1].clamp_(0, shape[0]) # y
|
||||
else: # np.array (faster grouped)
|
||||
coords[..., 0] = coords[..., 0].clip(0, shape[1]) # x
|
||||
coords[..., 1] = coords[..., 1].clip(0, shape[0]) # y
|
||||
|
||||
|
||||
def scale_image(masks, im0_shape, ratio_pad=None):
|
||||
"""
|
||||
Takes a mask, and resizes it to the original image size
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor): resized and padded masks/images, [h, w, num]/[h, w, 3].
|
||||
im0_shape (tuple): the original image shape
|
||||
ratio_pad (tuple): the ratio of the padding to the original image.
|
||||
|
||||
Returns:
|
||||
masks (torch.Tensor): The masks that are being returned.
|
||||
"""
|
||||
# Rescale coordinates (xyxy) from im1_shape to im0_shape
|
||||
im1_shape = masks.shape
|
||||
if im1_shape[:2] == im0_shape[:2]:
|
||||
return masks
|
||||
if ratio_pad is None: # calculate from im0_shape
|
||||
gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new
|
||||
pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding
|
||||
else:
|
||||
gain = ratio_pad[0][0]
|
||||
pad = ratio_pad[1]
|
||||
top, left = int(pad[1]), int(pad[0]) # y, x
|
||||
bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0])
|
||||
|
||||
if len(masks.shape) < 2:
|
||||
raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
|
||||
masks = masks[top:bottom, left:right]
|
||||
# masks = masks.permute(2, 0, 1).contiguous()
|
||||
# masks = F.interpolate(masks[None], im0_shape[:2], mode='bilinear', align_corners=False)[0]
|
||||
# masks = masks.permute(1, 2, 0).contiguous()
|
||||
masks = cv2.resize(masks, (im0_shape[1], im0_shape[0]))
|
||||
if len(masks.shape) == 2:
|
||||
masks = masks[:, :, None]
|
||||
|
||||
return masks
|
||||
|
||||
|
||||
def xyxy2xywh(x):
|
||||
"""
|
||||
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
|
||||
y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
|
||||
y[..., 2] = x[..., 2] - x[..., 0] # width
|
||||
y[..., 3] = x[..., 3] - x[..., 1] # height
|
||||
return y
|
||||
|
||||
|
||||
def xywh2xyxy(x):
|
||||
"""
|
||||
Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the
|
||||
top-left corner and (x2, y2) is the bottom-right corner.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x, y, width, height) format.
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
|
||||
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
|
||||
y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
|
||||
y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
|
||||
return y
|
||||
|
||||
|
||||
def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
|
||||
"""
|
||||
Convert normalized bounding box coordinates to pixel coordinates.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): The bounding box coordinates.
|
||||
w (int): Width of the image. Defaults to 640
|
||||
h (int): Height of the image. Defaults to 640
|
||||
padw (int): Padding width. Defaults to 0
|
||||
padh (int): Padding height. Defaults to 0
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
|
||||
x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
|
||||
y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
|
||||
y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
|
||||
y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y
|
||||
return y
|
||||
|
||||
|
||||
def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
|
||||
"""
|
||||
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format.
|
||||
x, y, width and height are normalized to image dimensions
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
|
||||
w (int): The width of the image. Defaults to 640
|
||||
h (int): The height of the image. Defaults to 640
|
||||
clip (bool): If True, the boxes will be clipped to the image boundaries. Defaults to False
|
||||
eps (float): The minimum value of the box's width and height. Defaults to 0.0
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format
|
||||
"""
|
||||
if clip:
|
||||
clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
|
||||
y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
|
||||
y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
|
||||
y[..., 3] = (x[..., 3] - x[..., 1]) / h # height
|
||||
return y
|
||||
|
||||
|
||||
def xyn2xy(x, w=640, h=640, padw=0, padh=0):
|
||||
"""
|
||||
Convert normalized coordinates to pixel coordinates of shape (n,2)
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): The input tensor of normalized bounding box coordinates
|
||||
w (int): The width of the image. Defaults to 640
|
||||
h (int): The height of the image. Defaults to 640
|
||||
padw (int): The width of the padding. Defaults to 0
|
||||
padh (int): The height of the padding. Defaults to 0
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The x and y coordinates of the top left corner of the bounding box
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[..., 0] = w * x[..., 0] + padw # top left x
|
||||
y[..., 1] = h * x[..., 1] + padh # top left y
|
||||
return y
|
||||
|
||||
|
||||
def xywh2ltwh(x):
|
||||
"""
|
||||
Convert the bounding box format from [x, y, w, h] to [x1, y1, w, h], where x1, y1 are the top-left coordinates.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): The input tensor with the bounding box coordinates in the xywh format
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
||||
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
||||
return y
|
||||
|
||||
|
||||
def xyxy2ltwh(x):
|
||||
"""
|
||||
Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format.
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[:, 2] = x[:, 2] - x[:, 0] # width
|
||||
y[:, 3] = x[:, 3] - x[:, 1] # height
|
||||
return y
|
||||
|
||||
|
||||
def ltwh2xywh(x):
|
||||
"""
|
||||
Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): the input tensor
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[:, 0] = x[:, 0] + x[:, 2] / 2 # center x
|
||||
y[:, 1] = x[:, 1] + x[:, 3] / 2 # center y
|
||||
return y
|
||||
|
||||
|
||||
def ltwh2xyxy(x):
|
||||
"""
|
||||
It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): the input image
|
||||
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): the xyxy coordinates of the bounding boxes.
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[:, 2] = x[:, 2] + x[:, 0] # width
|
||||
y[:, 3] = x[:, 3] + x[:, 1] # height
|
||||
return y
|
||||
|
||||
|
||||
def segments2boxes(segments):
|
||||
"""
|
||||
It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
|
||||
|
||||
Args:
|
||||
segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates
|
||||
|
||||
Returns:
|
||||
(np.ndarray): the xywh coordinates of the bounding boxes.
|
||||
"""
|
||||
boxes = []
|
||||
for s in segments:
|
||||
x, y = s.T # segment xy
|
||||
boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
|
||||
return xyxy2xywh(np.array(boxes)) # cls, xywh
|
||||
|
||||
|
||||
def resample_segments(segments, n=1000):
|
||||
"""
|
||||
Inputs a list of segments (n,2) and returns a list of segments (n,2) up-sampled to n points each.
|
||||
|
||||
Args:
|
||||
segments (list): a list of (n,2) arrays, where n is the number of points in the segment.
|
||||
n (int): number of points to resample the segment to. Defaults to 1000
|
||||
|
||||
Returns:
|
||||
segments (list): the resampled segments.
|
||||
"""
|
||||
for i, s in enumerate(segments):
|
||||
s = np.concatenate((s, s[0:1, :]), axis=0)
|
||||
x = np.linspace(0, len(s) - 1, n)
|
||||
xp = np.arange(len(s))
|
||||
segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)],
|
||||
dtype=np.float32).reshape(2, -1).T # segment xy
|
||||
return segments
|
||||
|
||||
|
||||
def crop_mask(masks, boxes):
|
||||
"""
|
||||
It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor): [n, h, w] tensor of masks
|
||||
boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The masks are being cropped to the bounding box.
|
||||
"""
|
||||
n, h, w = masks.shape
|
||||
x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)
|
||||
r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w)
|
||||
c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1)
|
||||
|
||||
return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
|
||||
|
||||
|
||||
def process_mask_upsample(protos, masks_in, bboxes, shape):
|
||||
"""
|
||||
It takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher
|
||||
quality but is slower.
|
||||
|
||||
Args:
|
||||
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
|
||||
masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
|
||||
bboxes (torch.Tensor): [n, 4], n is number of masks after nms
|
||||
shape (tuple): the size of the input image (h,w)
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The upsampled masks.
|
||||
"""
|
||||
c, mh, mw = protos.shape # CHW
|
||||
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
|
||||
masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
|
||||
masks = crop_mask(masks, bboxes) # CHW
|
||||
return masks.gt_(0.5)
|
||||
|
||||
|
||||
def process_mask(protos, masks_in, bboxes, shape, upsample=False):
|
||||
"""
|
||||
Apply masks to bounding boxes using the output of the mask head.
|
||||
|
||||
Args:
|
||||
protos (torch.Tensor): A tensor of shape [mask_dim, mask_h, mask_w].
|
||||
masks_in (torch.Tensor): A tensor of shape [n, mask_dim], where n is the number of masks after NMS.
|
||||
bboxes (torch.Tensor): A tensor of shape [n, 4], where n is the number of masks after NMS.
|
||||
shape (tuple): A tuple of integers representing the size of the input image in the format (h, w).
|
||||
upsample (bool): A flag to indicate whether to upsample the mask to the original image size. Default is False.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w
|
||||
are the height and width of the input image. The mask is applied to the bounding boxes.
|
||||
"""
|
||||
|
||||
c, mh, mw = protos.shape # CHW
|
||||
ih, iw = shape
|
||||
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) # CHW
|
||||
|
||||
downsampled_bboxes = bboxes.clone()
|
||||
downsampled_bboxes[:, 0] *= mw / iw
|
||||
downsampled_bboxes[:, 2] *= mw / iw
|
||||
downsampled_bboxes[:, 3] *= mh / ih
|
||||
downsampled_bboxes[:, 1] *= mh / ih
|
||||
|
||||
masks = crop_mask(masks, downsampled_bboxes) # CHW
|
||||
if upsample:
|
||||
masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
|
||||
return masks.gt_(0.5)
|
||||
|
||||
|
||||
def process_mask_native(protos, masks_in, bboxes, shape):
|
||||
"""
|
||||
It takes the output of the mask head, and crops it after upsampling to the bounding boxes.
|
||||
|
||||
Args:
|
||||
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
|
||||
masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
|
||||
bboxes (torch.Tensor): [n, 4], n is number of masks after nms
|
||||
shape (tuple): the size of the input image (h,w)
|
||||
|
||||
Returns:
|
||||
masks (torch.Tensor): The returned masks with dimensions [h, w, n]
|
||||
"""
|
||||
c, mh, mw = protos.shape # CHW
|
||||
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
|
||||
masks = scale_masks(masks[None], shape)[0] # CHW
|
||||
masks = crop_mask(masks, bboxes) # CHW
|
||||
return masks.gt_(0.5)
|
||||
|
||||
|
||||
def scale_masks(masks, shape, padding=True):
|
||||
"""
|
||||
Rescale segment masks to shape.
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor): (N, C, H, W).
|
||||
shape (tuple): Height and width.
|
||||
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
|
||||
rescaling.
|
||||
"""
|
||||
mh, mw = masks.shape[2:]
|
||||
gain = min(mh / shape[0], mw / shape[1]) # gain = old / new
|
||||
pad = [mw - shape[1] * gain, mh - shape[0] * gain] # wh padding
|
||||
if padding:
|
||||
pad[0] /= 2
|
||||
pad[1] /= 2
|
||||
top, left = (int(pad[1]), int(pad[0])) if padding else (0, 0) # y, x
|
||||
bottom, right = (int(mh - pad[1]), int(mw - pad[0]))
|
||||
masks = masks[..., top:bottom, left:right]
|
||||
|
||||
masks = F.interpolate(masks, shape, mode='bilinear', align_corners=False) # NCHW
|
||||
return masks
|
||||
|
||||
|
||||
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False, padding=True):
|
||||
"""
|
||||
Rescale segment coordinates (xyxy) from img1_shape to img0_shape
|
||||
|
||||
Args:
|
||||
img1_shape (tuple): The shape of the image that the coords are from.
|
||||
coords (torch.Tensor): the coords to be scaled
|
||||
img0_shape (tuple): the shape of the image that the segmentation is being applied to
|
||||
ratio_pad (tuple): the ratio of the image size to the padded image size.
|
||||
normalize (bool): If True, the coordinates will be normalized to the range [0, 1]. Defaults to False
|
||||
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
|
||||
rescaling.
|
||||
|
||||
Returns:
|
||||
coords (torch.Tensor): the segmented image.
|
||||
"""
|
||||
if ratio_pad is None: # calculate from img0_shape
|
||||
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
||||
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
|
||||
else:
|
||||
gain = ratio_pad[0][0]
|
||||
pad = ratio_pad[1]
|
||||
|
||||
if padding:
|
||||
coords[..., 0] -= pad[0] # x padding
|
||||
coords[..., 1] -= pad[1] # y padding
|
||||
coords[..., 0] /= gain
|
||||
coords[..., 1] /= gain
|
||||
clip_coords(coords, img0_shape)
|
||||
if normalize:
|
||||
coords[..., 0] /= img0_shape[1] # width
|
||||
coords[..., 1] /= img0_shape[0] # height
|
||||
return coords
|
||||
|
||||
|
||||
def masks2segments(masks, strategy='largest'):
|
||||
"""
|
||||
It takes a list of masks(n,h,w) and returns a list of segments(n,xy)
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160)
|
||||
strategy (str): 'concat' or 'largest'. Defaults to largest
|
||||
|
||||
Returns:
|
||||
segments (List): list of segment masks
|
||||
"""
|
||||
segments = []
|
||||
for x in masks.int().cpu().numpy().astype('uint8'):
|
||||
c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
|
||||
if c:
|
||||
if strategy == 'concat': # concatenate all segments
|
||||
c = np.concatenate([x.reshape(-1, 2) for x in c])
|
||||
elif strategy == 'largest': # select largest segment
|
||||
c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
|
||||
else:
|
||||
c = np.zeros((0, 2)) # no segments found
|
||||
segments.append(c.astype('float32'))
|
||||
return segments
|
||||
|
||||
|
||||
def clean_str(s):
|
||||
"""
|
||||
Cleans a string by replacing special characters with underscore _
|
||||
|
||||
Args:
|
||||
s (str): a string needing special characters replaced
|
||||
|
||||
Returns:
|
||||
(str): a string with special characters replaced by an underscore _
|
||||
"""
|
||||
return re.sub(pattern='[|@#!¡·$€%&()=?¿^*;:,¨´><+]', repl='_', string=s)
|
@ -1,45 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Monkey patches to update/extend functionality of existing functions
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# OpenCV Multilanguage-friendly functions ------------------------------------------------------------------------------
|
||||
_imshow = cv2.imshow # copy to avoid recursion errors
|
||||
|
||||
|
||||
def imread(filename, flags=cv2.IMREAD_COLOR):
|
||||
return cv2.imdecode(np.fromfile(filename, np.uint8), flags)
|
||||
|
||||
|
||||
def imwrite(filename, img):
|
||||
try:
|
||||
cv2.imencode(Path(filename).suffix, img)[1].tofile(filename)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def imshow(path, im):
|
||||
_imshow(path.encode('unicode_escape').decode(), im)
|
||||
|
||||
|
||||
# PyTorch functions ----------------------------------------------------------------------------------------------------
|
||||
_torch_save = torch.save # copy to avoid recursion errors
|
||||
|
||||
|
||||
def torch_save(*args, **kwargs):
|
||||
"""Use dill (if exists) to serialize the lambda functions where pickle does not do this."""
|
||||
try:
|
||||
import dill as pickle
|
||||
except ImportError:
|
||||
import pickle
|
||||
|
||||
if 'pickle_module' not in kwargs:
|
||||
kwargs['pickle_module'] = pickle
|
||||
return _torch_save(*args, **kwargs)
|
@ -1,527 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import contextlib
|
||||
import math
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from PIL import __version__ as pil_version
|
||||
from scipy.ndimage import gaussian_filter1d
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER, TryExcept, plt_settings, threaded
|
||||
|
||||
from .checks import check_font, check_version, is_ascii
|
||||
from .files import increment_path
|
||||
from .ops import clip_boxes, scale_image, xywh2xyxy, xyxy2xywh
|
||||
|
||||
|
||||
class Colors:
|
||||
"""Ultralytics color palette https://ultralytics.com/."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values()."""
|
||||
hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
|
||||
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
|
||||
self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
|
||||
self.n = len(self.palette)
|
||||
self.pose_palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102], [230, 230, 0], [255, 153, 255],
|
||||
[153, 204, 255], [255, 102, 255], [255, 51, 255], [102, 178, 255], [51, 153, 255],
|
||||
[255, 153, 153], [255, 102, 102], [255, 51, 51], [153, 255, 153], [102, 255, 102],
|
||||
[51, 255, 51], [0, 255, 0], [0, 0, 255], [255, 0, 0], [255, 255, 255]],
|
||||
dtype=np.uint8)
|
||||
|
||||
def __call__(self, i, bgr=False):
|
||||
"""Converts hex color codes to rgb values."""
|
||||
c = self.palette[int(i) % self.n]
|
||||
return (c[2], c[1], c[0]) if bgr else c
|
||||
|
||||
@staticmethod
|
||||
def hex2rgb(h): # rgb order (PIL)
|
||||
return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
|
||||
|
||||
|
||||
colors = Colors() # create instance for 'from utils.plots import colors'
|
||||
|
||||
|
||||
class Annotator:
|
||||
"""YOLOv8 Annotator for train/val mosaics and jpgs and detect/hub inference annotations."""
|
||||
|
||||
def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
|
||||
"""Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
|
||||
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
|
||||
non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
|
||||
self.pil = pil or non_ascii
|
||||
if self.pil: # use PIL
|
||||
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
|
||||
self.draw = ImageDraw.Draw(self.im)
|
||||
try:
|
||||
font = check_font('Arial.Unicode.ttf' if non_ascii else font)
|
||||
size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12)
|
||||
self.font = ImageFont.truetype(str(font), size)
|
||||
except Exception:
|
||||
self.font = ImageFont.load_default()
|
||||
# Deprecation fix for w, h = getsize(string) -> _, _, w, h = getbox(string)
|
||||
if check_version(pil_version, '9.2.0'):
|
||||
self.font.getsize = lambda x: self.font.getbbox(x)[2:4] # text width, height
|
||||
else: # use cv2
|
||||
self.im = im
|
||||
self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
|
||||
# Pose
|
||||
self.skeleton = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13], [6, 7], [6, 8], [7, 9],
|
||||
[8, 10], [9, 11], [2, 3], [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]]
|
||||
|
||||
self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]]
|
||||
self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]
|
||||
|
||||
def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
|
||||
"""Add one xyxy box to image with label."""
|
||||
if isinstance(box, torch.Tensor):
|
||||
box = box.tolist()
|
||||
if self.pil or not is_ascii(label):
|
||||
self.draw.rectangle(box, width=self.lw, outline=color) # box
|
||||
if label:
|
||||
w, h = self.font.getsize(label) # text width, height
|
||||
outside = box[1] - h >= 0 # label fits outside box
|
||||
self.draw.rectangle(
|
||||
(box[0], box[1] - h if outside else box[1], box[0] + w + 1,
|
||||
box[1] + 1 if outside else box[1] + h + 1),
|
||||
fill=color,
|
||||
)
|
||||
# self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
|
||||
self.draw.text((box[0], box[1] - h if outside else box[1]), label, fill=txt_color, font=self.font)
|
||||
else: # cv2
|
||||
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
|
||||
cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
|
||||
if label:
|
||||
tf = max(self.lw - 1, 1) # font thickness
|
||||
w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height
|
||||
outside = p1[1] - h >= 3
|
||||
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
|
||||
cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
|
||||
cv2.putText(self.im,
|
||||
label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
|
||||
0,
|
||||
self.lw / 3,
|
||||
txt_color,
|
||||
thickness=tf,
|
||||
lineType=cv2.LINE_AA)
|
||||
|
||||
def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
|
||||
"""Plot masks at once.
|
||||
Args:
|
||||
masks (tensor): predicted masks on cuda, shape: [n, h, w]
|
||||
colors (List[List[Int]]): colors for predicted masks, [[r, g, b] * n]
|
||||
im_gpu (tensor): img is in cuda, shape: [3, h, w], range: [0, 1]
|
||||
alpha (float): mask transparency: 0.0 fully transparent, 1.0 opaque
|
||||
"""
|
||||
if self.pil:
|
||||
# Convert to numpy first
|
||||
self.im = np.asarray(self.im).copy()
|
||||
if len(masks) == 0:
|
||||
self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
|
||||
if im_gpu.device != masks.device:
|
||||
im_gpu = im_gpu.to(masks.device)
|
||||
colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0 # shape(n,3)
|
||||
colors = colors[:, None, None] # shape(n,1,1,3)
|
||||
masks = masks.unsqueeze(3) # shape(n,h,w,1)
|
||||
masks_color = masks * (colors * alpha) # shape(n,h,w,3)
|
||||
|
||||
inv_alph_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
|
||||
mcs = masks_color.max(dim=0).values # shape(n,h,w,3)
|
||||
|
||||
im_gpu = im_gpu.flip(dims=[0]) # flip channel
|
||||
im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
|
||||
im_gpu = im_gpu * inv_alph_masks[-1] + mcs
|
||||
im_mask = (im_gpu * 255)
|
||||
im_mask_np = im_mask.byte().cpu().numpy()
|
||||
self.im[:] = im_mask_np if retina_masks else scale_image(im_mask_np, self.im.shape)
|
||||
if self.pil:
|
||||
# Convert im back to PIL and update draw
|
||||
self.fromarray(self.im)
|
||||
|
||||
def kpts(self, kpts, shape=(640, 640), radius=5, kpt_line=True):
|
||||
"""Plot keypoints on the image.
|
||||
|
||||
Args:
|
||||
kpts (tensor): Predicted keypoints with shape [17, 3]. Each keypoint has (x, y, confidence).
|
||||
shape (tuple): Image shape as a tuple (h, w), where h is the height and w is the width.
|
||||
radius (int, optional): Radius of the drawn keypoints. Default is 5.
|
||||
kpt_line (bool, optional): If True, the function will draw lines connecting keypoints
|
||||
for human pose. Default is True.
|
||||
|
||||
Note: `kpt_line=True` currently only supports human pose plotting.
|
||||
"""
|
||||
if self.pil:
|
||||
# Convert to numpy first
|
||||
self.im = np.asarray(self.im).copy()
|
||||
nkpt, ndim = kpts.shape
|
||||
is_pose = nkpt == 17 and ndim == 3
|
||||
kpt_line &= is_pose # `kpt_line=True` for now only supports human pose plotting
|
||||
for i, k in enumerate(kpts):
|
||||
color_k = [int(x) for x in self.kpt_color[i]] if is_pose else colors(i)
|
||||
x_coord, y_coord = k[0], k[1]
|
||||
if x_coord % shape[1] != 0 and y_coord % shape[0] != 0:
|
||||
if len(k) == 3:
|
||||
conf = k[2]
|
||||
if conf < 0.5:
|
||||
continue
|
||||
cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, color_k, -1, lineType=cv2.LINE_AA)
|
||||
|
||||
if kpt_line:
|
||||
ndim = kpts.shape[-1]
|
||||
for i, sk in enumerate(self.skeleton):
|
||||
pos1 = (int(kpts[(sk[0] - 1), 0]), int(kpts[(sk[0] - 1), 1]))
|
||||
pos2 = (int(kpts[(sk[1] - 1), 0]), int(kpts[(sk[1] - 1), 1]))
|
||||
if ndim == 3:
|
||||
conf1 = kpts[(sk[0] - 1), 2]
|
||||
conf2 = kpts[(sk[1] - 1), 2]
|
||||
if conf1 < 0.5 or conf2 < 0.5:
|
||||
continue
|
||||
if pos1[0] % shape[1] == 0 or pos1[1] % shape[0] == 0 or pos1[0] < 0 or pos1[1] < 0:
|
||||
continue
|
||||
if pos2[0] % shape[1] == 0 or pos2[1] % shape[0] == 0 or pos2[0] < 0 or pos2[1] < 0:
|
||||
continue
|
||||
cv2.line(self.im, pos1, pos2, [int(x) for x in self.limb_color[i]], thickness=2, lineType=cv2.LINE_AA)
|
||||
if self.pil:
|
||||
# Convert im back to PIL and update draw
|
||||
self.fromarray(self.im)
|
||||
|
||||
def rectangle(self, xy, fill=None, outline=None, width=1):
|
||||
"""Add rectangle to image (PIL-only)."""
|
||||
self.draw.rectangle(xy, fill, outline, width)
|
||||
|
||||
def text(self, xy, text, txt_color=(255, 255, 255), anchor='top', box_style=False):
|
||||
"""Adds text to an image using PIL or cv2."""
|
||||
if anchor == 'bottom': # start y from font bottom
|
||||
w, h = self.font.getsize(text) # text width, height
|
||||
xy[1] += 1 - h
|
||||
if self.pil:
|
||||
if box_style:
|
||||
w, h = self.font.getsize(text)
|
||||
self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=txt_color)
|
||||
# Using `txt_color` for background and draw fg with white color
|
||||
txt_color = (255, 255, 255)
|
||||
if '\n' in text:
|
||||
lines = text.split('\n')
|
||||
_, h = self.font.getsize(text)
|
||||
for line in lines:
|
||||
self.draw.text(xy, line, fill=txt_color, font=self.font)
|
||||
xy[1] += h
|
||||
else:
|
||||
self.draw.text(xy, text, fill=txt_color, font=self.font)
|
||||
else:
|
||||
if box_style:
|
||||
tf = max(self.lw - 1, 1) # font thickness
|
||||
w, h = cv2.getTextSize(text, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height
|
||||
outside = xy[1] - h >= 3
|
||||
p2 = xy[0] + w, xy[1] - h - 3 if outside else xy[1] + h + 3
|
||||
cv2.rectangle(self.im, xy, p2, txt_color, -1, cv2.LINE_AA) # filled
|
||||
# Using `txt_color` for background and draw fg with white color
|
||||
txt_color = (255, 255, 255)
|
||||
tf = max(self.lw - 1, 1) # font thickness
|
||||
cv2.putText(self.im, text, xy, 0, self.lw / 3, txt_color, thickness=tf, lineType=cv2.LINE_AA)
|
||||
|
||||
def fromarray(self, im):
|
||||
"""Update self.im from a numpy array."""
|
||||
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
|
||||
self.draw = ImageDraw.Draw(self.im)
|
||||
|
||||
def result(self):
|
||||
"""Return annotated image as array."""
|
||||
return np.asarray(self.im)
|
||||
|
||||
|
||||
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
|
||||
@plt_settings()
|
||||
def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
|
||||
"""Save and plot image with no axis or spines."""
|
||||
import pandas as pd
|
||||
import seaborn as sn
|
||||
|
||||
# Filter matplotlib>=3.7.2 warning
|
||||
warnings.filterwarnings('ignore', category=UserWarning, message='The figure layout has changed to tight')
|
||||
|
||||
# Plot dataset labels
|
||||
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
|
||||
b = boxes.transpose() # classes, boxes
|
||||
nc = int(cls.max() + 1) # number of classes
|
||||
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
|
||||
|
||||
# Seaborn correlogram
|
||||
sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
|
||||
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
|
||||
plt.close()
|
||||
|
||||
# Matplotlib labels
|
||||
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
|
||||
y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
|
||||
with contextlib.suppress(Exception): # color histogram bars by class
|
||||
[y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # known issue #3195
|
||||
ax[0].set_ylabel('instances')
|
||||
if 0 < len(names) < 30:
|
||||
ax[0].set_xticks(range(len(names)))
|
||||
ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
|
||||
else:
|
||||
ax[0].set_xlabel('classes')
|
||||
sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
|
||||
sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
|
||||
|
||||
# Rectangles
|
||||
boxes[:, 0:2] = 0.5 # center
|
||||
boxes = xywh2xyxy(boxes) * 1000
|
||||
img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)
|
||||
for cls, box in zip(cls[:500], boxes[:500]):
|
||||
ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
|
||||
ax[1].imshow(img)
|
||||
ax[1].axis('off')
|
||||
|
||||
for a in [0, 1, 2, 3]:
|
||||
for s in ['top', 'right', 'left', 'bottom']:
|
||||
ax[a].spines[s].set_visible(False)
|
||||
|
||||
fname = save_dir / 'labels.jpg'
|
||||
plt.savefig(fname, dpi=200)
|
||||
plt.close()
|
||||
if on_plot:
|
||||
on_plot(fname)
|
||||
|
||||
|
||||
def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
|
||||
"""Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop."""
|
||||
b = xyxy2xywh(xyxy.view(-1, 4)) # boxes
|
||||
if square:
|
||||
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
|
||||
b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
|
||||
xyxy = xywh2xyxy(b).long()
|
||||
clip_boxes(xyxy, im.shape)
|
||||
crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
|
||||
if save:
|
||||
file.parent.mkdir(parents=True, exist_ok=True) # make directory
|
||||
f = str(increment_path(file).with_suffix('.jpg'))
|
||||
# cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
|
||||
Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB
|
||||
return crop
|
||||
|
||||
|
||||
@threaded
|
||||
def plot_images(images,
|
||||
batch_idx,
|
||||
cls,
|
||||
bboxes=np.zeros(0, dtype=np.float32),
|
||||
masks=np.zeros(0, dtype=np.uint8),
|
||||
kpts=np.zeros((0, 51), dtype=np.float32),
|
||||
paths=None,
|
||||
fname='images.jpg',
|
||||
names=None,
|
||||
on_plot=None):
|
||||
"""Plot image grid with labels."""
|
||||
if isinstance(images, torch.Tensor):
|
||||
images = images.cpu().float().numpy()
|
||||
if isinstance(cls, torch.Tensor):
|
||||
cls = cls.cpu().numpy()
|
||||
if isinstance(bboxes, torch.Tensor):
|
||||
bboxes = bboxes.cpu().numpy()
|
||||
if isinstance(masks, torch.Tensor):
|
||||
masks = masks.cpu().numpy().astype(int)
|
||||
if isinstance(kpts, torch.Tensor):
|
||||
kpts = kpts.cpu().numpy()
|
||||
if isinstance(batch_idx, torch.Tensor):
|
||||
batch_idx = batch_idx.cpu().numpy()
|
||||
|
||||
max_size = 1920 # max image size
|
||||
max_subplots = 16 # max image subplots, i.e. 4x4
|
||||
bs, _, h, w = images.shape # batch size, _, height, width
|
||||
bs = min(bs, max_subplots) # limit plot images
|
||||
ns = np.ceil(bs ** 0.5) # number of subplots (square)
|
||||
if np.max(images[0]) <= 1:
|
||||
images *= 255 # de-normalise (optional)
|
||||
|
||||
# Build Image
|
||||
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
|
||||
for i, im in enumerate(images):
|
||||
if i == max_subplots: # if last batch has fewer images than we expect
|
||||
break
|
||||
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
||||
im = im.transpose(1, 2, 0)
|
||||
mosaic[y:y + h, x:x + w, :] = im
|
||||
|
||||
# Resize (optional)
|
||||
scale = max_size / ns / max(h, w)
|
||||
if scale < 1:
|
||||
h = math.ceil(scale * h)
|
||||
w = math.ceil(scale * w)
|
||||
mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
|
||||
|
||||
# Annotate
|
||||
fs = int((h + w) * ns * 0.01) # font size
|
||||
annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
|
||||
for i in range(i + 1):
|
||||
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
||||
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
|
||||
if paths:
|
||||
annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
|
||||
if len(cls) > 0:
|
||||
idx = batch_idx == i
|
||||
classes = cls[idx].astype('int')
|
||||
|
||||
if len(bboxes):
|
||||
boxes = xywh2xyxy(bboxes[idx, :4]).T
|
||||
labels = bboxes.shape[1] == 4 # labels if no conf column
|
||||
conf = None if labels else bboxes[idx, 4] # check for confidence presence (label vs pred)
|
||||
|
||||
if boxes.shape[1]:
|
||||
if boxes.max() <= 1.01: # if normalized with tolerance 0.01
|
||||
boxes[[0, 2]] *= w # scale to pixels
|
||||
boxes[[1, 3]] *= h
|
||||
elif scale < 1: # absolute coords need scale if image scales
|
||||
boxes *= scale
|
||||
boxes[[0, 2]] += x
|
||||
boxes[[1, 3]] += y
|
||||
for j, box in enumerate(boxes.T.tolist()):
|
||||
c = classes[j]
|
||||
color = colors(c)
|
||||
c = names.get(c, c) if names else c
|
||||
if labels or conf[j] > 0.25: # 0.25 conf thresh
|
||||
label = f'{c}' if labels else f'{c} {conf[j]:.1f}'
|
||||
annotator.box_label(box, label, color=color)
|
||||
elif len(classes):
|
||||
for c in classes:
|
||||
color = colors(c)
|
||||
c = names.get(c, c) if names else c
|
||||
annotator.text((x, y), f'{c}', txt_color=color, box_style=True)
|
||||
|
||||
# Plot keypoints
|
||||
if len(kpts):
|
||||
kpts_ = kpts[idx].copy()
|
||||
if len(kpts_):
|
||||
if kpts_[..., 0].max() <= 1.01 or kpts_[..., 1].max() <= 1.01: # if normalized with tolerance .01
|
||||
kpts_[..., 0] *= w # scale to pixels
|
||||
kpts_[..., 1] *= h
|
||||
elif scale < 1: # absolute coords need scale if image scales
|
||||
kpts_ *= scale
|
||||
kpts_[..., 0] += x
|
||||
kpts_[..., 1] += y
|
||||
for j in range(len(kpts_)):
|
||||
if labels or conf[j] > 0.25: # 0.25 conf thresh
|
||||
annotator.kpts(kpts_[j])
|
||||
|
||||
# Plot masks
|
||||
if len(masks):
|
||||
if idx.shape[0] == masks.shape[0]: # overlap_masks=False
|
||||
image_masks = masks[idx]
|
||||
else: # overlap_masks=True
|
||||
image_masks = masks[[i]] # (1, 640, 640)
|
||||
nl = idx.sum()
|
||||
index = np.arange(nl).reshape((nl, 1, 1)) + 1
|
||||
image_masks = np.repeat(image_masks, nl, axis=0)
|
||||
image_masks = np.where(image_masks == index, 1.0, 0.0)
|
||||
|
||||
im = np.asarray(annotator.im).copy()
|
||||
for j, box in enumerate(boxes.T.tolist()):
|
||||
if labels or conf[j] > 0.25: # 0.25 conf thresh
|
||||
color = colors(classes[j])
|
||||
mh, mw = image_masks[j].shape
|
||||
if mh != h or mw != w:
|
||||
mask = image_masks[j].astype(np.uint8)
|
||||
mask = cv2.resize(mask, (w, h))
|
||||
mask = mask.astype(bool)
|
||||
else:
|
||||
mask = image_masks[j].astype(bool)
|
||||
with contextlib.suppress(Exception):
|
||||
im[y:y + h, x:x + w, :][mask] = im[y:y + h, x:x + w, :][mask] * 0.4 + np.array(color) * 0.6
|
||||
annotator.fromarray(im)
|
||||
annotator.im.save(fname) # save
|
||||
if on_plot:
|
||||
on_plot(fname)
|
||||
|
||||
|
||||
@plt_settings()
|
||||
def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False, classify=False, on_plot=None):
|
||||
"""Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')."""
|
||||
import pandas as pd
|
||||
save_dir = Path(file).parent if file else Path(dir)
|
||||
if classify:
|
||||
fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True)
|
||||
index = [1, 4, 2, 3]
|
||||
elif segment:
|
||||
fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)
|
||||
index = [1, 2, 3, 4, 5, 6, 9, 10, 13, 14, 15, 16, 7, 8, 11, 12]
|
||||
elif pose:
|
||||
fig, ax = plt.subplots(2, 9, figsize=(21, 6), tight_layout=True)
|
||||
index = [1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16, 17, 18, 8, 9, 12, 13]
|
||||
else:
|
||||
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
|
||||
index = [1, 2, 3, 4, 5, 8, 9, 10, 6, 7]
|
||||
ax = ax.ravel()
|
||||
files = list(save_dir.glob('results*.csv'))
|
||||
assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
|
||||
for f in files:
|
||||
try:
|
||||
data = pd.read_csv(f)
|
||||
s = [x.strip() for x in data.columns]
|
||||
x = data.values[:, 0]
|
||||
for i, j in enumerate(index):
|
||||
y = data.values[:, j].astype('float')
|
||||
# y[y == 0] = np.nan # don't show zero values
|
||||
ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8) # actual results
|
||||
ax[i].plot(x, gaussian_filter1d(y, sigma=3), ':', label='smooth', linewidth=2) # smoothing line
|
||||
ax[i].set_title(s[j], fontsize=12)
|
||||
# if j in [8, 9, 10]: # share train and val loss y axes
|
||||
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING: Plotting error for {f}: {e}')
|
||||
ax[1].legend()
|
||||
fname = save_dir / 'results.png'
|
||||
fig.savefig(fname, dpi=200)
|
||||
plt.close()
|
||||
if on_plot:
|
||||
on_plot(fname)
|
||||
|
||||
|
||||
def output_to_target(output, max_det=300):
|
||||
"""Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
|
||||
targets = []
|
||||
for i, o in enumerate(output):
|
||||
box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
|
||||
j = torch.full((conf.shape[0], 1), i)
|
||||
targets.append(torch.cat((j, cls, xyxy2xywh(box), conf), 1))
|
||||
targets = torch.cat(targets, 0).numpy()
|
||||
return targets[:, 0], targets[:, 1], targets[:, 2:]
|
||||
|
||||
|
||||
def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
|
||||
"""
|
||||
Visualize feature maps of a given model module during inference.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Features to be visualized.
|
||||
module_type (str): Module type.
|
||||
stage (int): Module stage within the model.
|
||||
n (int, optional): Maximum number of feature maps to plot. Defaults to 32.
|
||||
save_dir (Path, optional): Directory to save results. Defaults to Path('runs/detect/exp').
|
||||
"""
|
||||
for m in ['Detect', 'Pose', 'Segment']:
|
||||
if m in module_type:
|
||||
return
|
||||
batch, channels, height, width = x.shape # batch, channels, height, width
|
||||
if height > 1 and width > 1:
|
||||
f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
|
||||
|
||||
blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
|
||||
n = min(n, channels) # number of plots
|
||||
fig, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
|
||||
ax = ax.ravel()
|
||||
plt.subplots_adjust(wspace=0.05, hspace=0.05)
|
||||
for i in range(n):
|
||||
ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
|
||||
ax[i].axis('off')
|
||||
|
||||
LOGGER.info(f'Saving {f}... ({n}/{channels})')
|
||||
plt.savefig(f, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
np.save(str(f.with_suffix('.npy')), x[0].cpu().numpy()) # npy save
|
@ -1,276 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .checks import check_version
|
||||
from .metrics import bbox_iou
|
||||
|
||||
TORCH_1_10 = check_version(torch.__version__, '1.10.0')
|
||||
|
||||
|
||||
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
|
||||
"""select the positive anchor center in gt
|
||||
|
||||
Args:
|
||||
xy_centers (Tensor): shape(h*w, 4)
|
||||
gt_bboxes (Tensor): shape(b, n_boxes, 4)
|
||||
Return:
|
||||
(Tensor): shape(b, n_boxes, h*w)
|
||||
"""
|
||||
n_anchors = xy_centers.shape[0]
|
||||
bs, n_boxes, _ = gt_bboxes.shape
|
||||
lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom
|
||||
bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
|
||||
# return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype)
|
||||
return bbox_deltas.amin(3).gt_(eps)
|
||||
|
||||
|
||||
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
|
||||
"""if an anchor box is assigned to multiple gts,
|
||||
the one with the highest iou will be selected.
|
||||
|
||||
Args:
|
||||
mask_pos (Tensor): shape(b, n_max_boxes, h*w)
|
||||
overlaps (Tensor): shape(b, n_max_boxes, h*w)
|
||||
Return:
|
||||
target_gt_idx (Tensor): shape(b, h*w)
|
||||
fg_mask (Tensor): shape(b, h*w)
|
||||
mask_pos (Tensor): shape(b, n_max_boxes, h*w)
|
||||
"""
|
||||
# (b, n_max_boxes, h*w) -> (b, h*w)
|
||||
fg_mask = mask_pos.sum(-2)
|
||||
if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
|
||||
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)
|
||||
max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
|
||||
|
||||
is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
|
||||
is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
|
||||
|
||||
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w)
|
||||
fg_mask = mask_pos.sum(-2)
|
||||
# Find each grid serve which gt(index)
|
||||
target_gt_idx = mask_pos.argmax(-2) # (b, h*w)
|
||||
return target_gt_idx, fg_mask, mask_pos
|
||||
|
||||
|
||||
class TaskAlignedAssigner(nn.Module):
|
||||
"""
|
||||
A task-aligned assigner for object detection.
|
||||
|
||||
This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric,
|
||||
which combines both classification and localization information.
|
||||
|
||||
Attributes:
|
||||
topk (int): The number of top candidates to consider.
|
||||
num_classes (int): The number of object classes.
|
||||
alpha (float): The alpha parameter for the classification component of the task-aligned metric.
|
||||
beta (float): The beta parameter for the localization component of the task-aligned metric.
|
||||
eps (float): A small value to prevent division by zero.
|
||||
"""
|
||||
|
||||
def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
|
||||
"""Initialize a TaskAlignedAssigner object with customizable hyperparameters."""
|
||||
super().__init__()
|
||||
self.topk = topk
|
||||
self.num_classes = num_classes
|
||||
self.bg_idx = num_classes
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.eps = eps
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
||||
"""
|
||||
Compute the task-aligned assignment.
|
||||
Reference https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
|
||||
|
||||
Args:
|
||||
pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
|
||||
pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
|
||||
anc_points (Tensor): shape(num_total_anchors, 2)
|
||||
gt_labels (Tensor): shape(bs, n_max_boxes, 1)
|
||||
gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
|
||||
mask_gt (Tensor): shape(bs, n_max_boxes, 1)
|
||||
|
||||
Returns:
|
||||
target_labels (Tensor): shape(bs, num_total_anchors)
|
||||
target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
|
||||
target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
|
||||
fg_mask (Tensor): shape(bs, num_total_anchors)
|
||||
target_gt_idx (Tensor): shape(bs, num_total_anchors)
|
||||
"""
|
||||
self.bs = pd_scores.size(0)
|
||||
self.n_max_boxes = gt_bboxes.size(1)
|
||||
|
||||
if self.n_max_boxes == 0:
|
||||
device = gt_bboxes.device
|
||||
return (torch.full_like(pd_scores[..., 0], self.bg_idx).to(device), torch.zeros_like(pd_bboxes).to(device),
|
||||
torch.zeros_like(pd_scores).to(device), torch.zeros_like(pd_scores[..., 0]).to(device),
|
||||
torch.zeros_like(pd_scores[..., 0]).to(device))
|
||||
|
||||
mask_pos, align_metric, overlaps = self.get_pos_mask(pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points,
|
||||
mask_gt)
|
||||
|
||||
target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
|
||||
|
||||
# Assigned target
|
||||
target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)
|
||||
|
||||
# Normalize
|
||||
align_metric *= mask_pos
|
||||
pos_align_metrics = align_metric.amax(axis=-1, keepdim=True) # b, max_num_obj
|
||||
pos_overlaps = (overlaps * mask_pos).amax(axis=-1, keepdim=True) # b, max_num_obj
|
||||
norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
|
||||
target_scores = target_scores * norm_align_metric
|
||||
|
||||
return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
|
||||
|
||||
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
|
||||
"""Get in_gts mask, (b, max_num_obj, h*w)."""
|
||||
mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
|
||||
# Get anchor_align metric, (b, max_num_obj, h*w)
|
||||
align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
|
||||
# Get topk_metric mask, (b, max_num_obj, h*w)
|
||||
mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool())
|
||||
# Merge all mask to a final mask, (b, max_num_obj, h*w)
|
||||
mask_pos = mask_topk * mask_in_gts * mask_gt
|
||||
|
||||
return mask_pos, align_metric, overlaps
|
||||
|
||||
def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
|
||||
"""Compute alignment metric given predicted and ground truth bounding boxes."""
|
||||
na = pd_bboxes.shape[-2]
|
||||
mask_gt = mask_gt.bool() # b, max_num_obj, h*w
|
||||
overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
|
||||
bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
|
||||
|
||||
ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj
|
||||
ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes) # b, max_num_obj
|
||||
ind[1] = gt_labels.squeeze(-1) # b, max_num_obj
|
||||
# Get the scores of each grid for each gt cls
|
||||
bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt] # b, max_num_obj, h*w
|
||||
|
||||
# (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
|
||||
pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]
|
||||
gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]
|
||||
overlaps[mask_gt] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
|
||||
|
||||
align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
|
||||
return align_metric, overlaps
|
||||
|
||||
def select_topk_candidates(self, metrics, largest=True, topk_mask=None):
|
||||
"""
|
||||
Select the top-k candidates based on the given metrics.
|
||||
|
||||
Args:
|
||||
metrics (Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,
|
||||
max_num_obj is the maximum number of objects, and h*w represents the
|
||||
total number of anchor points.
|
||||
largest (bool): If True, select the largest values; otherwise, select the smallest values.
|
||||
topk_mask (Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), where
|
||||
topk is the number of top candidates to consider. If not provided,
|
||||
the top-k values are automatically computed based on the given metrics.
|
||||
|
||||
Returns:
|
||||
(Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
|
||||
"""
|
||||
|
||||
# (b, max_num_obj, topk)
|
||||
topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
|
||||
if topk_mask is None:
|
||||
topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
|
||||
# (b, max_num_obj, topk)
|
||||
topk_idxs.masked_fill_(~topk_mask, 0)
|
||||
|
||||
# (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
|
||||
count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
|
||||
ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
|
||||
for k in range(self.topk):
|
||||
# Expand topk_idxs for each value of k and add 1 at the specified positions
|
||||
count_tensor.scatter_add_(-1, topk_idxs[:, :, k:k + 1], ones)
|
||||
# count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device))
|
||||
# filter invalid bboxes
|
||||
count_tensor.masked_fill_(count_tensor > 1, 0)
|
||||
|
||||
return count_tensor.to(metrics.dtype)
|
||||
|
||||
def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
|
||||
"""
|
||||
Compute target labels, target bounding boxes, and target scores for the positive anchor points.
|
||||
|
||||
Args:
|
||||
gt_labels (Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
|
||||
batch size and max_num_obj is the maximum number of objects.
|
||||
gt_bboxes (Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
|
||||
target_gt_idx (Tensor): Indices of the assigned ground truth objects for positive
|
||||
anchor points, with shape (b, h*w), where h*w is the total
|
||||
number of anchor points.
|
||||
fg_mask (Tensor): A boolean tensor of shape (b, h*w) indicating the positive
|
||||
(foreground) anchor points.
|
||||
|
||||
Returns:
|
||||
(Tuple[Tensor, Tensor, Tensor]): A tuple containing the following tensors:
|
||||
- target_labels (Tensor): Shape (b, h*w), containing the target labels for
|
||||
positive anchor points.
|
||||
- target_bboxes (Tensor): Shape (b, h*w, 4), containing the target bounding boxes
|
||||
for positive anchor points.
|
||||
- target_scores (Tensor): Shape (b, h*w, num_classes), containing the target scores
|
||||
for positive anchor points, where num_classes is the number
|
||||
of object classes.
|
||||
"""
|
||||
|
||||
# Assigned target labels, (b, 1)
|
||||
batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
|
||||
target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w)
|
||||
target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w)
|
||||
|
||||
# Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w)
|
||||
target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx]
|
||||
|
||||
# Assigned target scores
|
||||
target_labels.clamp_(0)
|
||||
|
||||
# 10x faster than F.one_hot()
|
||||
target_scores = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.num_classes),
|
||||
dtype=torch.int64,
|
||||
device=target_labels.device) # (b, h*w, 80)
|
||||
target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
|
||||
|
||||
fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80)
|
||||
target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
|
||||
|
||||
return target_labels, target_bboxes, target_scores
|
||||
|
||||
|
||||
def make_anchors(feats, strides, grid_cell_offset=0.5):
|
||||
"""Generate anchors from features."""
|
||||
anchor_points, stride_tensor = [], []
|
||||
assert feats is not None
|
||||
dtype, device = feats[0].dtype, feats[0].device
|
||||
for i, stride in enumerate(strides):
|
||||
_, _, h, w = feats[i].shape
|
||||
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
|
||||
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
|
||||
sy, sx = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx)
|
||||
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
|
||||
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
|
||||
return torch.cat(anchor_points), torch.cat(stride_tensor)
|
||||
|
||||
|
||||
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
|
||||
"""Transform distance(ltrb) to box(xywh or xyxy)."""
|
||||
lt, rb = distance.chunk(2, dim)
|
||||
x1y1 = anchor_points - lt
|
||||
x2y2 = anchor_points + rb
|
||||
if xywh:
|
||||
c_xy = (x1y1 + x2y2) / 2
|
||||
wh = x2y2 - x1y1
|
||||
return torch.cat((c_xy, wh), dim) # xywh bbox
|
||||
return torch.cat((x1y1, x2y2), dim) # xyxy bbox
|
||||
|
||||
|
||||
def bbox2dist(anchor_points, bbox, reg_max):
|
||||
"""Transform bbox(xyxy) to dist(ltrb)."""
|
||||
x1y1, x2y2 = bbox.chunk(2, -1)
|
||||
return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01) # dist (lt, rb)
|
@ -1,512 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import math
|
||||
import os
|
||||
import platform
|
||||
import random
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, __version__
|
||||
from ultralytics.yolo.utils.checks import check_requirements, check_version
|
||||
|
||||
try:
|
||||
import thop
|
||||
except ImportError:
|
||||
thop = None
|
||||
|
||||
TORCHVISION_0_10 = check_version(torchvision.__version__, '0.10.0')
|
||||
TORCH_1_9 = check_version(torch.__version__, '1.9.0')
|
||||
TORCH_1_11 = check_version(torch.__version__, '1.11.0')
|
||||
TORCH_1_12 = check_version(torch.__version__, '1.12.0')
|
||||
TORCH_2_0 = check_version(torch.__version__, minimum='2.0')
|
||||
|
||||
|
||||
@contextmanager
|
||||
def torch_distributed_zero_first(local_rank: int):
|
||||
"""Decorator to make all processes in distributed training wait for each local_master to do something."""
|
||||
initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
|
||||
if initialized and local_rank not in (-1, 0):
|
||||
dist.barrier(device_ids=[local_rank])
|
||||
yield
|
||||
if initialized and local_rank == 0:
|
||||
dist.barrier(device_ids=[0])
|
||||
|
||||
|
||||
def smart_inference_mode():
|
||||
"""Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator."""
|
||||
|
||||
def decorate(fn):
|
||||
"""Applies appropriate torch decorator for inference mode based on torch version."""
|
||||
return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def get_cpu_info():
|
||||
"""Return a string with system CPU information, i.e. 'Apple M2'."""
|
||||
check_requirements('py-cpuinfo')
|
||||
import cpuinfo # noqa
|
||||
return cpuinfo.get_cpu_info()['brand_raw'].replace('(R)', '').replace('CPU ', '').replace('@ ', '')
|
||||
|
||||
|
||||
def select_device(device='', batch=0, newline=False, verbose=True):
|
||||
"""Selects PyTorch Device. Options are device = None or 'cpu' or 0 or '0' or '0,1,2,3'."""
|
||||
s = f'Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
|
||||
device = str(device).lower()
|
||||
for remove in 'cuda:', 'none', '(', ')', '[', ']', "'", ' ':
|
||||
device = device.replace(remove, '') # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
|
||||
cpu = device == 'cpu'
|
||||
mps = device == 'mps' # Apple Metal Performance Shaders (MPS)
|
||||
if cpu or mps:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
|
||||
elif device: # non-cpu device requested
|
||||
if device == 'cuda':
|
||||
device = '0'
|
||||
visible = os.environ.get('CUDA_VISIBLE_DEVICES', None)
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
|
||||
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', ''))):
|
||||
LOGGER.info(s)
|
||||
install = 'See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no ' \
|
||||
'CUDA devices are seen by torch.\n' if torch.cuda.device_count() == 0 else ''
|
||||
raise ValueError(f"Invalid CUDA 'device={device}' requested."
|
||||
f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
|
||||
f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
|
||||
f'\ntorch.cuda.is_available(): {torch.cuda.is_available()}'
|
||||
f'\ntorch.cuda.device_count(): {torch.cuda.device_count()}'
|
||||
f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
|
||||
f'{install}')
|
||||
|
||||
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
|
||||
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
||||
n = len(devices) # device count
|
||||
if n > 1 and batch > 0 and batch % n != 0: # check batch_size is divisible by device_count
|
||||
raise ValueError(f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
|
||||
f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}.")
|
||||
space = ' ' * (len(s) + 1)
|
||||
for i, d in enumerate(devices):
|
||||
p = torch.cuda.get_device_properties(i)
|
||||
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
|
||||
arg = 'cuda:0'
|
||||
elif mps and getattr(torch, 'has_mps', False) and torch.backends.mps.is_available() and TORCH_2_0:
|
||||
# Prefer MPS if available
|
||||
s += f'MPS ({get_cpu_info()})\n'
|
||||
arg = 'mps'
|
||||
else: # revert to CPU
|
||||
s += f'CPU ({get_cpu_info()})\n'
|
||||
arg = 'cpu'
|
||||
|
||||
if verbose and RANK == -1:
|
||||
LOGGER.info(s if newline else s.rstrip())
|
||||
return torch.device(arg)
|
||||
|
||||
|
||||
def time_sync():
|
||||
"""PyTorch-accurate time."""
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
return time.time()
|
||||
|
||||
|
||||
def fuse_conv_and_bn(conv, bn):
|
||||
"""Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/."""
|
||||
fusedconv = nn.Conv2d(conv.in_channels,
|
||||
conv.out_channels,
|
||||
kernel_size=conv.kernel_size,
|
||||
stride=conv.stride,
|
||||
padding=conv.padding,
|
||||
dilation=conv.dilation,
|
||||
groups=conv.groups,
|
||||
bias=True).requires_grad_(False).to(conv.weight.device)
|
||||
|
||||
# Prepare filters
|
||||
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
||||
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
||||
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
|
||||
|
||||
# Prepare spatial bias
|
||||
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
|
||||
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
|
||||
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
|
||||
|
||||
return fusedconv
|
||||
|
||||
|
||||
def fuse_deconv_and_bn(deconv, bn):
|
||||
"""Fuse ConvTranspose2d() and BatchNorm2d() layers."""
|
||||
fuseddconv = nn.ConvTranspose2d(deconv.in_channels,
|
||||
deconv.out_channels,
|
||||
kernel_size=deconv.kernel_size,
|
||||
stride=deconv.stride,
|
||||
padding=deconv.padding,
|
||||
output_padding=deconv.output_padding,
|
||||
dilation=deconv.dilation,
|
||||
groups=deconv.groups,
|
||||
bias=True).requires_grad_(False).to(deconv.weight.device)
|
||||
|
||||
# Prepare filters
|
||||
w_deconv = deconv.weight.clone().view(deconv.out_channels, -1)
|
||||
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
||||
fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))
|
||||
|
||||
# Prepare spatial bias
|
||||
b_conv = torch.zeros(deconv.weight.size(1), device=deconv.weight.device) if deconv.bias is None else deconv.bias
|
||||
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
|
||||
fuseddconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
|
||||
|
||||
return fuseddconv
|
||||
|
||||
|
||||
def model_info(model, detailed=False, verbose=True, imgsz=640):
|
||||
"""Model information. imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]."""
|
||||
if not verbose:
|
||||
return
|
||||
n_p = get_num_params(model) # number of parameters
|
||||
n_g = get_num_gradients(model) # number of gradients
|
||||
n_l = len(list(model.modules())) # number of layers
|
||||
if detailed:
|
||||
LOGGER.info(
|
||||
f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
|
||||
for i, (name, p) in enumerate(model.named_parameters()):
|
||||
name = name.replace('module_list.', '')
|
||||
LOGGER.info('%5g %40s %9s %12g %20s %10.3g %10.3g %10s' %
|
||||
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std(), p.dtype))
|
||||
|
||||
flops = get_flops(model, imgsz)
|
||||
fused = ' (fused)' if getattr(model, 'is_fused', lambda: False)() else ''
|
||||
fs = f', {flops:.1f} GFLOPs' if flops else ''
|
||||
yaml_file = getattr(model, 'yaml_file', '') or getattr(model, 'yaml', {}).get('yaml_file', '')
|
||||
model_name = Path(yaml_file).stem.replace('yolo', 'YOLO') or 'Model'
|
||||
LOGGER.info(f'{model_name} summary{fused}: {n_l} layers, {n_p} parameters, {n_g} gradients{fs}')
|
||||
return n_l, n_p, n_g, flops
|
||||
|
||||
|
||||
def get_num_params(model):
|
||||
"""Return the total number of parameters in a YOLO model."""
|
||||
return sum(x.numel() for x in model.parameters())
|
||||
|
||||
|
||||
def get_num_gradients(model):
|
||||
"""Return the total number of parameters with gradients in a YOLO model."""
|
||||
return sum(x.numel() for x in model.parameters() if x.requires_grad)
|
||||
|
||||
|
||||
def model_info_for_loggers(trainer):
|
||||
"""
|
||||
Return model info dict with useful model information.
|
||||
|
||||
Example for YOLOv8n:
|
||||
{'model/parameters': 3151904,
|
||||
'model/GFLOPs': 8.746,
|
||||
'model/speed_ONNX(ms)': 41.244,
|
||||
'model/speed_TensorRT(ms)': 3.211,
|
||||
'model/speed_PyTorch(ms)': 18.755}
|
||||
"""
|
||||
if trainer.args.profile: # profile ONNX and TensorRT times
|
||||
from ultralytics.yolo.utils.benchmarks import ProfileModels
|
||||
results = ProfileModels([trainer.last], device=trainer.device).profile()[0]
|
||||
results.pop('model/name')
|
||||
else: # only return PyTorch times from most recent validation
|
||||
results = {
|
||||
'model/parameters': get_num_params(trainer.model),
|
||||
'model/GFLOPs': round(get_flops(trainer.model), 3)}
|
||||
results['model/speed_PyTorch(ms)'] = round(trainer.validator.speed['inference'], 3)
|
||||
return results
|
||||
|
||||
|
||||
def get_flops(model, imgsz=640):
|
||||
"""Return a YOLO model's FLOPs."""
|
||||
try:
|
||||
model = de_parallel(model)
|
||||
p = next(model.parameters())
|
||||
stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride
|
||||
im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
||||
flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1E9 * 2 if thop else 0 # stride GFLOPs
|
||||
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
|
||||
return flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
def get_flops_with_torch_profiler(model, imgsz=640):
|
||||
"""Compute model FLOPs (thop alternative)."""
|
||||
model = de_parallel(model)
|
||||
p = next(model.parameters())
|
||||
stride = (max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32) * 2 # max stride
|
||||
im = torch.zeros((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
||||
with torch.profiler.profile(with_flops=True) as prof:
|
||||
model(im)
|
||||
flops = sum(x.flops for x in prof.key_averages()) / 1E9
|
||||
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
|
||||
flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
|
||||
return flops
|
||||
|
||||
|
||||
def initialize_weights(model):
|
||||
"""Initialize model weights to random values."""
|
||||
for m in model.modules():
|
||||
t = type(m)
|
||||
if t is nn.Conv2d:
|
||||
pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif t is nn.BatchNorm2d:
|
||||
m.eps = 1e-3
|
||||
m.momentum = 0.03
|
||||
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
|
||||
m.inplace = True
|
||||
|
||||
|
||||
def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
|
||||
# Scales img(bs,3,y,x) by ratio constrained to gs-multiple
|
||||
if ratio == 1.0:
|
||||
return img
|
||||
h, w = img.shape[2:]
|
||||
s = (int(h * ratio), int(w * ratio)) # new size
|
||||
img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
|
||||
if not same_shape: # pad/crop img
|
||||
h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
|
||||
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
|
||||
|
||||
|
||||
def make_divisible(x, divisor):
|
||||
"""Returns nearest x divisible by divisor."""
|
||||
if isinstance(divisor, torch.Tensor):
|
||||
divisor = int(divisor.max()) # to int
|
||||
return math.ceil(x / divisor) * divisor
|
||||
|
||||
|
||||
def copy_attr(a, b, include=(), exclude=()):
|
||||
"""Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes."""
|
||||
for k, v in b.__dict__.items():
|
||||
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
|
||||
continue
|
||||
else:
|
||||
setattr(a, k, v)
|
||||
|
||||
|
||||
def get_latest_opset():
|
||||
"""Return second-most (for maturity) recently supported ONNX opset by this version of torch."""
|
||||
return max(int(k[14:]) for k in vars(torch.onnx) if 'symbolic_opset' in k) - 1 # opset
|
||||
|
||||
|
||||
def intersect_dicts(da, db, exclude=()):
|
||||
"""Returns a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values."""
|
||||
return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
|
||||
|
||||
|
||||
def is_parallel(model):
|
||||
"""Returns True if model is of type DP or DDP."""
|
||||
return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))
|
||||
|
||||
|
||||
def de_parallel(model):
|
||||
"""De-parallelize a model: returns single-GPU model if model is of type DP or DDP."""
|
||||
return model.module if is_parallel(model) else model
|
||||
|
||||
|
||||
def one_cycle(y1=0.0, y2=1.0, steps=100):
|
||||
"""Returns a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf."""
|
||||
return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
|
||||
|
||||
|
||||
def init_seeds(seed=0, deterministic=False):
|
||||
"""Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
|
||||
# torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
|
||||
if deterministic:
|
||||
if TORCH_2_0:
|
||||
torch.use_deterministic_algorithms(True, warn_only=True) # warn if deterministic is not possible
|
||||
torch.backends.cudnn.deterministic = True
|
||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
else:
|
||||
LOGGER.warning('WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.')
|
||||
else:
|
||||
torch.use_deterministic_algorithms(False)
|
||||
torch.backends.cudnn.deterministic = False
|
||||
|
||||
|
||||
class ModelEMA:
|
||||
"""Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
|
||||
Keeps a moving average of everything in the model state_dict (parameters and buffers)
|
||||
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||
To disable EMA set the `enabled` attribute to `False`.
|
||||
"""
|
||||
|
||||
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
||||
"""Create EMA."""
|
||||
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
||||
self.updates = updates # number of EMA updates
|
||||
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
||||
for p in self.ema.parameters():
|
||||
p.requires_grad_(False)
|
||||
self.enabled = True
|
||||
|
||||
def update(self, model):
|
||||
"""Update EMA parameters."""
|
||||
if self.enabled:
|
||||
self.updates += 1
|
||||
d = self.decay(self.updates)
|
||||
|
||||
msd = de_parallel(model).state_dict() # model state_dict
|
||||
for k, v in self.ema.state_dict().items():
|
||||
if v.dtype.is_floating_point: # true for FP16 and FP32
|
||||
v *= d
|
||||
v += (1 - d) * msd[k].detach()
|
||||
# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}'
|
||||
|
||||
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
||||
"""Updates attributes and saves stripped model with optimizer removed."""
|
||||
if self.enabled:
|
||||
copy_attr(self.ema, model, include, exclude)
|
||||
|
||||
|
||||
def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None:
|
||||
"""
|
||||
Strip optimizer from 'f' to finalize training, optionally save as 's'.
|
||||
|
||||
Args:
|
||||
f (str): file path to model to strip the optimizer from. Default is 'best.pt'.
|
||||
s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Usage:
|
||||
from pathlib import Path
|
||||
from ultralytics.yolo.utils.torch_utils import strip_optimizer
|
||||
for f in Path('/Users/glennjocher/Downloads/weights').rglob('*.pt'):
|
||||
strip_optimizer(f)
|
||||
"""
|
||||
# Use dill (if exists) to serialize the lambda functions where pickle does not do this
|
||||
try:
|
||||
import dill as pickle
|
||||
except ImportError:
|
||||
import pickle
|
||||
|
||||
x = torch.load(f, map_location=torch.device('cpu'))
|
||||
args = {**DEFAULT_CFG_DICT, **x['train_args']} if 'train_args' in x else None # combine args
|
||||
if x.get('ema'):
|
||||
x['model'] = x['ema'] # replace model with ema
|
||||
for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
|
||||
x[k] = None
|
||||
x['epoch'] = -1
|
||||
x['model'].half() # to FP16
|
||||
for p in x['model'].parameters():
|
||||
p.requires_grad = False
|
||||
x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
|
||||
# x['model'].args = x['train_args']
|
||||
torch.save(x, s or f, pickle_module=pickle)
|
||||
mb = os.path.getsize(s or f) / 1E6 # filesize
|
||||
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
||||
|
||||
|
||||
def profile(input, ops, n=10, device=None):
|
||||
"""
|
||||
YOLOv8 speed/memory/FLOPs profiler
|
||||
|
||||
Usage:
|
||||
input = torch.randn(16, 3, 640, 640)
|
||||
m1 = lambda x: x * torch.sigmoid(x)
|
||||
m2 = nn.SiLU()
|
||||
profile(input, [m1, m2], n=100) # profile over 100 iterations
|
||||
"""
|
||||
results = []
|
||||
if not isinstance(device, torch.device):
|
||||
device = select_device(device)
|
||||
LOGGER.info(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
|
||||
f"{'input':>24s}{'output':>24s}")
|
||||
|
||||
for x in input if isinstance(input, list) else [input]:
|
||||
x = x.to(device)
|
||||
x.requires_grad = True
|
||||
for m in ops if isinstance(ops, list) else [ops]:
|
||||
m = m.to(device) if hasattr(m, 'to') else m # device
|
||||
m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
|
||||
tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
|
||||
try:
|
||||
flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1E9 * 2 if thop else 0 # GFLOPs
|
||||
except Exception:
|
||||
flops = 0
|
||||
|
||||
try:
|
||||
for _ in range(n):
|
||||
t[0] = time_sync()
|
||||
y = m(x)
|
||||
t[1] = time_sync()
|
||||
try:
|
||||
_ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
|
||||
t[2] = time_sync()
|
||||
except Exception: # no backward method
|
||||
# print(e) # for debug
|
||||
t[2] = float('nan')
|
||||
tf += (t[1] - t[0]) * 1000 / n # ms per op forward
|
||||
tb += (t[2] - t[1]) * 1000 / n # ms per op backward
|
||||
mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
|
||||
s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' for x in (x, y)) # shapes
|
||||
p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
|
||||
LOGGER.info(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
|
||||
results.append([p, flops, mem, tf, tb, s_in, s_out])
|
||||
except Exception as e:
|
||||
LOGGER.info(e)
|
||||
results.append(None)
|
||||
torch.cuda.empty_cache()
|
||||
return results
|
||||
|
||||
|
||||
class EarlyStopping:
|
||||
"""
|
||||
Early stopping class that stops training when a specified number of epochs have passed without improvement.
|
||||
"""
|
||||
|
||||
def __init__(self, patience=50):
|
||||
"""
|
||||
Initialize early stopping object
|
||||
|
||||
Args:
|
||||
patience (int, optional): Number of epochs to wait after fitness stops improving before stopping.
|
||||
"""
|
||||
self.best_fitness = 0.0 # i.e. mAP
|
||||
self.best_epoch = 0
|
||||
self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
|
||||
self.possible_stop = False # possible stop may occur next epoch
|
||||
|
||||
def __call__(self, epoch, fitness):
|
||||
"""
|
||||
Check whether to stop training
|
||||
|
||||
Args:
|
||||
epoch (int): Current epoch of training
|
||||
fitness (float): Fitness value of current epoch
|
||||
|
||||
Returns:
|
||||
(bool): True if training should stop, False otherwise
|
||||
"""
|
||||
if fitness is None: # check if fitness=None (happens when val=False)
|
||||
return False
|
||||
|
||||
if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
|
||||
self.best_epoch = epoch
|
||||
self.best_fitness = fitness
|
||||
delta = epoch - self.best_epoch # epochs without improvement
|
||||
self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
|
||||
stop = delta >= self.patience # stop training if patience exceeded
|
||||
if stop:
|
||||
LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. '
|
||||
f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n'
|
||||
f'To update EarlyStopping(patience={self.patience}) pass a new patience value, '
|
||||
f'i.e. `patience=300` or use `patience=0` to disable EarlyStopping.')
|
||||
return stop
|
@ -1,120 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
from ultralytics.yolo.cfg import TASK2DATA, TASK2METRIC
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, LOGGER, NUM_THREADS
|
||||
|
||||
|
||||
def run_ray_tune(model,
|
||||
space: dict = None,
|
||||
grace_period: int = 10,
|
||||
gpu_per_trial: int = None,
|
||||
max_samples: int = 10,
|
||||
**train_args):
|
||||
"""
|
||||
Runs hyperparameter tuning using Ray Tune.
|
||||
|
||||
Args:
|
||||
model (YOLO): Model to run the tuner on.
|
||||
space (dict, optional): The hyperparameter search space. Defaults to None.
|
||||
grace_period (int, optional): The grace period in epochs of the ASHA scheduler. Defaults to 10.
|
||||
gpu_per_trial (int, optional): The number of GPUs to allocate per trial. Defaults to None.
|
||||
max_samples (int, optional): The maximum number of trials to run. Defaults to 10.
|
||||
train_args (dict, optional): Additional arguments to pass to the `train()` method. Defaults to {}.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary containing the results of the hyperparameter search.
|
||||
|
||||
Raises:
|
||||
ModuleNotFoundError: If Ray Tune is not installed.
|
||||
"""
|
||||
if train_args is None:
|
||||
train_args = {}
|
||||
|
||||
try:
|
||||
from ray import tune
|
||||
from ray.air import RunConfig
|
||||
from ray.air.integrations.wandb import WandbLoggerCallback
|
||||
from ray.tune.schedulers import ASHAScheduler
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError("Tuning hyperparameters requires Ray Tune. Install with: pip install 'ray[tune]'")
|
||||
|
||||
try:
|
||||
import wandb
|
||||
|
||||
assert hasattr(wandb, '__version__')
|
||||
except (ImportError, AssertionError):
|
||||
wandb = False
|
||||
|
||||
default_space = {
|
||||
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
|
||||
'lr0': tune.uniform(1e-5, 1e-1),
|
||||
'lrf': tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
|
||||
'momentum': tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1
|
||||
'weight_decay': tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4
|
||||
'warmup_epochs': tune.uniform(0.0, 5.0), # warmup epochs (fractions ok)
|
||||
'warmup_momentum': tune.uniform(0.0, 0.95), # warmup initial momentum
|
||||
'box': tune.uniform(0.02, 0.2), # box loss gain
|
||||
'cls': tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels)
|
||||
'hsv_h': tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction)
|
||||
'hsv_s': tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction)
|
||||
'hsv_v': tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction)
|
||||
'degrees': tune.uniform(0.0, 45.0), # image rotation (+/- deg)
|
||||
'translate': tune.uniform(0.0, 0.9), # image translation (+/- fraction)
|
||||
'scale': tune.uniform(0.0, 0.9), # image scale (+/- gain)
|
||||
'shear': tune.uniform(0.0, 10.0), # image shear (+/- deg)
|
||||
'perspective': tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
|
||||
'flipud': tune.uniform(0.0, 1.0), # image flip up-down (probability)
|
||||
'fliplr': tune.uniform(0.0, 1.0), # image flip left-right (probability)
|
||||
'mosaic': tune.uniform(0.0, 1.0), # image mixup (probability)
|
||||
'mixup': tune.uniform(0.0, 1.0), # image mixup (probability)
|
||||
'copy_paste': tune.uniform(0.0, 1.0)} # segment copy-paste (probability)
|
||||
|
||||
def _tune(config):
|
||||
"""
|
||||
Trains the YOLO model with the specified hyperparameters and additional arguments.
|
||||
|
||||
Args:
|
||||
config (dict): A dictionary of hyperparameters to use for training.
|
||||
|
||||
Returns:
|
||||
None.
|
||||
"""
|
||||
model._reset_callbacks()
|
||||
config.update(train_args)
|
||||
model.train(**config)
|
||||
|
||||
# Get search space
|
||||
if not space:
|
||||
space = default_space
|
||||
LOGGER.warning('WARNING ⚠️ search space not provided, using default search space.')
|
||||
|
||||
# Get dataset
|
||||
data = train_args.get('data', TASK2DATA[model.task])
|
||||
space['data'] = data
|
||||
if 'data' not in train_args:
|
||||
LOGGER.warning(f'WARNING ⚠️ data not provided, using default "data={data}".')
|
||||
|
||||
# Define the trainable function with allocated resources
|
||||
trainable_with_resources = tune.with_resources(_tune, {'cpu': NUM_THREADS, 'gpu': gpu_per_trial or 0})
|
||||
|
||||
# Define the ASHA scheduler for hyperparameter search
|
||||
asha_scheduler = ASHAScheduler(time_attr='epoch',
|
||||
metric=TASK2METRIC[model.task],
|
||||
mode='max',
|
||||
max_t=train_args.get('epochs') or DEFAULT_CFG_DICT['epochs'] or 100,
|
||||
grace_period=grace_period,
|
||||
reduction_factor=3)
|
||||
|
||||
# Define the callbacks for the hyperparameter search
|
||||
tuner_callbacks = [WandbLoggerCallback(project='YOLOv8-tune')] if wandb else []
|
||||
|
||||
# Create the Ray Tune hyperparameter search tuner
|
||||
tuner = tune.Tuner(trainable_with_resources,
|
||||
param_space=space,
|
||||
tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples),
|
||||
run_config=RunConfig(callbacks=tuner_callbacks, storage_path='./runs/tune'))
|
||||
|
||||
# Run the hyperparameter search
|
||||
tuner.fit()
|
||||
|
||||
# Return the results of the hyperparameter search
|
||||
return tuner.get_results()
|
@ -1,5 +1,10 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
from ultralytics.yolo.v8 import classify, detect, pose, segment
|
||||
from ultralytics.utils import LOGGER
|
||||
|
||||
__all__ = 'classify', 'segment', 'detect', 'pose'
|
||||
# Set modules in sys.modules under their old name
|
||||
sys.modules['ultralytics.yolo.v8'] = importlib.import_module('ultralytics.models.yolo')
|
||||
|
||||
LOGGER.warning("WARNING ⚠️ 'ultralytics.yolo.v8' is deprecated since '8.0.136' and will be removed in '8.1.0'. "
|
||||
"Please use 'ultralytics.models.yolo' instead.")
|
||||
|
@ -1,7 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from ultralytics.yolo.v8.classify.predict import ClassificationPredictor, predict
|
||||
from ultralytics.yolo.v8.classify.train import ClassificationTrainer, train
|
||||
from ultralytics.yolo.v8.classify.val import ClassificationValidator, val
|
||||
|
||||
__all__ = 'ClassificationPredictor', 'predict', 'ClassificationTrainer', 'train', 'ClassificationValidator', 'val'
|
@ -1,51 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.engine.predictor import BasePredictor
|
||||
from ultralytics.yolo.engine.results import Results
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, ROOT
|
||||
|
||||
|
||||
class ClassificationPredictor(BasePredictor):
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.task = 'classify'
|
||||
|
||||
def preprocess(self, img):
|
||||
"""Converts input image to model-compatible data type."""
|
||||
if not isinstance(img, torch.Tensor):
|
||||
img = torch.stack([self.transforms(im) for im in img], dim=0)
|
||||
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
|
||||
return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Postprocesses predictions to return Results objects."""
|
||||
results = []
|
||||
for i, pred in enumerate(preds):
|
||||
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
|
||||
path = self.batch[0]
|
||||
img_path = path[i] if isinstance(path, list) else path
|
||||
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, probs=pred))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def predict(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Run YOLO model predictions on input images/videos."""
|
||||
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
|
||||
source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
|
||||
else 'https://ultralytics.com/images/bus.jpg'
|
||||
|
||||
args = dict(model=model, source=source)
|
||||
if use_python:
|
||||
from ultralytics import YOLO
|
||||
YOLO(model)(**args)
|
||||
else:
|
||||
predictor = ClassificationPredictor(overrides=args)
|
||||
predictor.predict_cli()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
predict()
|
@ -1,161 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight
|
||||
from ultralytics.yolo import v8
|
||||
from ultralytics.yolo.data import ClassificationDataset, build_dataloader
|
||||
from ultralytics.yolo.engine.trainer import BaseTrainer
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, colorstr
|
||||
from ultralytics.yolo.utils.plotting import plot_images, plot_results
|
||||
from ultralytics.yolo.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
|
||||
|
||||
|
||||
class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""Initialize a ClassificationTrainer object with optional configuration overrides and callbacks."""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides['task'] = 'classify'
|
||||
if overrides.get('imgsz') is None:
|
||||
overrides['imgsz'] = 224
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def set_model_attributes(self):
|
||||
"""Set the YOLO model's class names from the loaded dataset."""
|
||||
self.model.names = self.data['names']
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""Returns a modified PyTorch model configured for training YOLO."""
|
||||
model = ClassificationModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
for m in model.modules():
|
||||
if not self.args.pretrained and hasattr(m, 'reset_parameters'):
|
||||
m.reset_parameters()
|
||||
if isinstance(m, torch.nn.Dropout) and self.args.dropout:
|
||||
m.p = self.args.dropout # set dropout
|
||||
for p in model.parameters():
|
||||
p.requires_grad = True # for training
|
||||
return model
|
||||
|
||||
def setup_model(self):
|
||||
"""
|
||||
load/create/download model for any task
|
||||
"""
|
||||
# Classification models require special handling
|
||||
|
||||
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
||||
return
|
||||
|
||||
model = str(self.model)
|
||||
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
|
||||
if model.endswith('.pt'):
|
||||
self.model, _ = attempt_load_one_weight(model, device='cpu')
|
||||
for p in self.model.parameters():
|
||||
p.requires_grad = True # for training
|
||||
elif model.endswith('.yaml'):
|
||||
self.model = self.get_model(cfg=model)
|
||||
elif model in torchvision.models.__dict__:
|
||||
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if self.args.pretrained else None)
|
||||
else:
|
||||
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
|
||||
ClassificationModel.reshape_outputs(self.model, self.data['nc'])
|
||||
|
||||
return # dont return ckpt. Classification doesn't support resume
|
||||
|
||||
def build_dataset(self, img_path, mode='train', batch=None):
|
||||
return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train')
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
|
||||
"""Returns PyTorch DataLoader with transforms to preprocess images for inference."""
|
||||
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
||||
dataset = self.build_dataset(dataset_path, mode)
|
||||
|
||||
loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)
|
||||
# Attach inference transforms
|
||||
if mode != 'train':
|
||||
if is_parallel(self.model):
|
||||
self.model.module.transforms = loader.dataset.torch_transforms
|
||||
else:
|
||||
self.model.transforms = loader.dataset.torch_transforms
|
||||
return loader
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
"""Preprocesses a batch of images and classes."""
|
||||
batch['img'] = batch['img'].to(self.device)
|
||||
batch['cls'] = batch['cls'].to(self.device)
|
||||
return batch
|
||||
|
||||
def progress_string(self):
|
||||
"""Returns a formatted string showing training progress."""
|
||||
return ('\n' + '%11s' * (4 + len(self.loss_names))) % \
|
||||
('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
|
||||
|
||||
def get_validator(self):
|
||||
"""Returns an instance of ClassificationValidator for validation."""
|
||||
self.loss_names = ['loss']
|
||||
return v8.classify.ClassificationValidator(self.test_loader, self.save_dir)
|
||||
|
||||
def label_loss_items(self, loss_items=None, prefix='train'):
|
||||
"""
|
||||
Returns a loss dict with labelled training loss items tensor
|
||||
"""
|
||||
# Not needed for classification but necessary for segmentation & detection
|
||||
keys = [f'{prefix}/{x}' for x in self.loss_names]
|
||||
if loss_items is None:
|
||||
return keys
|
||||
loss_items = [round(float(loss_items), 5)]
|
||||
return dict(zip(keys, loss_items))
|
||||
|
||||
def resume_training(self, ckpt):
|
||||
"""Resumes training from a given checkpoint."""
|
||||
pass
|
||||
|
||||
def plot_metrics(self):
|
||||
"""Plots metrics from a CSV file."""
|
||||
plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png
|
||||
|
||||
def final_eval(self):
|
||||
"""Evaluate trained model and save validation results."""
|
||||
for f in self.last, self.best:
|
||||
if f.exists():
|
||||
strip_optimizer(f) # strip optimizers
|
||||
# TODO: validate best.pt after training completes
|
||||
# if f is self.best:
|
||||
# LOGGER.info(f'\nValidating {f}...')
|
||||
# self.validator.args.save_json = True
|
||||
# self.metrics = self.validator(model=f)
|
||||
# self.metrics.pop('fitness', None)
|
||||
# self.run_callbacks('on_fit_epoch_end')
|
||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
|
||||
|
||||
def plot_training_samples(self, batch, ni):
|
||||
"""Plots training samples with their annotations."""
|
||||
plot_images(images=batch['img'],
|
||||
batch_idx=torch.arange(len(batch['img'])),
|
||||
cls=batch['cls'].squeeze(-1),
|
||||
fname=self.save_dir / f'train_batch{ni}.jpg',
|
||||
on_plot=self.on_plot)
|
||||
|
||||
|
||||
def train(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Train the YOLO classification model."""
|
||||
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
|
||||
data = cfg.data or 'mnist160' # or yolo.ClassificationDataset("mnist")
|
||||
device = cfg.device if cfg.device is not None else ''
|
||||
|
||||
args = dict(model=model, data=data, device=device)
|
||||
if use_python:
|
||||
from ultralytics import YOLO
|
||||
YOLO(model).train(**args)
|
||||
else:
|
||||
trainer = ClassificationTrainer(overrides=args)
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
@ -1,109 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.data import ClassificationDataset, build_dataloader
|
||||
from ultralytics.yolo.engine.validator import BaseValidator
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER
|
||||
from ultralytics.yolo.utils.metrics import ClassifyMetrics, ConfusionMatrix
|
||||
from ultralytics.yolo.utils.plotting import plot_images
|
||||
|
||||
|
||||
class ClassificationValidator(BaseValidator):
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||
"""Initializes ClassificationValidator instance with args, dataloader, save_dir, and progress bar."""
|
||||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||
self.args.task = 'classify'
|
||||
self.metrics = ClassifyMetrics()
|
||||
|
||||
def get_desc(self):
|
||||
"""Returns a formatted string summarizing classification metrics."""
|
||||
return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc')
|
||||
|
||||
def init_metrics(self, model):
|
||||
"""Initialize confusion matrix, class names, and top-1 and top-5 accuracy."""
|
||||
self.names = model.names
|
||||
self.nc = len(model.names)
|
||||
self.confusion_matrix = ConfusionMatrix(nc=self.nc, task='classify')
|
||||
self.pred = []
|
||||
self.targets = []
|
||||
|
||||
def preprocess(self, batch):
|
||||
"""Preprocesses input batch and returns it."""
|
||||
batch['img'] = batch['img'].to(self.device, non_blocking=True)
|
||||
batch['img'] = batch['img'].half() if self.args.half else batch['img'].float()
|
||||
batch['cls'] = batch['cls'].to(self.device)
|
||||
return batch
|
||||
|
||||
def update_metrics(self, preds, batch):
|
||||
"""Updates running metrics with model predictions and batch targets."""
|
||||
n5 = min(len(self.model.names), 5)
|
||||
self.pred.append(preds.argsort(1, descending=True)[:, :n5])
|
||||
self.targets.append(batch['cls'])
|
||||
|
||||
def finalize_metrics(self, *args, **kwargs):
|
||||
"""Finalizes metrics of the model such as confusion_matrix and speed."""
|
||||
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
|
||||
if self.args.plots:
|
||||
for normalize in True, False:
|
||||
self.confusion_matrix.plot(save_dir=self.save_dir,
|
||||
names=self.names.values(),
|
||||
normalize=normalize,
|
||||
on_plot=self.on_plot)
|
||||
self.metrics.speed = self.speed
|
||||
self.metrics.confusion_matrix = self.confusion_matrix
|
||||
|
||||
def get_stats(self):
|
||||
"""Returns a dictionary of metrics obtained by processing targets and predictions."""
|
||||
self.metrics.process(self.targets, self.pred)
|
||||
return self.metrics.results_dict
|
||||
|
||||
def build_dataset(self, img_path):
|
||||
return ClassificationDataset(root=img_path, args=self.args, augment=False)
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size):
|
||||
"""Builds and returns a data loader for classification tasks with given parameters."""
|
||||
dataset = self.build_dataset(dataset_path)
|
||||
return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)
|
||||
|
||||
def print_results(self):
|
||||
"""Prints evaluation metrics for YOLO object detection model."""
|
||||
pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
|
||||
LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5))
|
||||
|
||||
def plot_val_samples(self, batch, ni):
|
||||
"""Plot validation image samples."""
|
||||
plot_images(images=batch['img'],
|
||||
batch_idx=torch.arange(len(batch['img'])),
|
||||
cls=batch['cls'].squeeze(-1),
|
||||
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
|
||||
names=self.names,
|
||||
on_plot=self.on_plot)
|
||||
|
||||
def plot_predictions(self, batch, preds, ni):
|
||||
"""Plots predicted bounding boxes on input images and saves the result."""
|
||||
plot_images(batch['img'],
|
||||
batch_idx=torch.arange(len(batch['img'])),
|
||||
cls=torch.argmax(preds, dim=1),
|
||||
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
|
||||
names=self.names,
|
||||
on_plot=self.on_plot) # pred
|
||||
|
||||
|
||||
def val(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Validate YOLO model using custom data."""
|
||||
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
|
||||
data = cfg.data or 'mnist160'
|
||||
|
||||
args = dict(model=model, data=data)
|
||||
if use_python:
|
||||
from ultralytics import YOLO
|
||||
YOLO(model).val(**args)
|
||||
else:
|
||||
validator = ClassificationValidator(args=args)
|
||||
validator(model=args['model'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
val()
|
@ -1,7 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from .predict import DetectionPredictor, predict
|
||||
from .train import DetectionTrainer, train
|
||||
from .val import DetectionValidator, val
|
||||
|
||||
__all__ = 'DetectionPredictor', 'predict', 'DetectionTrainer', 'train', 'DetectionValidator', 'val'
|
@ -1,48 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.engine.predictor import BasePredictor
|
||||
from ultralytics.yolo.engine.results import Results
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops
|
||||
|
||||
|
||||
class DetectionPredictor(BasePredictor):
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Postprocesses predictions and returns a list of Results objects."""
|
||||
preds = ops.non_max_suppression(preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
agnostic=self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
classes=self.args.classes)
|
||||
|
||||
results = []
|
||||
for i, pred in enumerate(preds):
|
||||
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
|
||||
if not isinstance(orig_imgs, torch.Tensor):
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||
path = self.batch[0]
|
||||
img_path = path[i] if isinstance(path, list) else path
|
||||
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred))
|
||||
return results
|
||||
|
||||
|
||||
def predict(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Runs YOLO model inference on input image(s)."""
|
||||
model = cfg.model or 'yolov8n.pt'
|
||||
source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
|
||||
else 'https://ultralytics.com/images/bus.jpg'
|
||||
|
||||
args = dict(model=model, source=source)
|
||||
if use_python:
|
||||
from ultralytics import YOLO
|
||||
YOLO(model)(**args)
|
||||
else:
|
||||
predictor = DetectionPredictor(overrides=args)
|
||||
predictor.predict_cli()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
predict()
|
@ -1,123 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
from copy import copy
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ultralytics.nn.tasks import DetectionModel
|
||||
from ultralytics.yolo import v8
|
||||
from ultralytics.yolo.data import build_dataloader, build_yolo_dataset
|
||||
from ultralytics.yolo.engine.trainer import BaseTrainer
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK
|
||||
from ultralytics.yolo.utils.plotting import plot_images, plot_labels, plot_results
|
||||
from ultralytics.yolo.utils.torch_utils import de_parallel, torch_distributed_zero_first
|
||||
|
||||
|
||||
# BaseTrainer python usage
|
||||
class DetectionTrainer(BaseTrainer):
|
||||
|
||||
def build_dataset(self, img_path, mode='train', batch=None):
|
||||
"""
|
||||
Build YOLO Dataset.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the folder containing images.
|
||||
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
||||
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
|
||||
"""
|
||||
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
||||
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == 'val', stride=gs)
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
|
||||
"""Construct and return dataloader."""
|
||||
assert mode in ['train', 'val']
|
||||
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
||||
dataset = self.build_dataset(dataset_path, mode, batch_size)
|
||||
shuffle = mode == 'train'
|
||||
if getattr(dataset, 'rect', False) and shuffle:
|
||||
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
|
||||
shuffle = False
|
||||
workers = self.args.workers if mode == 'train' else self.args.workers * 2
|
||||
return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
"""Preprocesses a batch of images by scaling and converting to float."""
|
||||
batch['img'] = batch['img'].to(self.device, non_blocking=True).float() / 255
|
||||
return batch
|
||||
|
||||
def set_model_attributes(self):
|
||||
"""nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)."""
|
||||
# self.args.box *= 3 / nl # scale to layers
|
||||
# self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
|
||||
# self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
|
||||
self.model.nc = self.data['nc'] # attach number of classes to model
|
||||
self.model.names = self.data['names'] # attach class names to model
|
||||
self.model.args = self.args # attach hyperparameters to model
|
||||
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""Return a YOLO detection model."""
|
||||
model = DetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
return model
|
||||
|
||||
def get_validator(self):
|
||||
"""Returns a DetectionValidator for YOLO model validation."""
|
||||
self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
|
||||
return v8.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
||||
|
||||
def label_loss_items(self, loss_items=None, prefix='train'):
|
||||
"""
|
||||
Returns a loss dict with labelled training loss items tensor
|
||||
"""
|
||||
# Not needed for classification but necessary for segmentation & detection
|
||||
keys = [f'{prefix}/{x}' for x in self.loss_names]
|
||||
if loss_items is not None:
|
||||
loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
|
||||
return dict(zip(keys, loss_items))
|
||||
else:
|
||||
return keys
|
||||
|
||||
def progress_string(self):
|
||||
"""Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
|
||||
return ('\n' + '%11s' *
|
||||
(4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
|
||||
|
||||
def plot_training_samples(self, batch, ni):
|
||||
"""Plots training samples with their annotations."""
|
||||
plot_images(images=batch['img'],
|
||||
batch_idx=batch['batch_idx'],
|
||||
cls=batch['cls'].squeeze(-1),
|
||||
bboxes=batch['bboxes'],
|
||||
paths=batch['im_file'],
|
||||
fname=self.save_dir / f'train_batch{ni}.jpg',
|
||||
on_plot=self.on_plot)
|
||||
|
||||
def plot_metrics(self):
|
||||
"""Plots metrics from a CSV file."""
|
||||
plot_results(file=self.csv, on_plot=self.on_plot) # save results.png
|
||||
|
||||
def plot_training_labels(self):
|
||||
"""Create a labeled training plot of the YOLO model."""
|
||||
boxes = np.concatenate([lb['bboxes'] for lb in self.train_loader.dataset.labels], 0)
|
||||
cls = np.concatenate([lb['cls'] for lb in self.train_loader.dataset.labels], 0)
|
||||
plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir, on_plot=self.on_plot)
|
||||
|
||||
|
||||
def train(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Train and optimize YOLO model given training data and device."""
|
||||
model = cfg.model or 'yolov8n.pt'
|
||||
data = cfg.data or 'coco128.yaml' # or yolo.ClassificationDataset("mnist")
|
||||
device = cfg.device if cfg.device is not None else ''
|
||||
|
||||
args = dict(model=model, data=data, device=device)
|
||||
if use_python:
|
||||
from ultralytics import YOLO
|
||||
YOLO(model).train(**args)
|
||||
else:
|
||||
trainer = DetectionTrainer(overrides=args)
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
@ -1,276 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.data import build_dataloader, build_yolo_dataset
|
||||
from ultralytics.yolo.engine.validator import BaseValidator
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, ops
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
from ultralytics.yolo.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
|
||||
from ultralytics.yolo.utils.plotting import output_to_target, plot_images
|
||||
from ultralytics.yolo.utils.torch_utils import de_parallel
|
||||
|
||||
|
||||
class DetectionValidator(BaseValidator):
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||
"""Initialize detection model with necessary variables and settings."""
|
||||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||
self.args.task = 'detect'
|
||||
self.is_coco = False
|
||||
self.class_map = None
|
||||
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
||||
self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95
|
||||
self.niou = self.iouv.numel()
|
||||
|
||||
def preprocess(self, batch):
|
||||
"""Preprocesses batch of images for YOLO training."""
|
||||
batch['img'] = batch['img'].to(self.device, non_blocking=True)
|
||||
batch['img'] = (batch['img'].half() if self.args.half else batch['img'].float()) / 255
|
||||
for k in ['batch_idx', 'cls', 'bboxes']:
|
||||
batch[k] = batch[k].to(self.device)
|
||||
|
||||
nb = len(batch['img'])
|
||||
self.lb = [torch.cat([batch['cls'], batch['bboxes']], dim=-1)[batch['batch_idx'] == i]
|
||||
for i in range(nb)] if self.args.save_hybrid else [] # for autolabelling
|
||||
|
||||
return batch
|
||||
|
||||
def init_metrics(self, model):
|
||||
"""Initialize evaluation metrics for YOLO."""
|
||||
val = self.data.get(self.args.split, '') # validation path
|
||||
self.is_coco = isinstance(val, str) and 'coco' in val and val.endswith(f'{os.sep}val2017.txt') # is COCO
|
||||
self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000))
|
||||
self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO
|
||||
self.names = model.names
|
||||
self.nc = len(model.names)
|
||||
self.metrics.names = self.names
|
||||
self.metrics.plot = self.args.plots
|
||||
self.confusion_matrix = ConfusionMatrix(nc=self.nc)
|
||||
self.seen = 0
|
||||
self.jdict = []
|
||||
self.stats = []
|
||||
|
||||
def get_desc(self):
|
||||
"""Return a formatted string summarizing class metrics of YOLO model."""
|
||||
return ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)')
|
||||
|
||||
def postprocess(self, preds):
|
||||
"""Apply Non-maximum suppression to prediction outputs."""
|
||||
return ops.non_max_suppression(preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
labels=self.lb,
|
||||
multi_label=True,
|
||||
agnostic=self.args.single_cls,
|
||||
max_det=self.args.max_det)
|
||||
|
||||
def update_metrics(self, preds, batch):
|
||||
"""Metrics."""
|
||||
for si, pred in enumerate(preds):
|
||||
idx = batch['batch_idx'] == si
|
||||
cls = batch['cls'][idx]
|
||||
bbox = batch['bboxes'][idx]
|
||||
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
|
||||
shape = batch['ori_shape'][si]
|
||||
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
||||
self.seen += 1
|
||||
|
||||
if npr == 0:
|
||||
if nl:
|
||||
self.stats.append((correct_bboxes, *torch.zeros((2, 0), device=self.device), cls.squeeze(-1)))
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
|
||||
continue
|
||||
|
||||
# Predictions
|
||||
if self.args.single_cls:
|
||||
pred[:, 5] = 0
|
||||
predn = pred.clone()
|
||||
ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
|
||||
ratio_pad=batch['ratio_pad'][si]) # native-space pred
|
||||
|
||||
# Evaluate
|
||||
if nl:
|
||||
height, width = batch['img'].shape[2:]
|
||||
tbox = ops.xywh2xyxy(bbox) * torch.tensor(
|
||||
(width, height, width, height), device=self.device) # target boxes
|
||||
ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
|
||||
ratio_pad=batch['ratio_pad'][si]) # native-space labels
|
||||
labelsn = torch.cat((cls, tbox), 1) # native-space labels
|
||||
correct_bboxes = self._process_batch(predn, labelsn)
|
||||
# TODO: maybe remove these `self.` arguments as they already are member variable
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.process_batch(predn, labelsn)
|
||||
self.stats.append((correct_bboxes, pred[:, 4], pred[:, 5], cls.squeeze(-1))) # (conf, pcls, tcls)
|
||||
|
||||
# Save
|
||||
if self.args.save_json:
|
||||
self.pred_to_json(predn, batch['im_file'][si])
|
||||
if self.args.save_txt:
|
||||
file = self.save_dir / 'labels' / f'{Path(batch["im_file"][si]).stem}.txt'
|
||||
self.save_one_txt(predn, self.args.save_conf, shape, file)
|
||||
|
||||
def finalize_metrics(self, *args, **kwargs):
|
||||
"""Set final values for metrics speed and confusion matrix."""
|
||||
self.metrics.speed = self.speed
|
||||
self.metrics.confusion_matrix = self.confusion_matrix
|
||||
|
||||
def get_stats(self):
|
||||
"""Returns metrics statistics and results dictionary."""
|
||||
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
|
||||
if len(stats) and stats[0].any():
|
||||
self.metrics.process(*stats)
|
||||
self.nt_per_class = np.bincount(stats[-1].astype(int), minlength=self.nc) # number of targets per class
|
||||
return self.metrics.results_dict
|
||||
|
||||
def print_results(self):
|
||||
"""Prints training/validation set metrics per class."""
|
||||
pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metrics.keys) # print format
|
||||
LOGGER.info(pf % ('all', self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
|
||||
if self.nt_per_class.sum() == 0:
|
||||
LOGGER.warning(
|
||||
f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels')
|
||||
|
||||
# Print results per class
|
||||
if self.args.verbose and not self.training and self.nc > 1 and len(self.stats):
|
||||
for i, c in enumerate(self.metrics.ap_class_index):
|
||||
LOGGER.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
|
||||
|
||||
if self.args.plots:
|
||||
for normalize in True, False:
|
||||
self.confusion_matrix.plot(save_dir=self.save_dir,
|
||||
names=self.names.values(),
|
||||
normalize=normalize,
|
||||
on_plot=self.on_plot)
|
||||
|
||||
def _process_batch(self, detections, labels):
|
||||
"""
|
||||
Return correct prediction matrix
|
||||
Arguments:
|
||||
detections (array[N, 6]), x1, y1, x2, y2, conf, class
|
||||
labels (array[M, 5]), class, x1, y1, x2, y2
|
||||
Returns:
|
||||
correct (array[N, 10]), for 10 IoU levels
|
||||
"""
|
||||
iou = box_iou(labels[:, 1:], detections[:, :4])
|
||||
correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool)
|
||||
correct_class = labels[:, 0:1] == detections[:, 5]
|
||||
for i in range(len(self.iouv)):
|
||||
x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match
|
||||
if x[0].shape[0]:
|
||||
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]),
|
||||
1).cpu().numpy() # [label, detect, iou]
|
||||
if x[0].shape[0] > 1:
|
||||
matches = matches[matches[:, 2].argsort()[::-1]]
|
||||
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
||||
# matches = matches[matches[:, 2].argsort()[::-1]]
|
||||
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
||||
correct[matches[:, 1].astype(int), i] = True
|
||||
return torch.tensor(correct, dtype=torch.bool, device=detections.device)
|
||||
|
||||
def build_dataset(self, img_path, mode='val', batch=None):
|
||||
"""Build YOLO Dataset
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the folder containing images.
|
||||
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
||||
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
|
||||
"""
|
||||
gs = max(int(de_parallel(self.model).stride if self.model else 0), 32)
|
||||
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=gs)
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size):
|
||||
"""Construct and return dataloader."""
|
||||
dataset = self.build_dataset(dataset_path, batch=batch_size, mode='val')
|
||||
return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1) # return dataloader
|
||||
|
||||
def plot_val_samples(self, batch, ni):
|
||||
"""Plot validation image samples."""
|
||||
plot_images(batch['img'],
|
||||
batch['batch_idx'],
|
||||
batch['cls'].squeeze(-1),
|
||||
batch['bboxes'],
|
||||
paths=batch['im_file'],
|
||||
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
|
||||
names=self.names,
|
||||
on_plot=self.on_plot)
|
||||
|
||||
def plot_predictions(self, batch, preds, ni):
|
||||
"""Plots predicted bounding boxes on input images and saves the result."""
|
||||
plot_images(batch['img'],
|
||||
*output_to_target(preds, max_det=self.args.max_det),
|
||||
paths=batch['im_file'],
|
||||
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
|
||||
names=self.names,
|
||||
on_plot=self.on_plot) # pred
|
||||
|
||||
def save_one_txt(self, predn, save_conf, shape, file):
|
||||
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""
|
||||
gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh
|
||||
for *xyxy, conf, cls in predn.tolist():
|
||||
xywh = (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
|
||||
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
|
||||
with open(file, 'a') as f:
|
||||
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
||||
|
||||
def pred_to_json(self, predn, filename):
|
||||
"""Serialize YOLO predictions to COCO json format."""
|
||||
stem = Path(filename).stem
|
||||
image_id = int(stem) if stem.isnumeric() else stem
|
||||
box = ops.xyxy2xywh(predn[:, :4]) # xywh
|
||||
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
|
||||
for p, b in zip(predn.tolist(), box.tolist()):
|
||||
self.jdict.append({
|
||||
'image_id': image_id,
|
||||
'category_id': self.class_map[int(p[5])],
|
||||
'bbox': [round(x, 3) for x in b],
|
||||
'score': round(p[4], 5)})
|
||||
|
||||
def eval_json(self, stats):
|
||||
"""Evaluates YOLO output in JSON format and returns performance statistics."""
|
||||
if self.args.save_json and self.is_coco and len(self.jdict):
|
||||
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
|
||||
pred_json = self.save_dir / 'predictions.json' # predictions
|
||||
LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
|
||||
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
||||
check_requirements('pycocotools>=2.0.6')
|
||||
from pycocotools.coco import COCO # noqa
|
||||
from pycocotools.cocoeval import COCOeval # noqa
|
||||
|
||||
for x in anno_json, pred_json:
|
||||
assert x.is_file(), f'{x} file not found'
|
||||
anno = COCO(str(anno_json)) # init annotations api
|
||||
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
||||
eval = COCOeval(anno, pred, 'bbox')
|
||||
if self.is_coco:
|
||||
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
|
||||
eval.evaluate()
|
||||
eval.accumulate()
|
||||
eval.summarize()
|
||||
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'pycocotools unable to run: {e}')
|
||||
return stats
|
||||
|
||||
|
||||
def val(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Validate trained YOLO model on validation dataset."""
|
||||
model = cfg.model or 'yolov8n.pt'
|
||||
data = cfg.data or 'coco128.yaml'
|
||||
|
||||
args = dict(model=model, data=data)
|
||||
if use_python:
|
||||
from ultralytics import YOLO
|
||||
YOLO(model).val(**args)
|
||||
else:
|
||||
validator = DetectionValidator(args=args)
|
||||
validator(model=args['model'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
val()
|
@ -1,7 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from .predict import PosePredictor, predict
|
||||
from .train import PoseTrainer, train
|
||||
from .val import PoseValidator, val
|
||||
|
||||
__all__ = 'PoseTrainer', 'train', 'PoseValidator', 'val', 'PosePredictor', 'predict'
|
@ -1,58 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from ultralytics.yolo.engine.results import Results
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops
|
||||
from ultralytics.yolo.v8.detect.predict import DetectionPredictor
|
||||
|
||||
|
||||
class PosePredictor(DetectionPredictor):
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.task = 'pose'
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Return detection results for a given input image or list of images."""
|
||||
preds = ops.non_max_suppression(preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
agnostic=self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
classes=self.args.classes,
|
||||
nc=len(self.model.names))
|
||||
|
||||
results = []
|
||||
for i, pred in enumerate(preds):
|
||||
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
|
||||
shape = orig_img.shape
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
|
||||
pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
|
||||
pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, shape)
|
||||
path = self.batch[0]
|
||||
img_path = path[i] if isinstance(path, list) else path
|
||||
results.append(
|
||||
Results(orig_img=orig_img,
|
||||
path=img_path,
|
||||
names=self.model.names,
|
||||
boxes=pred[:, :6],
|
||||
keypoints=pred_kpts))
|
||||
return results
|
||||
|
||||
|
||||
def predict(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Runs YOLO to predict objects in an image or video."""
|
||||
model = cfg.model or 'yolov8n-pose.pt'
|
||||
source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
|
||||
else 'https://ultralytics.com/images/bus.jpg'
|
||||
|
||||
args = dict(model=model, source=source)
|
||||
if use_python:
|
||||
from ultralytics import YOLO
|
||||
YOLO(model)(**args)
|
||||
else:
|
||||
predictor = PosePredictor(overrides=args)
|
||||
predictor.predict_cli()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
predict()
|
@ -1,77 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from copy import copy
|
||||
|
||||
from ultralytics.nn.tasks import PoseModel
|
||||
from ultralytics.yolo import v8
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG
|
||||
from ultralytics.yolo.utils.plotting import plot_images, plot_results
|
||||
|
||||
|
||||
# BaseTrainer python usage
|
||||
class PoseTrainer(v8.detect.DetectionTrainer):
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""Initialize a PoseTrainer object with specified configurations and overrides."""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides['task'] = 'pose'
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""Get pose estimation model with specified configuration and weights."""
|
||||
model = PoseModel(cfg, ch=3, nc=self.data['nc'], data_kpt_shape=self.data['kpt_shape'], verbose=verbose)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
return model
|
||||
|
||||
def set_model_attributes(self):
|
||||
"""Sets keypoints shape attribute of PoseModel."""
|
||||
super().set_model_attributes()
|
||||
self.model.kpt_shape = self.data['kpt_shape']
|
||||
|
||||
def get_validator(self):
|
||||
"""Returns an instance of the PoseValidator class for validation."""
|
||||
self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss'
|
||||
return v8.pose.PoseValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
||||
|
||||
def plot_training_samples(self, batch, ni):
|
||||
"""Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
|
||||
images = batch['img']
|
||||
kpts = batch['keypoints']
|
||||
cls = batch['cls'].squeeze(-1)
|
||||
bboxes = batch['bboxes']
|
||||
paths = batch['im_file']
|
||||
batch_idx = batch['batch_idx']
|
||||
plot_images(images,
|
||||
batch_idx,
|
||||
cls,
|
||||
bboxes,
|
||||
kpts=kpts,
|
||||
paths=paths,
|
||||
fname=self.save_dir / f'train_batch{ni}.jpg',
|
||||
on_plot=self.on_plot)
|
||||
|
||||
def plot_metrics(self):
|
||||
"""Plots training/val metrics."""
|
||||
plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png
|
||||
|
||||
|
||||
def train(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Train the YOLO model on the given data and device."""
|
||||
model = cfg.model or 'yolov8n-pose.yaml'
|
||||
data = cfg.data or 'coco8-pose.yaml'
|
||||
device = cfg.device if cfg.device is not None else ''
|
||||
|
||||
args = dict(model=model, data=data, device=device)
|
||||
if use_python:
|
||||
from ultralytics import YOLO
|
||||
YOLO(model).train(**args)
|
||||
else:
|
||||
trainer = PoseTrainer(overrides=args)
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
@ -1,224 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, ops
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
from ultralytics.yolo.utils.metrics import OKS_SIGMA, PoseMetrics, box_iou, kpt_iou
|
||||
from ultralytics.yolo.utils.plotting import output_to_target, plot_images
|
||||
from ultralytics.yolo.v8.detect import DetectionValidator
|
||||
|
||||
|
||||
class PoseValidator(DetectionValidator):
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||
"""Initialize a 'PoseValidator' object with custom parameters and assigned attributes."""
|
||||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||
self.args.task = 'pose'
|
||||
self.metrics = PoseMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
||||
|
||||
def preprocess(self, batch):
|
||||
"""Preprocesses the batch by converting the 'keypoints' data into a float and moving it to the device."""
|
||||
batch = super().preprocess(batch)
|
||||
batch['keypoints'] = batch['keypoints'].to(self.device).float()
|
||||
return batch
|
||||
|
||||
def get_desc(self):
|
||||
"""Returns description of evaluation metrics in string format."""
|
||||
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Pose(P',
|
||||
'R', 'mAP50', 'mAP50-95)')
|
||||
|
||||
def postprocess(self, preds):
|
||||
"""Apply non-maximum suppression and return detections with high confidence scores."""
|
||||
return ops.non_max_suppression(preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
labels=self.lb,
|
||||
multi_label=True,
|
||||
agnostic=self.args.single_cls,
|
||||
max_det=self.args.max_det,
|
||||
nc=self.nc)
|
||||
|
||||
def init_metrics(self, model):
|
||||
"""Initiate pose estimation metrics for YOLO model."""
|
||||
super().init_metrics(model)
|
||||
self.kpt_shape = self.data['kpt_shape']
|
||||
is_pose = self.kpt_shape == [17, 3]
|
||||
nkpt = self.kpt_shape[0]
|
||||
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
|
||||
|
||||
def update_metrics(self, preds, batch):
|
||||
"""Metrics."""
|
||||
for si, pred in enumerate(preds):
|
||||
idx = batch['batch_idx'] == si
|
||||
cls = batch['cls'][idx]
|
||||
bbox = batch['bboxes'][idx]
|
||||
kpts = batch['keypoints'][idx]
|
||||
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
|
||||
nk = kpts.shape[1] # number of keypoints
|
||||
shape = batch['ori_shape'][si]
|
||||
correct_kpts = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
||||
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
||||
self.seen += 1
|
||||
|
||||
if npr == 0:
|
||||
if nl:
|
||||
self.stats.append((correct_bboxes, correct_kpts, *torch.zeros(
|
||||
(2, 0), device=self.device), cls.squeeze(-1)))
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
|
||||
continue
|
||||
|
||||
# Predictions
|
||||
if self.args.single_cls:
|
||||
pred[:, 5] = 0
|
||||
predn = pred.clone()
|
||||
ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
|
||||
ratio_pad=batch['ratio_pad'][si]) # native-space pred
|
||||
pred_kpts = predn[:, 6:].view(npr, nk, -1)
|
||||
ops.scale_coords(batch['img'][si].shape[1:], pred_kpts, shape, ratio_pad=batch['ratio_pad'][si])
|
||||
|
||||
# Evaluate
|
||||
if nl:
|
||||
height, width = batch['img'].shape[2:]
|
||||
tbox = ops.xywh2xyxy(bbox) * torch.tensor(
|
||||
(width, height, width, height), device=self.device) # target boxes
|
||||
ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
|
||||
ratio_pad=batch['ratio_pad'][si]) # native-space labels
|
||||
tkpts = kpts.clone()
|
||||
tkpts[..., 0] *= width
|
||||
tkpts[..., 1] *= height
|
||||
tkpts = ops.scale_coords(batch['img'][si].shape[1:], tkpts, shape, ratio_pad=batch['ratio_pad'][si])
|
||||
labelsn = torch.cat((cls, tbox), 1) # native-space labels
|
||||
correct_bboxes = self._process_batch(predn[:, :6], labelsn)
|
||||
correct_kpts = self._process_batch(predn[:, :6], labelsn, pred_kpts, tkpts)
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.process_batch(predn, labelsn)
|
||||
|
||||
# Append correct_masks, correct_boxes, pconf, pcls, tcls
|
||||
self.stats.append((correct_bboxes, correct_kpts, pred[:, 4], pred[:, 5], cls.squeeze(-1)))
|
||||
|
||||
# Save
|
||||
if self.args.save_json:
|
||||
self.pred_to_json(predn, batch['im_file'][si])
|
||||
# if self.args.save_txt:
|
||||
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
|
||||
|
||||
def _process_batch(self, detections, labels, pred_kpts=None, gt_kpts=None):
|
||||
"""
|
||||
Return correct prediction matrix
|
||||
Arguments:
|
||||
detections (array[N, 6]), x1, y1, x2, y2, conf, class
|
||||
labels (array[M, 5]), class, x1, y1, x2, y2
|
||||
pred_kpts (array[N, 51]), 51 = 17 * 3
|
||||
gt_kpts (array[N, 51])
|
||||
Returns:
|
||||
correct (array[N, 10]), for 10 IoU levels
|
||||
"""
|
||||
if pred_kpts is not None and gt_kpts is not None:
|
||||
# `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
|
||||
area = ops.xyxy2xywh(labels[:, 1:])[:, 2:].prod(1) * 0.53
|
||||
iou = kpt_iou(gt_kpts, pred_kpts, sigma=self.sigma, area=area)
|
||||
else: # boxes
|
||||
iou = box_iou(labels[:, 1:], detections[:, :4])
|
||||
|
||||
correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool)
|
||||
correct_class = labels[:, 0:1] == detections[:, 5]
|
||||
for i in range(len(self.iouv)):
|
||||
x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match
|
||||
if x[0].shape[0]:
|
||||
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]),
|
||||
1).cpu().numpy() # [label, detect, iou]
|
||||
if x[0].shape[0] > 1:
|
||||
matches = matches[matches[:, 2].argsort()[::-1]]
|
||||
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
||||
# matches = matches[matches[:, 2].argsort()[::-1]]
|
||||
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
||||
correct[matches[:, 1].astype(int), i] = True
|
||||
return torch.tensor(correct, dtype=torch.bool, device=detections.device)
|
||||
|
||||
def plot_val_samples(self, batch, ni):
|
||||
"""Plots and saves validation set samples with predicted bounding boxes and keypoints."""
|
||||
plot_images(batch['img'],
|
||||
batch['batch_idx'],
|
||||
batch['cls'].squeeze(-1),
|
||||
batch['bboxes'],
|
||||
kpts=batch['keypoints'],
|
||||
paths=batch['im_file'],
|
||||
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
|
||||
names=self.names,
|
||||
on_plot=self.on_plot)
|
||||
|
||||
def plot_predictions(self, batch, preds, ni):
|
||||
"""Plots predictions for YOLO model."""
|
||||
pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
|
||||
plot_images(batch['img'],
|
||||
*output_to_target(preds, max_det=self.args.max_det),
|
||||
kpts=pred_kpts,
|
||||
paths=batch['im_file'],
|
||||
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
|
||||
names=self.names,
|
||||
on_plot=self.on_plot) # pred
|
||||
|
||||
def pred_to_json(self, predn, filename):
|
||||
"""Converts YOLO predictions to COCO JSON format."""
|
||||
stem = Path(filename).stem
|
||||
image_id = int(stem) if stem.isnumeric() else stem
|
||||
box = ops.xyxy2xywh(predn[:, :4]) # xywh
|
||||
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
|
||||
for p, b in zip(predn.tolist(), box.tolist()):
|
||||
self.jdict.append({
|
||||
'image_id': image_id,
|
||||
'category_id': self.class_map[int(p[5])],
|
||||
'bbox': [round(x, 3) for x in b],
|
||||
'keypoints': p[6:],
|
||||
'score': round(p[4], 5)})
|
||||
|
||||
def eval_json(self, stats):
|
||||
"""Evaluates object detection model using COCO JSON format."""
|
||||
if self.args.save_json and self.is_coco and len(self.jdict):
|
||||
anno_json = self.data['path'] / 'annotations/person_keypoints_val2017.json' # annotations
|
||||
pred_json = self.save_dir / 'predictions.json' # predictions
|
||||
LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
|
||||
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
||||
check_requirements('pycocotools>=2.0.6')
|
||||
from pycocotools.coco import COCO # noqa
|
||||
from pycocotools.cocoeval import COCOeval # noqa
|
||||
|
||||
for x in anno_json, pred_json:
|
||||
assert x.is_file(), f'{x} file not found'
|
||||
anno = COCO(str(anno_json)) # init annotations api
|
||||
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
||||
for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'keypoints')]):
|
||||
if self.is_coco:
|
||||
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
|
||||
eval.evaluate()
|
||||
eval.accumulate()
|
||||
eval.summarize()
|
||||
idx = i * 4 + 2
|
||||
stats[self.metrics.keys[idx + 1]], stats[
|
||||
self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'pycocotools unable to run: {e}')
|
||||
return stats
|
||||
|
||||
|
||||
def val(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Performs validation on YOLO model using given data."""
|
||||
model = cfg.model or 'yolov8n-pose.pt'
|
||||
data = cfg.data or 'coco8-pose.yaml'
|
||||
|
||||
args = dict(model=model, data=data)
|
||||
if use_python:
|
||||
from ultralytics import YOLO
|
||||
YOLO(model).val(**args)
|
||||
else:
|
||||
validator = PoseValidator(args=args)
|
||||
validator(model=args['model'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
val()
|
@ -1,7 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from .predict import SegmentationPredictor, predict
|
||||
from .train import SegmentationTrainer, train
|
||||
from .val import SegmentationValidator, val
|
||||
|
||||
__all__ = 'SegmentationPredictor', 'predict', 'SegmentationTrainer', 'train', 'SegmentationValidator', 'val'
|
@ -1,63 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.engine.results import Results
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops
|
||||
from ultralytics.yolo.v8.detect.predict import DetectionPredictor
|
||||
|
||||
|
||||
class SegmentationPredictor(DetectionPredictor):
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.task = 'segment'
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""TODO: filter by classes."""
|
||||
p = ops.non_max_suppression(preds[0],
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
agnostic=self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
nc=len(self.model.names),
|
||||
classes=self.args.classes)
|
||||
results = []
|
||||
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
|
||||
for i, pred in enumerate(p):
|
||||
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
|
||||
path = self.batch[0]
|
||||
img_path = path[i] if isinstance(path, list) else path
|
||||
if not len(pred): # save empty boxes
|
||||
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6]))
|
||||
continue
|
||||
if self.args.retina_masks:
|
||||
if not isinstance(orig_imgs, torch.Tensor):
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||
masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
|
||||
else:
|
||||
masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
|
||||
if not isinstance(orig_imgs, torch.Tensor):
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||
results.append(
|
||||
Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
|
||||
return results
|
||||
|
||||
|
||||
def predict(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Runs YOLO object detection on an image or video source."""
|
||||
model = cfg.model or 'yolov8n-seg.pt'
|
||||
source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
|
||||
else 'https://ultralytics.com/images/bus.jpg'
|
||||
|
||||
args = dict(model=model, source=source)
|
||||
if use_python:
|
||||
from ultralytics import YOLO
|
||||
YOLO(model)(**args)
|
||||
else:
|
||||
predictor = SegmentationPredictor(overrides=args)
|
||||
predictor.predict_cli()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
predict()
|
@ -1,65 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
from copy import copy
|
||||
|
||||
from ultralytics.nn.tasks import SegmentationModel
|
||||
from ultralytics.yolo import v8
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, RANK
|
||||
from ultralytics.yolo.utils.plotting import plot_images, plot_results
|
||||
|
||||
|
||||
# BaseTrainer python usage
|
||||
class SegmentationTrainer(v8.detect.DetectionTrainer):
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""Initialize a SegmentationTrainer object with given arguments."""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides['task'] = 'segment'
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""Return SegmentationModel initialized with specified config and weights."""
|
||||
model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
return model
|
||||
|
||||
def get_validator(self):
|
||||
"""Return an instance of SegmentationValidator for validation of YOLO model."""
|
||||
self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss'
|
||||
return v8.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
||||
|
||||
def plot_training_samples(self, batch, ni):
|
||||
"""Creates a plot of training sample images with labels and box coordinates."""
|
||||
plot_images(batch['img'],
|
||||
batch['batch_idx'],
|
||||
batch['cls'].squeeze(-1),
|
||||
batch['bboxes'],
|
||||
batch['masks'],
|
||||
paths=batch['im_file'],
|
||||
fname=self.save_dir / f'train_batch{ni}.jpg',
|
||||
on_plot=self.on_plot)
|
||||
|
||||
def plot_metrics(self):
|
||||
"""Plots training/val metrics."""
|
||||
plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png
|
||||
|
||||
|
||||
def train(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Train a YOLO segmentation model based on passed arguments."""
|
||||
model = cfg.model or 'yolov8n-seg.pt'
|
||||
data = cfg.data or 'coco128-seg.yaml' # or yolo.ClassificationDataset("mnist")
|
||||
device = cfg.device if cfg.device is not None else ''
|
||||
|
||||
args = dict(model=model, data=data, device=device)
|
||||
if use_python:
|
||||
from ultralytics import YOLO
|
||||
YOLO(model).train(**args)
|
||||
else:
|
||||
trainer = SegmentationTrainer(overrides=args)
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
@ -1,262 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, NUM_THREADS, ops
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
from ultralytics.yolo.utils.metrics import SegmentMetrics, box_iou, mask_iou
|
||||
from ultralytics.yolo.utils.plotting import output_to_target, plot_images
|
||||
from ultralytics.yolo.v8.detect import DetectionValidator
|
||||
|
||||
|
||||
class SegmentationValidator(DetectionValidator):
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics."""
|
||||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||
self.args.task = 'segment'
|
||||
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
||||
|
||||
def preprocess(self, batch):
|
||||
"""Preprocesses batch by converting masks to float and sending to device."""
|
||||
batch = super().preprocess(batch)
|
||||
batch['masks'] = batch['masks'].to(self.device).float()
|
||||
return batch
|
||||
|
||||
def init_metrics(self, model):
|
||||
"""Initialize metrics and select mask processing function based on save_json flag."""
|
||||
super().init_metrics(model)
|
||||
self.plot_masks = []
|
||||
if self.args.save_json:
|
||||
check_requirements('pycocotools>=2.0.6')
|
||||
self.process = ops.process_mask_upsample # more accurate
|
||||
else:
|
||||
self.process = ops.process_mask # faster
|
||||
|
||||
def get_desc(self):
|
||||
"""Return a formatted description of evaluation metrics."""
|
||||
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Mask(P',
|
||||
'R', 'mAP50', 'mAP50-95)')
|
||||
|
||||
def postprocess(self, preds):
|
||||
"""Postprocesses YOLO predictions and returns output detections with proto."""
|
||||
p = ops.non_max_suppression(preds[0],
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
labels=self.lb,
|
||||
multi_label=True,
|
||||
agnostic=self.args.single_cls,
|
||||
max_det=self.args.max_det,
|
||||
nc=self.nc)
|
||||
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
|
||||
return p, proto
|
||||
|
||||
def update_metrics(self, preds, batch):
|
||||
"""Metrics."""
|
||||
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
|
||||
idx = batch['batch_idx'] == si
|
||||
cls = batch['cls'][idx]
|
||||
bbox = batch['bboxes'][idx]
|
||||
nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions
|
||||
shape = batch['ori_shape'][si]
|
||||
correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
||||
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
||||
self.seen += 1
|
||||
|
||||
if npr == 0:
|
||||
if nl:
|
||||
self.stats.append((correct_bboxes, correct_masks, *torch.zeros(
|
||||
(2, 0), device=self.device), cls.squeeze(-1)))
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1))
|
||||
continue
|
||||
|
||||
# Masks
|
||||
midx = [si] if self.args.overlap_mask else idx
|
||||
gt_masks = batch['masks'][midx]
|
||||
pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=batch['img'][si].shape[1:])
|
||||
|
||||
# Predictions
|
||||
if self.args.single_cls:
|
||||
pred[:, 5] = 0
|
||||
predn = pred.clone()
|
||||
ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape,
|
||||
ratio_pad=batch['ratio_pad'][si]) # native-space pred
|
||||
|
||||
# Evaluate
|
||||
if nl:
|
||||
height, width = batch['img'].shape[2:]
|
||||
tbox = ops.xywh2xyxy(bbox) * torch.tensor(
|
||||
(width, height, width, height), device=self.device) # target boxes
|
||||
ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape,
|
||||
ratio_pad=batch['ratio_pad'][si]) # native-space labels
|
||||
labelsn = torch.cat((cls, tbox), 1) # native-space labels
|
||||
correct_bboxes = self._process_batch(predn, labelsn)
|
||||
# TODO: maybe remove these `self.` arguments as they already are member variable
|
||||
correct_masks = self._process_batch(predn,
|
||||
labelsn,
|
||||
pred_masks,
|
||||
gt_masks,
|
||||
overlap=self.args.overlap_mask,
|
||||
masks=True)
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.process_batch(predn, labelsn)
|
||||
|
||||
# Append correct_masks, correct_boxes, pconf, pcls, tcls
|
||||
self.stats.append((correct_bboxes, correct_masks, pred[:, 4], pred[:, 5], cls.squeeze(-1)))
|
||||
|
||||
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
|
||||
if self.args.plots and self.batch_i < 3:
|
||||
self.plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
|
||||
|
||||
# Save
|
||||
if self.args.save_json:
|
||||
pred_masks = ops.scale_image(pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
|
||||
shape,
|
||||
ratio_pad=batch['ratio_pad'][si])
|
||||
self.pred_to_json(predn, batch['im_file'][si], pred_masks)
|
||||
# if self.args.save_txt:
|
||||
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
|
||||
|
||||
def finalize_metrics(self, *args, **kwargs):
|
||||
"""Sets speed and confusion matrix for evaluation metrics."""
|
||||
self.metrics.speed = self.speed
|
||||
self.metrics.confusion_matrix = self.confusion_matrix
|
||||
|
||||
def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False):
|
||||
"""
|
||||
Return correct prediction matrix
|
||||
Arguments:
|
||||
detections (array[N, 6]), x1, y1, x2, y2, conf, class
|
||||
labels (array[M, 5]), class, x1, y1, x2, y2
|
||||
Returns:
|
||||
correct (array[N, 10]), for 10 IoU levels
|
||||
"""
|
||||
if masks:
|
||||
if overlap:
|
||||
nl = len(labels)
|
||||
index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
|
||||
gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
|
||||
gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
|
||||
if gt_masks.shape[1:] != pred_masks.shape[1:]:
|
||||
gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode='bilinear', align_corners=False)[0]
|
||||
gt_masks = gt_masks.gt_(0.5)
|
||||
iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
|
||||
else: # boxes
|
||||
iou = box_iou(labels[:, 1:], detections[:, :4])
|
||||
|
||||
correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool)
|
||||
correct_class = labels[:, 0:1] == detections[:, 5]
|
||||
for i in range(len(self.iouv)):
|
||||
x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match
|
||||
if x[0].shape[0]:
|
||||
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]),
|
||||
1).cpu().numpy() # [label, detect, iou]
|
||||
if x[0].shape[0] > 1:
|
||||
matches = matches[matches[:, 2].argsort()[::-1]]
|
||||
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
||||
# matches = matches[matches[:, 2].argsort()[::-1]]
|
||||
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
||||
correct[matches[:, 1].astype(int), i] = True
|
||||
return torch.tensor(correct, dtype=torch.bool, device=detections.device)
|
||||
|
||||
def plot_val_samples(self, batch, ni):
|
||||
"""Plots validation samples with bounding box labels."""
|
||||
plot_images(batch['img'],
|
||||
batch['batch_idx'],
|
||||
batch['cls'].squeeze(-1),
|
||||
batch['bboxes'],
|
||||
batch['masks'],
|
||||
paths=batch['im_file'],
|
||||
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
|
||||
names=self.names,
|
||||
on_plot=self.on_plot)
|
||||
|
||||
def plot_predictions(self, batch, preds, ni):
|
||||
"""Plots batch predictions with masks and bounding boxes."""
|
||||
plot_images(
|
||||
batch['img'],
|
||||
*output_to_target(preds[0], max_det=15), # not set to self.args.max_det due to slow plotting speed
|
||||
torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
|
||||
paths=batch['im_file'],
|
||||
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
|
||||
names=self.names,
|
||||
on_plot=self.on_plot) # pred
|
||||
self.plot_masks.clear()
|
||||
|
||||
def pred_to_json(self, predn, filename, pred_masks):
|
||||
"""Save one JSON result."""
|
||||
# Example result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
|
||||
from pycocotools.mask import encode # noqa
|
||||
|
||||
def single_encode(x):
|
||||
"""Encode predicted masks as RLE and append results to jdict."""
|
||||
rle = encode(np.asarray(x[:, :, None], order='F', dtype='uint8'))[0]
|
||||
rle['counts'] = rle['counts'].decode('utf-8')
|
||||
return rle
|
||||
|
||||
stem = Path(filename).stem
|
||||
image_id = int(stem) if stem.isnumeric() else stem
|
||||
box = ops.xyxy2xywh(predn[:, :4]) # xywh
|
||||
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
|
||||
pred_masks = np.transpose(pred_masks, (2, 0, 1))
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
rles = pool.map(single_encode, pred_masks)
|
||||
for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())):
|
||||
self.jdict.append({
|
||||
'image_id': image_id,
|
||||
'category_id': self.class_map[int(p[5])],
|
||||
'bbox': [round(x, 3) for x in b],
|
||||
'score': round(p[4], 5),
|
||||
'segmentation': rles[i]})
|
||||
|
||||
def eval_json(self, stats):
|
||||
"""Return COCO-style object detection evaluation metrics."""
|
||||
if self.args.save_json and self.is_coco and len(self.jdict):
|
||||
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
|
||||
pred_json = self.save_dir / 'predictions.json' # predictions
|
||||
LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
|
||||
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
||||
check_requirements('pycocotools>=2.0.6')
|
||||
from pycocotools.coco import COCO # noqa
|
||||
from pycocotools.cocoeval import COCOeval # noqa
|
||||
|
||||
for x in anno_json, pred_json:
|
||||
assert x.is_file(), f'{x} file not found'
|
||||
anno = COCO(str(anno_json)) # init annotations api
|
||||
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
||||
for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm')]):
|
||||
if self.is_coco:
|
||||
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
|
||||
eval.evaluate()
|
||||
eval.accumulate()
|
||||
eval.summarize()
|
||||
idx = i * 4 + 2
|
||||
stats[self.metrics.keys[idx + 1]], stats[
|
||||
self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'pycocotools unable to run: {e}')
|
||||
return stats
|
||||
|
||||
|
||||
def val(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Validate trained YOLO model on validation data."""
|
||||
model = cfg.model or 'yolov8n-seg.pt'
|
||||
data = cfg.data or 'coco128-seg.yaml'
|
||||
|
||||
args = dict(model=model, data=data)
|
||||
if use_python:
|
||||
from ultralytics import YOLO
|
||||
YOLO(model).val(**args)
|
||||
else:
|
||||
validator = SegmentationValidator(args=args)
|
||||
validator(model=args['model'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
val()
|
Reference in New Issue
Block a user