Threadpool fixes and CLI improvements (#550)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
@ -1,5 +1,4 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
import argparse
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
@ -9,46 +8,39 @@ from types import SimpleNamespace
|
||||
from typing import Dict, Union
|
||||
|
||||
from ultralytics import __version__, yolo
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, PREFIX, checks, colorstr, print_settings, yaml_load
|
||||
|
||||
DIR = Path(__file__).parent
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, PREFIX, USER_CONFIG_DIR,
|
||||
IterableSimpleNamespace, checks, colorstr, yaml_load, yaml_print)
|
||||
|
||||
CLI_HELP_MSG = \
|
||||
"""
|
||||
YOLOv8 CLI Usage examples:
|
||||
YOLOv8 'yolo' CLI commands use the following syntax:
|
||||
|
||||
1. Install the ultralytics package:
|
||||
yolo TASK MODE ARGS
|
||||
|
||||
pip install ultralytics
|
||||
Where TASK (optional) is one of [detect, segment, classify]
|
||||
MODE (required) is one of [train, val, predict, export]
|
||||
ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
|
||||
See all ARGS at https://docs.ultralytics.com/cfg or with 'yolo cfg'
|
||||
|
||||
2. Train, Val, Predict and Export using 'yolo' commands:
|
||||
1. Train a detection model for 10 epochs with an initial learning_rate of 0.01
|
||||
yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
|
||||
|
||||
yolo TASK MODE ARGS
|
||||
2. Predict a YouTube video using a pretrained segmentation model at image size 320:
|
||||
yolo segment predict model=yolov8n-seg.pt source=https://youtu.be/Zgi9g1ksQHc imgsz=320
|
||||
|
||||
Where TASK (optional) is one of [detect, segment, classify]
|
||||
MODE (required) is one of [train, val, predict, export]
|
||||
ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
|
||||
For a full list of available ARGS see https://docs.ultralytics.com/cfg.
|
||||
3. Val a pretrained detection model at batch-size 1 and image size 640:
|
||||
yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
|
||||
|
||||
Train a detection model for 10 epochs with an initial learning_rate of 0.01
|
||||
yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
|
||||
|
||||
Predict a YouTube video using a pretrained segmentation model at image size 320:
|
||||
yolo segment predict model=yolov8n-seg.pt source=https://youtu.be/Zgi9g1ksQHc imgsz=320
|
||||
|
||||
Validate a pretrained detection model at batch-size 1 and image size 640:
|
||||
yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
|
||||
|
||||
Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
|
||||
yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
|
||||
|
||||
3. Run special commands:
|
||||
4. Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
|
||||
yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
|
||||
|
||||
5. Run special commands:
|
||||
yolo help
|
||||
yolo checks
|
||||
yolo version
|
||||
yolo settings
|
||||
yolo copy-cfg
|
||||
yolo cfg
|
||||
|
||||
Docs: https://docs.ultralytics.com/cli
|
||||
Community: https://community.ultralytics.com
|
||||
@ -56,15 +48,6 @@ CLI_HELP_MSG = \
|
||||
"""
|
||||
|
||||
|
||||
class UltralyticsCFG(SimpleNamespace):
|
||||
"""
|
||||
UltralyticsCFG iterable SimpleNamespace class to allow SimpleNamespace to be used with dict() and in for loops
|
||||
"""
|
||||
|
||||
def __iter__(self):
|
||||
return iter(vars(self).items())
|
||||
|
||||
|
||||
def cfg2dict(cfg):
|
||||
"""
|
||||
Convert a configuration object to a dictionary.
|
||||
@ -104,7 +87,7 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace], overrides: Dict = None
|
||||
cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)
|
||||
|
||||
# Return instance
|
||||
return UltralyticsCFG(**cfg)
|
||||
return IterableSimpleNamespace(**cfg)
|
||||
|
||||
|
||||
def check_cfg_mismatch(base: Dict, custom: Dict):
|
||||
@ -118,12 +101,19 @@ def check_cfg_mismatch(base: Dict, custom: Dict):
|
||||
"""
|
||||
base, custom = (set(x.keys()) for x in (base, custom))
|
||||
mismatched = [x for x in custom if x not in base]
|
||||
for option in mismatched:
|
||||
LOGGER.info(f"{colorstr(option)} is not a valid key. Similar keys: {get_close_matches(option, base, 3, 0.6)}")
|
||||
if mismatched:
|
||||
for x in mismatched:
|
||||
matches = get_close_matches(x, base, 3, 0.6)
|
||||
match_str = f"Similar arguments are {matches}." if matches else 'There are no similar arguments.'
|
||||
LOGGER.warning(f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}")
|
||||
LOGGER.warning(CLI_HELP_MSG)
|
||||
sys.exit()
|
||||
|
||||
|
||||
def argument_error(arg):
|
||||
return SyntaxError(f"'{arg}' is not a valid YOLO argument.\n{CLI_HELP_MSG}")
|
||||
|
||||
|
||||
def entrypoint(debug=False):
|
||||
"""
|
||||
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
|
||||
@ -139,67 +129,61 @@ def entrypoint(debug=False):
|
||||
It uses the package's default cfg and initializes it using the passed overrides.
|
||||
Then it calls the CLI function with the composed cfg
|
||||
"""
|
||||
if debug:
|
||||
args = ['train', 'predict', 'model=yolov8n.pt'] # for testing
|
||||
else:
|
||||
if len(sys.argv) == 1: # no arguments passed
|
||||
LOGGER.info(CLI_HELP_MSG)
|
||||
return
|
||||
|
||||
parser = argparse.ArgumentParser(description='YOLO parser')
|
||||
parser.add_argument('args', type=str, nargs='+', help='YOLO args')
|
||||
args = parser.parse_args().args
|
||||
args = re.sub(r'\s*=\s*', '=', ' '.join(args)).split(' ') # remove whitespaces around = sign
|
||||
args = ['train', 'predict', 'model=yolov8n.pt'] if debug else sys.argv[1:]
|
||||
if not args: # no arguments passed
|
||||
LOGGER.info(CLI_HELP_MSG)
|
||||
return
|
||||
|
||||
tasks = 'detect', 'segment', 'classify'
|
||||
modes = 'train', 'val', 'predict', 'export'
|
||||
special_modes = {
|
||||
special = {
|
||||
'help': lambda: LOGGER.info(CLI_HELP_MSG),
|
||||
'checks': checks.check_yolo,
|
||||
'version': lambda: LOGGER.info(__version__),
|
||||
'settings': print_settings,
|
||||
'settings': lambda: yaml_print(USER_CONFIG_DIR / 'settings.yaml'),
|
||||
'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
|
||||
'copy-cfg': copy_default_config}
|
||||
|
||||
overrides = {} # basic overrides, i.e. imgsz=320
|
||||
defaults = yaml_load(DEFAULT_CFG_PATH)
|
||||
for a in args:
|
||||
if '=' in a:
|
||||
if a.startswith('cfg='): # custom.yaml passed
|
||||
custom_config = Path(a.split('=')[-1])
|
||||
LOGGER.info(f"{PREFIX}Overriding {DEFAULT_CFG_PATH} with {custom_config}")
|
||||
overrides = {k: v for k, v in yaml_load(custom_config).items() if k not in {'cfg'}}
|
||||
else:
|
||||
try:
|
||||
re.sub(r' *= *', '=', a) # remove spaces around equals sign
|
||||
k, v = a.split('=')
|
||||
try:
|
||||
if k == 'device': # special DDP handling, i.e. device='0,1,2,3'
|
||||
v = v.replace('[', '').replace(']', '') # handle device=[0,1,2,3]
|
||||
v = v.replace(" ", "") # handle device=[0, 1, 2, 3]
|
||||
v = v.replace('\\', '') # handle device=\'0,1,2,3\'
|
||||
overrides[k] = v
|
||||
else:
|
||||
overrides[k] = eval(v) # convert strings to integers, floats, bools, etc.
|
||||
except (NameError, SyntaxError):
|
||||
if k == 'cfg': # custom.yaml passed
|
||||
LOGGER.info(f"{PREFIX}Overriding {DEFAULT_CFG_PATH} with {v}")
|
||||
overrides = {k: val for k, val in yaml_load(v).items() if k != 'cfg'}
|
||||
else:
|
||||
if v.isnumeric():
|
||||
v = eval(v)
|
||||
elif v.lower() == 'none':
|
||||
v = None
|
||||
elif v.lower() == 'true':
|
||||
v = True
|
||||
elif v.lower() == 'false':
|
||||
v = False
|
||||
elif ',' in v:
|
||||
v = eval(v)
|
||||
overrides[k] = v
|
||||
except (NameError, SyntaxError, ValueError) as e:
|
||||
raise argument_error(a) from e
|
||||
|
||||
elif a in tasks:
|
||||
overrides['task'] = a
|
||||
elif a in modes:
|
||||
overrides['mode'] = a
|
||||
elif a in special_modes:
|
||||
special_modes[a]()
|
||||
elif a in special:
|
||||
special[a]()
|
||||
return
|
||||
elif a in defaults and defaults[a] is False:
|
||||
elif a in DEFAULT_CFG_DICT and DEFAULT_CFG_DICT[a] is False:
|
||||
overrides[a] = True # auto-True for default False args, i.e. 'yolo show' sets show=True
|
||||
elif a in defaults:
|
||||
raise SyntaxError(f"'{a}' is a valid YOLO argument but is missing an '=' sign to set its value, "
|
||||
f"i.e. try '{a}={defaults[a]}'"
|
||||
f"\n{CLI_HELP_MSG}")
|
||||
elif a in DEFAULT_CFG_DICT:
|
||||
raise SyntaxError(f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign "
|
||||
f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}")
|
||||
else:
|
||||
raise SyntaxError(
|
||||
f"'{a}' is not a valid YOLO argument. For a full list of valid arguments see "
|
||||
f"https://github.com/ultralytics/ultralytics/blob/main/ultralytics/yolo/configs/default.yaml"
|
||||
f"\n{CLI_HELP_MSG}")
|
||||
raise argument_error(a)
|
||||
|
||||
cfg = get_cfg(defaults, overrides) # create CFG instance
|
||||
cfg = get_cfg(DEFAULT_CFG_DICT, overrides) # create CFG instance
|
||||
|
||||
# Mapping from task to module
|
||||
module = {"detect": yolo.v8.detect, "segment": yolo.v8.segment, "classify": yolo.v8.classify}.get(cfg.task)
|
||||
@ -223,8 +207,8 @@ def copy_default_config():
|
||||
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
|
||||
shutil.copy2(DEFAULT_CFG_PATH, new_file)
|
||||
LOGGER.info(f"{PREFIX}{DEFAULT_CFG_PATH} copied to {new_file}\n"
|
||||
f"Usage for running YOLO with this new custom cfg:\nyolo cfg={new_file} args...")
|
||||
f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
entrypoint()
|
||||
entrypoint(debug=True)
|
||||
|
@ -93,7 +93,7 @@ class BaseDataset(Dataset):
|
||||
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
|
||||
assert im_files, f"{self.prefix}No images found"
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}: {e}\n{HELP_URL}") from e
|
||||
raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e
|
||||
return im_files
|
||||
|
||||
def update_labels(self, include_class: Optional[list]):
|
||||
@ -134,16 +134,17 @@ class BaseDataset(Dataset):
|
||||
gb = 0 # Gigabytes of cached images
|
||||
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
|
||||
fcn = self.cache_images_to_disk if cache == "disk" else self.load_image
|
||||
results = ThreadPool(NUM_THREADS).imap(fcn, range(self.ni))
|
||||
pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
|
||||
for i, x in pbar:
|
||||
if cache == "disk":
|
||||
gb += self.npy_files[i].stat().st_size
|
||||
else: # 'ram'
|
||||
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
|
||||
gb += self.ims[i].nbytes
|
||||
pbar.desc = f"{self.prefix}Caching images ({gb / 1E9:.1f}GB {cache})"
|
||||
pbar.close()
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
results = pool.imap(fcn, range(self.ni))
|
||||
pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
|
||||
for i, x in pbar:
|
||||
if cache == "disk":
|
||||
gb += self.npy_files[i].stat().st_size
|
||||
else: # 'ram'
|
||||
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
|
||||
gb += self.ims[i].nbytes
|
||||
pbar.desc = f"{self.prefix}Caching images ({gb / 1E9:.1f}GB {cache})"
|
||||
pbar.close()
|
||||
|
||||
def cache_images_to_disk(self, i):
|
||||
# Saves an image as an *.npy file for faster loading
|
||||
|
@ -13,7 +13,7 @@ import random
|
||||
import shutil
|
||||
import time
|
||||
from itertools import repeat
|
||||
from multiprocessing.pool import Pool, ThreadPool
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from urllib.parse import urlparse
|
||||
@ -580,7 +580,7 @@ class LoadImagesAndLabels(Dataset):
|
||||
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
|
||||
self.im_hw0, self.im_hw = [None] * n, [None] * n
|
||||
fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
|
||||
with (Pool if n > 10000 else ThreadPool)(NUM_THREADS) as pool:
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
results = pool.imap(fcn, range(n))
|
||||
pbar = tqdm(enumerate(results), total=n, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
|
||||
for i, x in pbar:
|
||||
@ -1150,7 +1150,7 @@ class HUBDatasetStats():
|
||||
dataset = LoadImagesAndLabels(self.data[split]) # load dataset
|
||||
desc = f'{split} images'
|
||||
total = dataset.n
|
||||
with (Pool if total > 10000 else ThreadPool)(NUM_THREADS) as pool:
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
for _ in tqdm(pool.imap(self._hub_ops, dataset.im_files), total=total, desc=desc):
|
||||
pass
|
||||
print(f'Done. All images saved to {self.im_dir}')
|
||||
|
@ -185,9 +185,9 @@ def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
|
||||
return masks, index
|
||||
|
||||
|
||||
def check_dataset_yaml(data, autodownload=True):
|
||||
def check_dataset_yaml(dataset, autodownload=True):
|
||||
# Download, check and/or unzip dataset if not found locally
|
||||
data = check_file(data)
|
||||
data = check_file(dataset)
|
||||
|
||||
# Download (optional)
|
||||
extract_dir = ''
|
||||
@ -227,9 +227,11 @@ def check_dataset_yaml(data, autodownload=True):
|
||||
if val:
|
||||
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
|
||||
if not all(x.exists() for x in val):
|
||||
LOGGER.info('\nDataset not found ⚠️, missing paths %s' % [str(x) for x in val if not x.exists()])
|
||||
if not s or not autodownload:
|
||||
raise FileNotFoundError('Dataset not found ❌')
|
||||
msg = f"\nDataset '{dataset}' not found ⚠️, missing paths %s" % [str(x) for x in val if not x.exists()]
|
||||
if s and autodownload:
|
||||
LOGGER.warning(msg)
|
||||
else:
|
||||
raise FileNotFoundError(s)
|
||||
t = time.time()
|
||||
if s.startswith('http') and s.endswith('.zip'): # URL
|
||||
f = Path(s).name # filename
|
||||
|
@ -126,15 +126,15 @@ class Exporter:
|
||||
save_dir (Path): Directory to save results.
|
||||
"""
|
||||
|
||||
def __init__(self, config=DEFAULT_CFG, overrides=None):
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
|
||||
"""
|
||||
Initializes the Exporter class.
|
||||
|
||||
Args:
|
||||
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
"""
|
||||
self.args = get_cfg(config, overrides)
|
||||
self.args = get_cfg(cfg, overrides)
|
||||
self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
@ -151,7 +151,7 @@ class Exporter:
|
||||
# Load PyTorch model
|
||||
self.device = select_device('cpu' if self.args.device is None else self.args.device)
|
||||
if self.args.half:
|
||||
if self.device.type == 'cpu' and not coreml:
|
||||
if self.device.type == 'cpu' and not coreml and not xml:
|
||||
LOGGER.info('half=True only compatible with GPU or CoreML export, i.e. use device=0 or format=coreml')
|
||||
self.args.half = False
|
||||
assert not self.args.dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic'
|
||||
@ -184,7 +184,7 @@ class Exporter:
|
||||
y = None
|
||||
for _ in range(2):
|
||||
y = model(im) # dry runs
|
||||
if self.args.half and not coreml:
|
||||
if self.args.half and not coreml and not xml:
|
||||
im, model = im.half(), model.half() # to FP16
|
||||
shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape
|
||||
LOGGER.info(
|
||||
@ -332,7 +332,7 @@ class Exporter:
|
||||
f = str(self.file).replace(self.file.suffix, f'_openvino_model{os.sep}')
|
||||
f_onnx = self.file.with_suffix('.onnx')
|
||||
|
||||
cmd = f"mo --input_model {f_onnx} --output_dir {f} --data_type {'FP16' if self.args.half else 'FP32'}"
|
||||
cmd = f"mo --input_model {f_onnx} --output_dir {f} {'--compress_to_fp16' * self.args.half}"
|
||||
subprocess.run(cmd.split(), check=True, env=os.environ) # export
|
||||
yaml_save(Path(f) / self.file.with_suffix('.yaml').name, self.metadata) # add metadata.yaml
|
||||
return f, None
|
||||
|
@ -6,7 +6,7 @@ from ultralytics import yolo # noqa
|
||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, yaml_load
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_yaml
|
||||
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode
|
||||
|
||||
@ -151,7 +151,7 @@ class YOLO:
|
||||
overrides = self.overrides.copy()
|
||||
overrides.update(kwargs)
|
||||
overrides["mode"] = "val"
|
||||
args = get_cfg(cfg=DEFAULT_CFG_PATH, overrides=overrides)
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
args.data = data or args.data
|
||||
args.task = self.task
|
||||
|
||||
@ -169,7 +169,7 @@ class YOLO:
|
||||
|
||||
overrides = self.overrides.copy()
|
||||
overrides.update(kwargs)
|
||||
args = get_cfg(cfg=DEFAULT_CFG_PATH, overrides=overrides)
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
args.task = self.task
|
||||
|
||||
print(args)
|
||||
@ -181,8 +181,7 @@ class YOLO:
|
||||
Trains the model on a given dataset.
|
||||
|
||||
Args:
|
||||
**kwargs (Any): Any number of arguments representing the training configuration. List of all args can be found in 'config' section.
|
||||
You can pass all arguments as a yaml file in `cfg`. Other args are ignored if `cfg` file is passed
|
||||
**kwargs (Any): Any number of arguments representing the training configuration.
|
||||
"""
|
||||
overrides = self.overrides.copy()
|
||||
overrides.update(kwargs)
|
||||
@ -192,7 +191,7 @@ class YOLO:
|
||||
overrides["task"] = self.task
|
||||
overrides["mode"] = "train"
|
||||
if not overrides.get("data"):
|
||||
raise AttributeError("dataset not provided! Please define `data` in config.yaml or pass as an argument.")
|
||||
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
|
||||
if overrides.get("resume"):
|
||||
overrides["resume"] = self.ckpt_path
|
||||
|
||||
@ -223,6 +222,13 @@ class YOLO:
|
||||
|
||||
return model_class, trainer_class, validator_class, predictor_class
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
"""
|
||||
Returns class names of the loaded model.
|
||||
"""
|
||||
return self.model.names
|
||||
|
||||
@staticmethod
|
||||
def _reset_ckpt_args(args):
|
||||
args.pop("project", None)
|
||||
|
@ -27,7 +27,6 @@ Usage - formats:
|
||||
"""
|
||||
import platform
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
@ -62,15 +61,15 @@ class BasePredictor:
|
||||
data_path (str): Path to data.
|
||||
"""
|
||||
|
||||
def __init__(self, config=DEFAULT_CFG_PATH, overrides=None):
|
||||
def __init__(self, cfg=DEFAULT_CFG_PATH, overrides=None):
|
||||
"""
|
||||
Initializes the BasePredictor class.
|
||||
|
||||
Args:
|
||||
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
"""
|
||||
self.args = get_cfg(config, overrides)
|
||||
self.args = get_cfg(cfg, overrides)
|
||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||
name = self.args.name or f"{self.args.mode}"
|
||||
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
|
||||
@ -219,7 +218,7 @@ class BasePredictor:
|
||||
self.run_callbacks("on_predict_batch_end")
|
||||
|
||||
# Print results
|
||||
if verbose:
|
||||
if verbose and self.seen:
|
||||
t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image
|
||||
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms postprocess per image at shape '
|
||||
f'{(1, 3, *self.imgsz)}' % t)
|
||||
|
@ -31,7 +31,8 @@ from ultralytics.yolo.utils.autobatch import check_train_batch_size
|
||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args
|
||||
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
|
||||
from ultralytics.yolo.utils.files import get_latest_run, increment_path
|
||||
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
|
||||
from ultralytics.yolo.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle,
|
||||
strip_optimizer)
|
||||
|
||||
|
||||
class BaseTrainer:
|
||||
@ -71,15 +72,15 @@ class BaseTrainer:
|
||||
csv (Path): Path to results CSV file.
|
||||
"""
|
||||
|
||||
def __init__(self, config=DEFAULT_CFG_PATH, overrides=None):
|
||||
def __init__(self, cfg=DEFAULT_CFG_PATH, overrides=None):
|
||||
"""
|
||||
Initializes the BaseTrainer class.
|
||||
|
||||
Args:
|
||||
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
"""
|
||||
self.args = get_cfg(config, overrides)
|
||||
self.args = get_cfg(cfg, overrides)
|
||||
self.device = utils.torch_utils.select_device(self.args.device, self.args.batch)
|
||||
self.check_resume()
|
||||
self.console = LOGGER
|
||||
@ -225,6 +226,7 @@ class BaseTrainer:
|
||||
self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
|
||||
self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
||||
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
||||
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
|
||||
|
||||
# dataloaders
|
||||
batch_size = self.batch_size // world_size if world_size > 1 else self.batch_size
|
||||
@ -333,10 +335,12 @@ class BaseTrainer:
|
||||
|
||||
# Validation
|
||||
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
|
||||
final_epoch = (epoch + 1 == self.epochs)
|
||||
final_epoch = (epoch + 1 == self.epochs) or self.stopper.possible_stop
|
||||
|
||||
if self.args.val or final_epoch:
|
||||
self.metrics, self.fitness = self.validate()
|
||||
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
|
||||
self.stop = self.stopper(epoch + 1, self.fitness)
|
||||
|
||||
# Save model
|
||||
if self.args.save or (epoch + 1 == self.epochs):
|
||||
@ -347,7 +351,15 @@ class BaseTrainer:
|
||||
self.epoch_time = tnow - self.epoch_time_start
|
||||
self.epoch_time_start = tnow
|
||||
self.run_callbacks("on_fit_epoch_end")
|
||||
# TODO: termination condition
|
||||
|
||||
# Early Stopping
|
||||
if RANK != -1: # if DDP training
|
||||
broadcast_list = [self.stop if RANK == 0 else None]
|
||||
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
|
||||
if RANK != 0:
|
||||
self.stop = broadcast_list[0]
|
||||
if self.stop:
|
||||
break # must break all DDP ranks
|
||||
|
||||
if rank in {-1, 0}:
|
||||
# Do final val with best.pt
|
||||
|
@ -8,9 +8,9 @@ import platform
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import types
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Union
|
||||
|
||||
import cv2
|
||||
@ -55,10 +55,34 @@ HELP_MSG = \
|
||||
|
||||
3. Use the command line interface (CLI):
|
||||
|
||||
yolo task=detect mode=train model=yolov8n.yaml args...
|
||||
classify predict yolov8n-cls.yaml args...
|
||||
segment val yolov8n-seg.yaml args...
|
||||
export yolov8n.pt format=onnx args...
|
||||
YOLOv8 'yolo' CLI commands use the following syntax:
|
||||
|
||||
yolo TASK MODE ARGS
|
||||
|
||||
Where TASK (optional) is one of [detect, segment, classify]
|
||||
MODE (required) is one of [train, val, predict, export]
|
||||
ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
|
||||
See all ARGS at https://docs.ultralytics.com/cfg or with 'yolo cfg'
|
||||
|
||||
- Train a detection model for 10 epochs with an initial learning_rate of 0.01
|
||||
yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
|
||||
|
||||
- Predict a YouTube video using a pretrained segmentation model at image size 320:
|
||||
yolo segment predict model=yolov8n-seg.pt source=https://youtu.be/Zgi9g1ksQHc imgsz=320
|
||||
|
||||
- Val a pretrained detection model at batch-size 1 and image size 640:
|
||||
yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
|
||||
|
||||
- Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
|
||||
yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
|
||||
|
||||
- Run special commands:
|
||||
yolo help
|
||||
yolo checks
|
||||
yolo version
|
||||
yolo settings
|
||||
yolo copy-cfg
|
||||
yolo cfg
|
||||
|
||||
Docs: https://docs.ultralytics.com
|
||||
Community: https://community.ultralytics.com
|
||||
@ -73,11 +97,24 @@ cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with Py
|
||||
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
|
||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # for deterministic training
|
||||
|
||||
# Default config dictionary
|
||||
|
||||
class IterableSimpleNamespace(SimpleNamespace):
|
||||
"""
|
||||
Iterable SimpleNamespace class to allow SimpleNamespace to be used with dict() and in for loops
|
||||
"""
|
||||
|
||||
def __iter__(self):
|
||||
return iter(vars(self).items())
|
||||
|
||||
def __str__(self):
|
||||
return '\n'.join(f"{k}={v}" for k, v in vars(self).items())
|
||||
|
||||
|
||||
# Default configuration
|
||||
with open(DEFAULT_CFG_PATH, errors='ignore') as f:
|
||||
DEFAULT_CFG_DICT = yaml.safe_load(f)
|
||||
DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()
|
||||
DEFAULT_CFG = types.SimpleNamespace(**DEFAULT_CFG_DICT)
|
||||
DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT)
|
||||
|
||||
|
||||
def is_colab():
|
||||
@ -307,14 +344,15 @@ def set_logging(name=LOGGING_NAME, verbose=True):
|
||||
|
||||
class TryExcept(contextlib.ContextDecorator):
|
||||
# YOLOv8 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager
|
||||
def __init__(self, msg=''):
|
||||
def __init__(self, msg='', verbose=True):
|
||||
self.msg = msg
|
||||
self.verbose = verbose
|
||||
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type, value, traceback):
|
||||
if value:
|
||||
if self.verbose and value:
|
||||
print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}"))
|
||||
return True
|
||||
|
||||
@ -366,6 +404,21 @@ def yaml_load(file='data.yaml', append_filename=False):
|
||||
return {**yaml.safe_load(f), 'yaml_file': str(file)} if append_filename else yaml.safe_load(f)
|
||||
|
||||
|
||||
def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
|
||||
"""
|
||||
Pretty prints a yaml file or a yaml-formatted dictionary.
|
||||
|
||||
Args:
|
||||
yaml_file: The file path of the yaml file or a yaml-formatted dictionary.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
yaml_dict = yaml_load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file
|
||||
dump = yaml.dump(yaml_dict, default_flow_style=False)
|
||||
LOGGER.info(f"Printing '{colorstr('bold', 'black', yaml_file)}'\n\n{dump}")
|
||||
|
||||
|
||||
def set_sentry(dsn=None):
|
||||
"""
|
||||
Initialize the Sentry SDK for error tracking and reporting if pytest is not currently running.
|
||||
@ -379,7 +432,6 @@ def set_sentry(dsn=None):
|
||||
debug=False,
|
||||
traces_sample_rate=1.0,
|
||||
release=ultralytics.__version__,
|
||||
send_default_pii=True,
|
||||
environment='production', # 'dev' or 'production'
|
||||
ignore_errors=[KeyboardInterrupt])
|
||||
|
||||
@ -439,17 +491,6 @@ def set_settings(kwargs, file=USER_CONFIG_DIR / 'settings.yaml'):
|
||||
yaml_save(file, SETTINGS)
|
||||
|
||||
|
||||
def print_settings():
|
||||
"""
|
||||
Function that prints Ultralytics settings
|
||||
"""
|
||||
import json
|
||||
s = f'\n{PREFIX}Settings:\n'
|
||||
s += json.dumps(SETTINGS, indent=2)
|
||||
s += f"\n\nUpdate settings at {USER_CONFIG_DIR / 'settings.yaml'}"
|
||||
LOGGER.info(s)
|
||||
|
||||
|
||||
# Run below code on utils init -----------------------------------------------------------------------------------------
|
||||
|
||||
# Set logger
|
||||
|
@ -3,7 +3,7 @@
|
||||
import json
|
||||
from time import time
|
||||
|
||||
from ultralytics.hub.utils import PREFIX, sync_analytics
|
||||
from ultralytics.hub.utils import PREFIX, traces
|
||||
from ultralytics.yolo.utils import LOGGER
|
||||
|
||||
|
||||
@ -43,24 +43,24 @@ def on_train_end(trainer):
|
||||
LOGGER.info(f"{PREFIX}Training completed successfully ✅\n"
|
||||
f"{PREFIX}Uploading final {session.model_id}")
|
||||
session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics['metrics/mAP50-95(B)'], final=True)
|
||||
session.alive = False # stop heartbeats
|
||||
session.shutdown() # stop heartbeats
|
||||
LOGGER.info(f"{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀")
|
||||
|
||||
|
||||
def on_train_start(trainer):
|
||||
sync_analytics(trainer.args)
|
||||
traces(trainer.args, traces_sample_rate=0.0)
|
||||
|
||||
|
||||
def on_val_start(validator):
|
||||
sync_analytics(validator.args)
|
||||
traces(validator.args, traces_sample_rate=0.0)
|
||||
|
||||
|
||||
def on_predict_start(predictor):
|
||||
sync_analytics(predictor.args)
|
||||
traces(predictor.args, traces_sample_rate=0.0)
|
||||
|
||||
|
||||
def on_export_start(exporter):
|
||||
sync_analytics(exporter.args)
|
||||
traces(exporter.args, traces_sample_rate=0.0)
|
||||
|
||||
|
||||
callbacks = {
|
||||
|
@ -154,7 +154,7 @@ def check_python(minimum: str = '3.7.0') -> bool:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
check_version(platform.python_version(), minimum, name='Python ', hard=True)
|
||||
return check_version(platform.python_version(), minimum, name='Python ', hard=True)
|
||||
|
||||
|
||||
@TryExcept()
|
||||
@ -223,8 +223,10 @@ def check_file(file, suffix=''):
|
||||
files = []
|
||||
for d in 'models', 'yolo/data': # search directories
|
||||
files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
|
||||
assert len(files), f'File not found: {file}' # assert file was found
|
||||
assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
|
||||
if not files:
|
||||
raise FileNotFoundError(f"{file} does not exist")
|
||||
elif len(files) > 1:
|
||||
raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}")
|
||||
return files[0] # return file
|
||||
|
||||
|
||||
|
@ -141,10 +141,14 @@ def download(url, dir=Path.cwd(), unzip=True, delete=True, curl=False, threads=1
|
||||
dir = Path(dir)
|
||||
dir.mkdir(parents=True, exist_ok=True) # make directory
|
||||
if threads > 1:
|
||||
pool = ThreadPool(threads)
|
||||
pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
|
||||
pool.close()
|
||||
pool.join()
|
||||
# pool = ThreadPool(threads)
|
||||
# pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
|
||||
# pool.close()
|
||||
# pool.join()
|
||||
with ThreadPool(threads) as pool:
|
||||
pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
|
||||
pool.close()
|
||||
pool.join()
|
||||
else:
|
||||
for u in [url] if isinstance(url, (str, Path)) else url:
|
||||
download_one(u, dir)
|
||||
|
@ -62,7 +62,9 @@ def select_device(device='', batch_size=0, newline=False):
|
||||
# device = None or 'cpu' or 0 or '0' or '0,1,2,3'
|
||||
ver = git_describe() or ultralytics.__version__ # git commit or pip package version
|
||||
s = f'Ultralytics YOLOv{ver} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
|
||||
device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0'
|
||||
device = str(device).lower()
|
||||
for remove in 'cuda:', 'none', '(', ')', '[', ']', "'", ' ':
|
||||
device = device.replace(remove, '') # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
|
||||
cpu = device == 'cpu'
|
||||
mps = device == 'mps' # Apple Metal Performance Shaders (MPS)
|
||||
if cpu or mps:
|
||||
@ -369,3 +371,26 @@ def profile(input, ops, n=10, device=None):
|
||||
results.append(None)
|
||||
torch.cuda.empty_cache()
|
||||
return results
|
||||
|
||||
|
||||
class EarlyStopping:
|
||||
# early stopper
|
||||
def __init__(self, patience=30):
|
||||
self.best_fitness = 0.0 # i.e. mAP
|
||||
self.best_epoch = 0
|
||||
self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
|
||||
self.possible_stop = False # possible stop may occur next epoch
|
||||
|
||||
def __call__(self, epoch, fitness):
|
||||
if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
|
||||
self.best_epoch = epoch
|
||||
self.best_fitness = fitness
|
||||
delta = epoch - self.best_epoch # epochs without improvement
|
||||
self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
|
||||
stop = delta >= self.patience # stop training if patience exceeded
|
||||
if stop:
|
||||
LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. '
|
||||
f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n'
|
||||
f'To update EarlyStopping(patience={self.patience}) pass a new patience value, '
|
||||
f'i.e. `patience=300` or use `patience=0` to disable EarlyStopping.')
|
||||
return stop
|
||||
|
@ -4,7 +4,7 @@ import torch
|
||||
|
||||
from ultralytics.yolo.engine.predictor import BasePredictor
|
||||
from ultralytics.yolo.engine.results import Results
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, is_git_directory
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, ROOT
|
||||
from ultralytics.yolo.utils.plotting import Annotator
|
||||
|
||||
|
||||
@ -65,7 +65,7 @@ class ClassificationPredictor(BasePredictor):
|
||||
|
||||
def predict(cfg=DEFAULT_CFG):
|
||||
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
|
||||
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if is_git_directory() \
|
||||
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if (ROOT / "assets").exists() \
|
||||
else "https://ultralytics.com/images/bus.jpg"
|
||||
predictor = ClassificationPredictor(cfg)
|
||||
predictor.predict_cli()
|
||||
|
@ -4,7 +4,7 @@ import torch
|
||||
|
||||
from ultralytics.yolo.engine.predictor import BasePredictor
|
||||
from ultralytics.yolo.engine.results import Results
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, is_git_directory, ops
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops
|
||||
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
|
||||
|
||||
|
||||
@ -83,7 +83,7 @@ class DetectionPredictor(BasePredictor):
|
||||
|
||||
def predict(cfg=DEFAULT_CFG):
|
||||
cfg.model = cfg.model or "yolov8n.pt"
|
||||
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if is_git_directory() \
|
||||
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if (ROOT / "assets").exists() \
|
||||
else "https://ultralytics.com/images/bus.jpg"
|
||||
predictor = DetectionPredictor(cfg)
|
||||
predictor.predict_cli()
|
||||
|
@ -3,7 +3,7 @@
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.engine.results import Results
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, is_git_directory, ops
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops
|
||||
from ultralytics.yolo.utils.plotting import colors, save_one_box
|
||||
from ultralytics.yolo.v8.detect.predict import DetectionPredictor
|
||||
|
||||
@ -100,7 +100,7 @@ class SegmentationPredictor(DetectionPredictor):
|
||||
|
||||
def predict(cfg=DEFAULT_CFG):
|
||||
cfg.model = cfg.model or "yolov8n-seg.pt"
|
||||
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if is_git_directory() \
|
||||
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if (ROOT / "assets").exists() \
|
||||
else "https://ultralytics.com/images/bus.jpg"
|
||||
predictor = SegmentationPredictor(cfg)
|
||||
predictor.predict_cli()
|
||||
|
Reference in New Issue
Block a user