`ultralytics 8.0.12` - Hydra removal (#506)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Pronoy Mandal <lukex9442@gmail.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent 6eec39162a
commit c5fccc3fc4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -85,6 +85,7 @@ jobs:
shell: bash # for Windows compatibility shell: bash # for Windows compatibility
run: | run: |
yolo task=detect mode=train data=coco8.yaml model=yolov8n.yaml epochs=1 imgsz=32 yolo task=detect mode=train data=coco8.yaml model=yolov8n.yaml epochs=1 imgsz=32
yolo task=detect mode=train data=coco8.yaml model=yolov8n.pt epochs=1 imgsz=32
yolo task=detect mode=val data=coco8.yaml model=runs/detect/train/weights/last.pt imgsz=32 yolo task=detect mode=val data=coco8.yaml model=runs/detect/train/weights/last.pt imgsz=32
yolo task=detect mode=predict model=runs/detect/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg yolo task=detect mode=predict model=runs/detect/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg
yolo mode=export model=runs/detect/train/weights/last.pt imgsz=32 format=torchscript yolo mode=export model=runs/detect/train/weights/last.pt imgsz=32 format=torchscript
@ -92,6 +93,7 @@ jobs:
shell: bash # for Windows compatibility shell: bash # for Windows compatibility
run: | run: |
yolo task=segment mode=train data=coco8-seg.yaml model=yolov8n-seg.yaml epochs=1 imgsz=32 yolo task=segment mode=train data=coco8-seg.yaml model=yolov8n-seg.yaml epochs=1 imgsz=32
yolo task=segment mode=train data=coco8-seg.yaml model=yolov8n-seg.pt epochs=1 imgsz=32
yolo task=segment mode=val data=coco8-seg.yaml model=runs/segment/train/weights/last.pt imgsz=32 yolo task=segment mode=val data=coco8-seg.yaml model=runs/segment/train/weights/last.pt imgsz=32
yolo task=segment mode=predict model=runs/segment/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg yolo task=segment mode=predict model=runs/segment/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg
yolo mode=export model=runs/segment/train/weights/last.pt imgsz=32 format=torchscript yolo mode=export model=runs/segment/train/weights/last.pt imgsz=32 format=torchscript
@ -99,6 +101,7 @@ jobs:
shell: bash # for Windows compatibility shell: bash # for Windows compatibility
run: | run: |
yolo task=classify mode=train data=mnist160 model=yolov8n-cls.yaml epochs=1 imgsz=32 yolo task=classify mode=train data=mnist160 model=yolov8n-cls.yaml epochs=1 imgsz=32
yolo task=classify mode=train data=mnist160 model=yolov8n-cls.pt epochs=1 imgsz=32
yolo task=classify mode=val data=mnist160 model=runs/classify/train/weights/last.pt imgsz=32 yolo task=classify mode=val data=mnist160 model=runs/classify/train/weights/last.pt imgsz=32
yolo task=classify mode=predict model=runs/classify/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg yolo task=classify mode=predict model=runs/classify/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg
yolo mode=export model=runs/classify/train/weights/last.pt imgsz=32 format=torchscript yolo mode=export model=runs/classify/train/weights/last.pt imgsz=32 format=torchscript

1
.gitignore vendored

@ -136,6 +136,7 @@ wandb/
.DS_Store .DS_Store
# Neural Network weights ----------------------------------------------------------------------------------------------- # Neural Network weights -----------------------------------------------------------------------------------------------
weights/
*.weights *.weights
*.pt *.pt
*.pb *.pb

@ -10,7 +10,7 @@ ADD https://ultralytics.com/assets/Arial.ttf https://ultralytics.com/assets/Aria
# Remove torch nightly and install torch stable # Remove torch nightly and install torch stable
RUN rm -rf /opt/pytorch # remove 1.2GB dir RUN rm -rf /opt/pytorch # remove 1.2GB dir
RUN pip uninstall -y torchtext torch torchvision RUN pip uninstall -y torchtext pillow torch torchvision
RUN pip install --no-cache torch torchvision RUN pip install --no-cache torch torchvision
# Install linux packages # Install linux packages

@ -6,9 +6,9 @@ Inference or prediction of a task returns a list of `Results` objects. Alternati
inputs = [img, img] # list of np arrays inputs = [img, img] # list of np arrays
results = model(inputs) # List of Results objects results = model(inputs) # List of Results objects
for result in results: for result in results:
boxes = results.boxes # Boxes object for bbox outputs boxes = result.boxes # Boxes object for bbox outputs
masks = results.masks # Masks object for segmenation masks outputs masks = result.masks # Masks object for segmenation masks outputs
probs = results.probs # Class probabilities for classification outputs probs = result.probs # Class probabilities for classification outputs
... ...
``` ```
=== "Getting a Generator" === "Getting a Generator"
@ -16,9 +16,9 @@ Inference or prediction of a task returns a list of `Results` objects. Alternati
inputs = [img, img] # list of np arrays inputs = [img, img] # list of np arrays
results = model(inputs, stream=True) # Generator of Results objects results = model(inputs, stream=True) # Generator of Results objects
for result in results: for result in results:
boxes = results.boxes # Boxes object for bbox outputs boxes = result.boxes # Boxes object for bbox outputs
masks = results.masks # Masks object for segmenation masks outputs masks = result.masks # Masks object for segmenation masks outputs
probs = results.probs # Class probabilities for classification outputs probs = result.probs # Class probabilities for classification outputs
... ...
``` ```

@ -2,7 +2,6 @@
# Usage: pip install -r requirements.txt # Usage: pip install -r requirements.txt
# Base ---------------------------------------- # Base ----------------------------------------
hydra-core>=1.2.0
matplotlib>=3.2.2 matplotlib>=3.2.2
numpy>=1.18.5 numpy>=1.18.5
opencv-python>=4.1.1 opencv-python>=4.1.1

@ -51,4 +51,5 @@ setup(
"Operating System :: MacOS", "Operating System :: Microsoft :: Windows"], "Operating System :: MacOS", "Operating System :: Microsoft :: Windows"],
keywords="machine-learning, deep-learning, vision, ML, DL, AI, YOLO, YOLOv3, YOLOv5, YOLOv8, HUB, Ultralytics", keywords="machine-learning, deep-learning, vision, ML, DL, AI, YOLO, YOLOv3, YOLOv5, YOLOv8, HUB, Ultralytics",
entry_points={ entry_points={
'console_scripts': ['yolo = ultralytics.yolo.cli:entrypoint', 'ultralytics = ultralytics.yolo.cli:entrypoint']}) 'console_scripts':
['yolo = ultralytics.yolo.configs:entrypoint', 'ultralytics = ultralytics.yolo.configs:entrypoint']})

@ -3,13 +3,13 @@
from pathlib import Path from pathlib import Path
from ultralytics.yolo.configs import get_config from ultralytics.yolo.configs import get_config
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, SETTINGS from ultralytics.yolo.utils import DEFAULT_CFG_PATH, ROOT, SETTINGS
from ultralytics.yolo.v8 import classify, detect, segment from ultralytics.yolo.v8 import classify, detect, segment
CFG_DET = 'yolov8n.yaml' CFG_DET = 'yolov8n.yaml'
CFG_SEG = 'yolov8n-seg.yaml' CFG_SEG = 'yolov8n-seg.yaml'
CFG_CLS = 'squeezenet1_0' CFG_CLS = 'squeezenet1_0'
CFG = get_config(DEFAULT_CONFIG) CFG = get_config(DEFAULT_CFG_PATH)
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n' MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n'
SOURCE = ROOT / "assets" SOURCE = ROOT / "assets"

@ -49,6 +49,8 @@ def test_predict_img():
assert len(output) == 1, "predict test failed" assert len(output) == 1, "predict test failed"
output = model(source=[img, img], save=True, save_txt=True) # batch output = model(source=[img, img], save=True, save_txt=True) # batch
assert len(output) == 2, "predict test failed" assert len(output) == 2, "predict test failed"
output = model(source=[img, img], save=True, stream=True) # stream
assert len(list(output)) == 2, "predict test failed"
tens = torch.zeros(320, 640, 3) tens = torch.zeros(320, 640, 3)
output = model(tens.numpy()) output = model(tens.numpy())
assert len(output) == 1, "predict test failed" assert len(output) == 1, "predict test failed"

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
__version__ = "8.0.11" __version__ = "8.0.12"
from ultralytics.yolo.engine.model import YOLO from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils import ops from ultralytics.yolo.utils import ops

@ -7,7 +7,7 @@ import time
import requests import requests
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, LOGGER, RANK, SETTINGS, TryExcept, colorstr, emojis from ultralytics.yolo.utils import DEFAULT_CFG_DICT, LOGGER, RANK, SETTINGS, TryExcept, colorstr, emojis
PREFIX = colorstr('Ultralytics: ') PREFIX = colorstr('Ultralytics: ')
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.' HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
@ -143,7 +143,7 @@ def sync_analytics(cfg, all_keys=False, enabled=False):
if SETTINGS['sync'] and RANK in {-1, 0} and enabled: if SETTINGS['sync'] and RANK in {-1, 0} and enabled:
cfg = dict(cfg) # convert type from DictConfig to dict cfg = dict(cfg) # convert type from DictConfig to dict
if not all_keys: if not all_keys:
cfg = {k: v for k, v in cfg.items() if v != DEFAULT_CONFIG_DICT.get(k, None)} # retain non-default values cfg = {k: v for k, v in cfg.items() if v != DEFAULT_CFG_DICT.get(k, None)} # retain non-default values
cfg['uuid'] = SETTINGS['uuid'] # add the device UUID to the configuration data cfg['uuid'] = SETTINGS['uuid'] # add the device UUID to the configuration data
# Send a request to the HUB API to sync analytics # Send a request to the HUB API to sync analytics

@ -10,7 +10,7 @@ import torch.nn as nn
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify, from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus, Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
GhostBottleneck, GhostConv, Segment) GhostBottleneck, GhostConv, Segment)
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, DEFAULT_CONFIG_KEYS, LOGGER, colorstr, yaml_load from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, yaml_load
from ultralytics.yolo.utils.checks import check_yaml from ultralytics.yolo.utils.checks import check_yaml
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible, from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible,
model_info, scale_img, time_sync) model_info, scale_img, time_sync)
@ -113,7 +113,7 @@ class BaseModel(nn.Module):
thresh (int, optional): The threshold number of BatchNorm layers. Default is 10. thresh (int, optional): The threshold number of BatchNorm layers. Default is 10.
Returns: Returns:
bool: True if the number of BatchNorm layers in the model is less than the threshold, False otherwise. (bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
""" """
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d() bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
@ -321,11 +321,11 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
model = Ensemble() model = Ensemble()
for w in weights if isinstance(weights, list) else [weights]: for w in weights if isinstance(weights, list) else [weights]:
ckpt = torch.load(attempt_download(w), map_location='cpu') # load ckpt = torch.load(attempt_download(w), map_location='cpu') # load
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
# Model compatibility updates # Model compatibility updates
ckpt.args = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS} # attach args to model ckpt.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
ckpt.pt_path = weights # attach *.pt file path to model ckpt.pt_path = weights # attach *.pt file path to model
if not hasattr(ckpt, 'stride'): if not hasattr(ckpt, 'stride'):
ckpt.stride = torch.tensor([32.]) ckpt.stride = torch.tensor([32.])
@ -359,11 +359,11 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
from ultralytics.yolo.utils.downloads import attempt_download from ultralytics.yolo.utils.downloads import attempt_download
ckpt = torch.load(attempt_download(weight), map_location='cpu') # load ckpt = torch.load(attempt_download(weight), map_location='cpu') # load
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
# Model compatibility updates # Model compatibility updates
model.args = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS} # attach args to model model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
model.pt_path = weight # attach *.pt file path to model model.pt_path = weight # attach *.pt file path to model
if not hasattr(model, 'stride'): if not hasattr(model, 'stride'):
model.stride = torch.tensor([32.]) model.stride = torch.tensor([32.])

@ -1,156 +0,0 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
import argparse
import re
import shutil
import sys
from pathlib import Path
from ultralytics import __version__, yolo
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, PREFIX, checks, print_settings, yaml_load
DIR = Path(__file__).parent
CLI_HELP_MSG = \
"""
YOLOv8 CLI Usage examples:
1. Install the ultralytics package:
pip install ultralytics
2. Train, Val, Predict and Export using 'yolo' commands:
yolo TASK MODE ARGS
Where TASK (optional) is one of [detect, segment, classify]
MODE (required) is one of [train, val, predict, export]
ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
For a full list of available ARGS see https://docs.ultralytics.com/config.
Train a detection model for 10 epochs with an initial learning_rate of 0.01
yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
Predict a YouTube video using a pretrained segmentation model at image size 320:
yolo segment predict model=yolov8n-seg.pt source=https://youtu.be/Zgi9g1ksQHc imgsz=320
Validate a pretrained detection model at batch-size 1 and image size 640:
yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
3. Run special commands:
yolo help
yolo checks
yolo version
yolo settings
yolo copy-config
Docs: https://docs.ultralytics.com/cli
Community: https://community.ultralytics.com
GitHub: https://github.com/ultralytics/ultralytics
"""
def cli(cfg):
"""
Run a specified task and mode with the given configuration.
Args:
cfg (DictConfig): Configuration for the task and mode.
"""
# LOGGER.info(f"{colorstr(f'Ultralytics YOLO v{ultralytics.__version__}')}")
from ultralytics.yolo.configs import get_config
if cfg.cfg:
LOGGER.info(f"{PREFIX}Overriding default config with {cfg.cfg}")
cfg = get_config(cfg.cfg)
task, mode = cfg.task.lower(), cfg.mode.lower()
# Mapping from task to module
tasks = {"detect": yolo.v8.detect, "segment": yolo.v8.segment, "classify": yolo.v8.classify}
module = tasks.get(task)
if not module:
raise SyntaxError(f"yolo task={task} is invalid. Valid tasks are: {', '.join(tasks.keys())}\n{CLI_HELP_MSG}")
# Mapping from mode to function
modes = {"train": module.train, "val": module.val, "predict": module.predict, "export": yolo.engine.exporter.export}
func = modes.get(mode)
if not func:
raise SyntaxError(f"yolo mode={mode} is invalid. Valid modes are: {', '.join(modes.keys())}\n{CLI_HELP_MSG}")
func(cfg)
def entrypoint():
"""
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
to the package. It's a combination of argparse and hydra.
This function allows for:
- passing mandatory YOLO args as a list of strings
- specifying the task to be performed, either 'detect', 'segment' or 'classify'
- specifying the mode, either 'train', 'val', 'test', or 'predict'
- running special modes like 'checks'
- passing overrides to the package's configuration
It uses the package's default config and initializes it using the passed overrides.
Then it calls the CLI function with the composed config
"""
if len(sys.argv) == 1: # no arguments passed
LOGGER.info(CLI_HELP_MSG)
return
parser = argparse.ArgumentParser(description='YOLO parser')
parser.add_argument('args', type=str, nargs='+', help='YOLO args')
args = parser.parse_args().args
args = re.sub(r'\s*=\s*', '=', ' '.join(args)).split(' ') # remove whitespaces around = sign
tasks = 'detect', 'segment', 'classify'
modes = 'train', 'val', 'predict', 'export'
special_modes = {
'help': lambda: LOGGER.info(CLI_HELP_MSG),
'checks': checks.check_yolo,
'version': lambda: LOGGER.info(__version__),
'settings': print_settings,
'copy-config': copy_default_config}
overrides = [] # basic overrides, i.e. imgsz=320
defaults = yaml_load(DEFAULT_CONFIG)
for a in args:
if '=' in a:
overrides.append(a)
elif a in tasks:
overrides.append(f'task={a}')
elif a in modes:
overrides.append(f'mode={a}')
elif a in special_modes:
special_modes[a]()
return
elif a in defaults and defaults[a] is False:
overrides.append(f'{a}=True') # auto-True for default False args, i.e. yolo show
elif a in defaults:
raise SyntaxError(f"'{a}' is a valid YOLO argument but is missing an '=' sign to set its value, "
f"i.e. try '{a}={defaults[a]}'"
f"\n{CLI_HELP_MSG}")
else:
raise SyntaxError(
f"'{a}' is not a valid YOLO argument. For a full list of valid arguments see "
f"https://github.com/ultralytics/ultralytics/blob/main/ultralytics/yolo/configs/default.yaml"
f"\n{CLI_HELP_MSG}")
from hydra import compose, initialize
with initialize(version_base=None, config_path=str(DEFAULT_CONFIG.parent.relative_to(DIR)), job_name="YOLO"):
cfg = compose(config_name=DEFAULT_CONFIG.name, overrides=overrides)
cli(cfg)
# Special modes --------------------------------------------------------------------------------------------------------
def copy_default_config():
new_file = Path.cwd() / DEFAULT_CONFIG.name.replace('.yaml', '_copy.yaml')
shutil.copy2(DEFAULT_CONFIG, new_file)
LOGGER.info(f"{PREFIX}{DEFAULT_CONFIG} copied to {new_file}\n"
f"Usage for running YOLO with this new custom config:\nyolo cfg={new_file} args...")

@ -1,36 +1,221 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
import argparse
import re
import shutil
import sys
from difflib import get_close_matches
from pathlib import Path from pathlib import Path
from types import SimpleNamespace
from typing import Dict, Union from typing import Dict, Union
from omegaconf import DictConfig, OmegaConf from ultralytics import __version__, yolo
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, PREFIX, checks, colorstr, print_settings, yaml_load
DIR = Path(__file__).parent
CLI_HELP_MSG = \
"""
YOLOv8 CLI Usage examples:
1. Install the ultralytics package:
pip install ultralytics
2. Train, Val, Predict and Export using 'yolo' commands:
yolo TASK MODE ARGS
Where TASK (optional) is one of [detect, segment, classify]
MODE (required) is one of [train, val, predict, export]
ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
For a full list of available ARGS see https://docs.ultralytics.com/config.
Train a detection model for 10 epochs with an initial learning_rate of 0.01
yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
Predict a YouTube video using a pretrained segmentation model at image size 320:
yolo segment predict model=yolov8n-seg.pt source=https://youtu.be/Zgi9g1ksQHc imgsz=320
Validate a pretrained detection model at batch-size 1 and image size 640:
yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
from ultralytics.yolo.configs.hydra_patch import check_config_mismatch 3. Run special commands:
yolo help
yolo checks
yolo version
yolo settings
yolo copy-config
def get_config(config: Union[str, Path, DictConfig], overrides: Union[str, Dict] = None): Docs: https://docs.ultralytics.com/cli
Community: https://community.ultralytics.com
GitHub: https://github.com/ultralytics/ultralytics
"""
def cfg2dict(cfg):
"""
Convert a configuration object to a dictionary.
This function converts a configuration object to a dictionary, whether it is a file path, a string, or a SimpleNamespace object.
Inputs:
cfg (str) or (Path) or (SimpleNamespace): Configuration object to be converted to a dictionary.
Returns:
cfg (dict): Configuration object in dictionary format.
"""
if isinstance(cfg, (str, Path)):
cfg = yaml_load(cfg) # load dict
elif isinstance(cfg, SimpleNamespace):
cfg = vars(cfg) # convert to dict
return cfg
def get_config(config: Union[str, Path, Dict, SimpleNamespace], overrides: Dict = None):
""" """
Load and merge configuration data from a file or dictionary. Load and merge configuration data from a file or dictionary.
Args: Args:
config (str) or (Path) or (DictConfig): Configuration data in the form of a file name or a DictConfig object. config (str) or (Path) or (Dict) or (SimpleNamespace): Configuration data.
overrides (str) or(Dict), optional: Overrides in the form of a file name or a dictionary. Default is None. overrides (str) or (Dict), optional: Overrides in the form of a file name or a dictionary. Default is None.
Returns: Returns:
OmegaConf.Namespace: Training arguments namespace. (SimpleNamespace): Training arguments namespace.
""" """
if overrides is None: config = cfg2dict(config)
overrides = {}
if isinstance(config, (str, Path)): # Merge overrides
config = OmegaConf.load(config) if overrides:
elif isinstance(config, Dict): overrides = cfg2dict(overrides)
config = OmegaConf.create(config) check_config_mismatch(config, overrides)
# override config = {**config, **overrides} # merge config and overrides dicts (prefer overrides)
if isinstance(overrides, str):
overrides = OmegaConf.load(overrides) # Return instance
elif isinstance(overrides, Dict): return SimpleNamespace(**config)
overrides = OmegaConf.create(overrides)
check_config_mismatch(dict(overrides).keys(), dict(config).keys()) def check_config_mismatch(base: Dict, custom: Dict):
"""
return OmegaConf.merge(config, overrides) This function checks for any mismatched keys between a custom configuration list and a base configuration list.
If any mismatched keys are found, the function prints out similar keys from the base list and exits the program.
Inputs:
- custom (Dict): a dictionary of custom configuration options
- base (Dict): a dictionary of base configuration options
"""
base, custom = (set(x.keys()) for x in (base, custom))
mismatched = [x for x in custom if x not in base]
for option in mismatched:
LOGGER.info(f"{colorstr(option)} is not a valid key. Similar keys: {get_close_matches(option, base, 3, 0.6)}")
if mismatched:
sys.exit()
def entrypoint(debug=True):
"""
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
to the package.
This function allows for:
- passing mandatory YOLO args as a list of strings
- specifying the task to be performed, either 'detect', 'segment' or 'classify'
- specifying the mode, either 'train', 'val', 'test', or 'predict'
- running special modes like 'checks'
- passing overrides to the package's configuration
It uses the package's default config and initializes it using the passed overrides.
Then it calls the CLI function with the composed config
"""
if debug:
args = ['train', 'predict', 'model=yolov8n.pt'] # for testing
else:
if len(sys.argv) == 1: # no arguments passed
LOGGER.info(CLI_HELP_MSG)
return
parser = argparse.ArgumentParser(description='YOLO parser')
parser.add_argument('args', type=str, nargs='+', help='YOLO args')
args = parser.parse_args().args
args = re.sub(r'\s*=\s*', '=', ' '.join(args)).split(' ') # remove whitespaces around = sign
tasks = 'detect', 'segment', 'classify'
modes = 'train', 'val', 'predict', 'export'
special_modes = {
'help': lambda: LOGGER.info(CLI_HELP_MSG),
'checks': checks.check_yolo,
'version': lambda: LOGGER.info(__version__),
'settings': print_settings,
'copy-config': copy_default_config}
overrides = {} # basic overrides, i.e. imgsz=320
defaults = yaml_load(DEFAULT_CFG_PATH)
for a in args:
if '=' in a:
if a.startswith('cfg='): # custom.yaml passed
custom_config = Path(a.split('=')[-1])
LOGGER.info(f"{PREFIX}Overriding {DEFAULT_CFG_PATH} with {custom_config}")
overrides = {k: v for k, v in yaml_load(custom_config).items() if k not in {'cfg'}}
else:
k, v = a.split('=')
try:
if k == 'device': # special DDP handling, i.e. device='0,1,2,3'
v = v.replace('[', '').replace(']', '') # handle device=[0,1,2,3]
v = v.replace(" ", "").replace('') # handle device=[0, 1, 2, 3]
v = v.replace('\\', '') # handle device=\'0,1,2,3\'
overrides[k] = v
else:
overrides[k] = eval(v) # convert strings to integers, floats, bools, etc.
except (NameError, SyntaxError):
overrides[k] = v
elif a in tasks:
overrides['task'] = a
elif a in modes:
overrides['mode'] = a
elif a in special_modes:
special_modes[a]()
return
elif a in defaults and defaults[a] is False:
overrides[a] = True # auto-True for default False args, i.e. 'yolo show' sets show=True
elif a in defaults:
raise SyntaxError(f"'{a}' is a valid YOLO argument but is missing an '=' sign to set its value, "
f"i.e. try '{a}={defaults[a]}'"
f"\n{CLI_HELP_MSG}")
else:
raise SyntaxError(
f"'{a}' is not a valid YOLO argument. For a full list of valid arguments see "
f"https://github.com/ultralytics/ultralytics/blob/main/ultralytics/yolo/configs/default.yaml"
f"\n{CLI_HELP_MSG}")
cfg = get_config(defaults, overrides) # create CFG instance
# Mapping from task to module
module = {"detect": yolo.v8.detect, "segment": yolo.v8.segment, "classify": yolo.v8.classify}.get(cfg.task)
if not module:
raise SyntaxError(f"yolo task={cfg.task} is invalid. Valid tasks are: {', '.join(tasks)}\n{CLI_HELP_MSG}")
# Mapping from mode to function
func = {
"train": module.train,
"val": module.val,
"predict": module.predict,
"export": yolo.engine.exporter.export}.get(cfg.mode)
if not func:
raise SyntaxError(f"yolo mode={cfg.mode} is invalid. Valid modes are: {', '.join(modes)}\n{CLI_HELP_MSG}")
func(cfg)
# Special modes --------------------------------------------------------------------------------------------------------
def copy_default_config():
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
shutil.copy2(DEFAULT_CFG_PATH, new_file)
LOGGER.info(f"{PREFIX}{DEFAULT_CFG_PATH} copied to {new_file}\n"
f"Usage for running YOLO with this new custom config:\nyolo cfg={new_file} args...")
if __name__ == '__main__':
entrypoint()

@ -1,68 +1,68 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
# Default training settings and hyperparameters for medium-augmentation COCO training # Default training settings and hyperparameters for medium-augmentation COCO training
task: "detect" # choices=['detect', 'segment', 'classify', 'init'] # init is a special case. Specify task to run. task: "detect" # choices=['detect', 'segment', 'classify', 'init'] # init is a special case. Specify task to run.
mode: "train" # choices=['train', 'val', 'predict'] # mode to run task in. mode: "train" # choices=['train', 'val', 'predict'] # mode to run task in.
# Train settings ------------------------------------------------------------------------------------------------------- # Train settings -------------------------------------------------------------------------------------------------------
model: null # i.e. yolov8n.pt, yolov8n.yaml. Path to model file model: null # i.e. yolov8n.pt, yolov8n.yaml. Path to model file
data: null # i.e. coco128.yaml. Path to data file data: null # i.e. coco128.yaml. Path to data file
epochs: 100 # number of epochs to train for epochs: 100 # number of epochs to train for
patience: 50 # epochs to wait for no observable improvement for early stopping of training patience: 50 # epochs to wait for no observable improvement for early stopping of training
batch: 16 # number of images per batch batch: 16 # number of images per batch
imgsz: 640 # size of input images imgsz: 640 # size of input images
save: True # save checkpoints save: True # save checkpoints
cache: False # True/ram, disk or False. Use cache for data loading cache: False # True/ram, disk or False. Use cache for data loading
device: null # cuda device, i.e. 0 or 0,1,2,3 or cpu. Device to run on device: null # cuda device, i.e. 0 or 0,1,2,3 or cpu. Device to run on
workers: 8 # number of worker threads for data loading workers: 8 # number of worker threads for data loading
project: null # project name project: null # project name
name: null # experiment name name: null # experiment name
exist_ok: False # whether to overwrite existing experiment exist_ok: False # whether to overwrite existing experiment
pretrained: False # whether to use a pretrained model pretrained: False # whether to use a pretrained model
optimizer: 'SGD' # optimizer to use, choices=['SGD', 'Adam', 'AdamW', 'RMSProp'] optimizer: 'SGD' # optimizer to use, choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
verbose: False # whether to print verbose output verbose: False # whether to print verbose output
seed: 0 # random seed for reproducibility seed: 0 # random seed for reproducibility
deterministic: True # whether to enable deterministic mode deterministic: True # whether to enable deterministic mode
single_cls: False # train multi-class data as single-class single_cls: False # train multi-class data as single-class
image_weights: False # use weighted image selection for training image_weights: False # use weighted image selection for training
rect: False # support rectangular training rect: False # support rectangular training
cos_lr: False # use cosine learning rate scheduler cos_lr: False # use cosine learning rate scheduler
close_mosaic: 10 # disable mosaic augmentation for final 10 epochs close_mosaic: 10 # disable mosaic augmentation for final 10 epochs
resume: False # resume training from last checkpoint resume: False # resume training from last checkpoint
# Segmentation # Segmentation
overlap_mask: True # masks should overlap during training overlap_mask: True # masks should overlap during training
mask_ratio: 4 # mask downsample ratio mask_ratio: 4 # mask downsample ratio
# Classification # Classification
dropout: 0.0 # use dropout regularization dropout: 0.0 # use dropout regularization
# Val/Test settings ---------------------------------------------------------------------------------------------------- # Val/Test settings ----------------------------------------------------------------------------------------------------
val: True # validate/test during training val: True # validate/test during training
save_json: False # save results to JSON file save_json: False # save results to JSON file
save_hybrid: False # save hybrid version of labels (labels + additional predictions) save_hybrid: False # save hybrid version of labels (labels + additional predictions)
conf: null # object confidence threshold for detection (default 0.25 predict, 0.001 val) conf: null # object confidence threshold for detection (default 0.25 predict, 0.001 val)
iou: 0.7 # intersection over union (IoU) threshold for NMS iou: 0.7 # intersection over union (IoU) threshold for NMS
max_det: 300 # maximum number of detections per image max_det: 300 # maximum number of detections per image
half: False # use half precision (FP16) half: False # use half precision (FP16)
dnn: False # use OpenCV DNN for ONNX inference dnn: False # use OpenCV DNN for ONNX inference
plots: True # show plots during training plots: True # show plots during training
# Prediction settings -------------------------------------------------------------------------------------------------- # Prediction settings --------------------------------------------------------------------------------------------------
source: null # source directory for images or videos source: null # source directory for images or videos
show: False # show results if possible show: False # show results if possible
save_txt: False # save results as .txt file save_txt: False # save results as .txt file
save_conf: False # save results with confidence scores save_conf: False # save results with confidence scores
save_crop: False # save cropped images with results save_crop: False # save cropped images with results
hide_labels: False # hide labels hide_labels: False # hide labels
hide_conf: False # hide confidence scores hide_conf: False # hide confidence scores
vid_stride: 1 # video frame-rate stride vid_stride: 1 # video frame-rate stride
line_thickness: 3 # bounding box thickness (pixels) line_thickness: 3 # bounding box thickness (pixels)
visualize: False # visualize results visualize: False # visualize results
augment: False # apply data augmentation to images augment: False # apply data augmentation to images
agnostic_nms: False # class-agnostic NMS agnostic_nms: False # class-agnostic NMS
retina_masks: False # use retina masks for object detection retina_masks: False # use retina masks for object detection
# Export settings ------------------------------------------------------------------------------------------------------ # Export settings ------------------------------------------------------------------------------------------------------
format: torchscript # format to export to format: torchscript # format to export to
keras: False # use Keras keras: False # use Keras
optimize: False # TorchScript: optimize for mobile optimize: False # TorchScript: optimize for mobile
int8: False # CoreML/TF INT8 quantization int8: False # CoreML/TF INT8 quantization
@ -100,12 +100,8 @@ mosaic: 1.0 # image mosaic (probability)
mixup: 0.0 # image mixup (probability) mixup: 0.0 # image mixup (probability)
copy_paste: 0.0 # segment copy-paste (probability) copy_paste: 0.0 # segment copy-paste (probability)
# Hydra configs -------------------------------------------------------------------------------------------------------- # Custom config.yaml ---------------------------------------------------------------------------------------------------
cfg: null # for overriding defaults.yaml cfg: null # for overriding defaults.yaml
hydra:
output_subdir: null # disable hydra directory creation
run:
dir: .
# Debug, do not modify ------------------------------------------------------------------------------------------------- # Debug, do not modify -------------------------------------------------------------------------------------------------
v5loader: False # use legacy YOLOv5 dataloader v5loader: False # use legacy YOLOv5 dataloader

@ -1,77 +0,0 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
import sys
from difflib import get_close_matches
from textwrap import dedent
import hydra
from hydra.errors import ConfigCompositionException
from omegaconf import OmegaConf, open_dict # noqa
from omegaconf.errors import ConfigAttributeError, ConfigKeyError, OmegaConfBaseException # noqa
from ultralytics.yolo.utils import LOGGER, colorstr
def override_config(overrides, cfg):
override_keys = [override.key_or_group for override in overrides]
check_config_mismatch(override_keys, cfg.keys())
for override in overrides:
if override.package is not None:
raise ConfigCompositionException(f"Override {override.input_line} looks like a config group"
f" override, but config group '{override.key_or_group}' does not exist.")
key = override.key_or_group
value = override.value()
try:
if override.is_delete():
config_val = OmegaConf.select(cfg, key, throw_on_missing=False)
if config_val is None:
raise ConfigCompositionException(f"Could not delete from config. '{override.key_or_group}'"
" does not exist.")
elif value is not None and value != config_val:
raise ConfigCompositionException("Could not delete from config. The value of"
f" '{override.key_or_group}' is {config_val} and not"
f" {value}.")
last_dot = key.rfind(".")
with open_dict(cfg):
if last_dot == -1:
del cfg[key]
else:
node = OmegaConf.select(cfg, key[:last_dot])
del node[key[last_dot + 1:]]
elif override.is_add():
if OmegaConf.select(cfg, key, throw_on_missing=False) is None or isinstance(value, (dict, list)):
OmegaConf.update(cfg, key, value, merge=True, force_add=True)
else:
assert override.input_line is not None
raise ConfigCompositionException(
dedent(f"""\
Could not append to config. An item is already at '{override.key_or_group}'.
Either remove + prefix: '{override.input_line[1:]}'
Or add a second + to add or override '{override.key_or_group}': '+{override.input_line}'
"""))
elif override.is_force_add():
OmegaConf.update(cfg, key, value, merge=True, force_add=True)
else:
try:
OmegaConf.update(cfg, key, value, merge=True)
except (ConfigAttributeError, ConfigKeyError) as ex:
raise ConfigCompositionException(f"Could not override '{override.key_or_group}'."
f"\nTo append to your config use +{override.input_line}") from ex
except OmegaConfBaseException as ex:
raise ConfigCompositionException(f"Error merging override {override.input_line}").with_traceback(
sys.exc_info()[2]) from ex
def check_config_mismatch(overrides, cfg):
mismatched = [option for option in overrides if option not in cfg and 'hydra.' not in option]
for option in mismatched:
LOGGER.info(f"{colorstr(option)} is not a valid key. Similar keys: {get_close_matches(option, cfg, 3, 0.6)}")
if mismatched:
sys.exit()
hydra._internal.config_loader_impl.ConfigLoaderImpl._apply_overrides_to_config = override_config

@ -69,8 +69,8 @@ def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank
augment=mode == "train", # augmentation augment=mode == "train", # augmentation
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
rect=cfg.rect if mode == "train" else True, # rectangular batches rect=cfg.rect if mode == "train" else True, # rectangular batches
cache=cfg.get("cache", None), cache=cfg.cache or None,
single_cls=cfg.get("single_cls", False), single_cls=cfg.single_cls or False,
stride=int(stride), stride=int(stride),
pad=0.0 if mode == "train" else 0.5, pad=0.0 if mode == "train" else 0.5,
prefix=colorstr(f"{mode}: "), prefix=colorstr(f"{mode}: "),

@ -29,7 +29,8 @@ from torch.utils.data import DataLoader, Dataset, dataloader, distributed
from tqdm import tqdm from tqdm import tqdm
from ultralytics.yolo.data.utils import check_dataset, unzip_file from ultralytics.yolo.data.utils import check_dataset, unzip_file
from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, is_colab, is_kaggle from ultralytics.yolo.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, is_colab, is_dir_writeable,
is_kaggle)
from ultralytics.yolo.utils.checks import check_requirements, check_yaml from ultralytics.yolo.utils.checks import check_requirements, check_yaml
from ultralytics.yolo.utils.ops import clean_str, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn from ultralytics.yolo.utils.ops import clean_str, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn
from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first
@ -493,7 +494,7 @@ class LoadImagesAndLabels(Dataset):
cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
assert cache['version'] == self.cache_version # matches current version assert cache['version'] == self.cache_version # matches current version
assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash
except (FileNotFoundError, AssertionError): except (FileNotFoundError, AssertionError, AttributeError):
cache, exists = self.cache_labels(cache_path, prefix), False # run cache ops cache, exists = self.cache_labels(cache_path, prefix), False # run cache ops
# Display cache # Display cache
@ -579,16 +580,17 @@ class LoadImagesAndLabels(Dataset):
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
self.im_hw0, self.im_hw = [None] * n, [None] * n self.im_hw0, self.im_hw = [None] * n, [None] * n
fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
results = ThreadPool(NUM_THREADS).imap(fcn, range(n)) with (Pool if n > 10000 else ThreadPool)(NUM_THREADS) as pool:
pbar = tqdm(enumerate(results), total=n, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0) results = pool.imap(fcn, range(n))
for i, x in pbar: pbar = tqdm(enumerate(results), total=n, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
if cache_images == 'disk': for i, x in pbar:
b += self.npy_files[i].stat().st_size if cache_images == 'disk':
else: # 'ram' b += self.npy_files[i].stat().st_size
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i) else: # 'ram'
b += self.ims[i].nbytes self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
pbar.desc = f'{prefix}Caching images ({b / gb:.1f}GB {cache_images})' b += self.ims[i].nbytes
pbar.close() pbar.desc = f'{prefix}Caching images ({b / gb:.1f}GB {cache_images})'
pbar.close()
def check_cache_ram(self, safety_margin=0.1, prefix=''): def check_cache_ram(self, safety_margin=0.1, prefix=''):
# Check image caching requirements vs available memory # Check image caching requirements vs available memory
@ -612,11 +614,10 @@ class LoadImagesAndLabels(Dataset):
x = {} # dict x = {} # dict
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
desc = f"{prefix}Scanning {path.parent / path.stem}..." desc = f"{prefix}Scanning {path.parent / path.stem}..."
with Pool(NUM_THREADS) as pool: total = len(self.im_files)
pbar = tqdm(pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix))), with (Pool if total > 10000 else ThreadPool)(NUM_THREADS) as pool:
desc=desc, results = pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix)))
total=len(self.im_files), pbar = tqdm(results, desc=desc, total=total, bar_format=TQDM_BAR_FORMAT)
bar_format=TQDM_BAR_FORMAT)
for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar: for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
nm += nm_f nm += nm_f
nf += nf_f nf += nf_f
@ -627,8 +628,8 @@ class LoadImagesAndLabels(Dataset):
if msg: if msg:
msgs.append(msg) msgs.append(msg)
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt" pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
pbar.close()
pbar.close()
if msgs: if msgs:
LOGGER.info('\n'.join(msgs)) LOGGER.info('\n'.join(msgs))
if nf == 0: if nf == 0:
@ -637,12 +638,12 @@ class LoadImagesAndLabels(Dataset):
x['results'] = nf, nm, ne, nc, len(self.im_files) x['results'] = nf, nm, ne, nc, len(self.im_files)
x['msgs'] = msgs # warnings x['msgs'] = msgs # warnings
x['version'] = self.cache_version # cache version x['version'] = self.cache_version # cache version
try: if is_dir_writeable(path.parent):
np.save(path, x) # save cache for next time np.save(str(path), x) # save cache for next time
path.with_suffix('.cache.npy').rename(path) # remove .npy suffix path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
LOGGER.info(f'{prefix}New cache created: {path}') LOGGER.info(f'{prefix}New cache created: {path}')
except Exception as e: else:
LOGGER.warning(f'{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable: {e}') # not writeable LOGGER.warning(f'{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable') # not writeable
return x return x
def __len__(self): def __len__(self):
@ -1148,8 +1149,10 @@ class HUBDatasetStats():
continue continue
dataset = LoadImagesAndLabels(self.data[split]) # load dataset dataset = LoadImagesAndLabels(self.data[split]) # load dataset
desc = f'{split} images' desc = f'{split} images'
for _ in tqdm(ThreadPool(NUM_THREADS).imap(self._hub_ops, dataset.im_files), total=dataset.n, desc=desc): total = dataset.n
pass with (Pool if total > 10000 else ThreadPool)(NUM_THREADS) as pool:
for _ in tqdm(pool.imap(self._hub_ops, dataset.im_files), total=total, desc=desc):
pass
print(f'Done. All images saved to {self.im_dir}') print(f'Done. All images saved to {self.im_dir}')
return self.im_dir return self.im_dir

@ -1,13 +1,13 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
from itertools import repeat from itertools import repeat
from multiprocessing.pool import Pool from multiprocessing.pool import Pool, ThreadPool
from pathlib import Path from pathlib import Path
import torchvision import torchvision
from tqdm import tqdm from tqdm import tqdm
from ..utils import NUM_THREADS, TQDM_BAR_FORMAT from ..utils import NUM_THREADS, TQDM_BAR_FORMAT, is_dir_writeable
from .augment import * from .augment import *
from .base import BaseDataset from .base import BaseDataset
from .utils import HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label from .utils import HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label
@ -50,14 +50,12 @@ class YOLODataset(BaseDataset):
x = {"labels": []} x = {"labels": []}
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
desc = f"{self.prefix}Scanning {path.parent / path.stem}..." desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
with Pool(NUM_THREADS) as pool: total = len(self.im_files)
pbar = tqdm( with (Pool if total > 10000 else ThreadPool)(NUM_THREADS) as pool:
pool.imap(verify_image_label, results = pool.imap(func=verify_image_label,
zip(self.im_files, self.label_files, repeat(self.prefix), repeat(self.use_keypoints))), iterable=zip(self.im_files, self.label_files, repeat(self.prefix),
desc=desc, repeat(self.use_keypoints)))
total=len(self.im_files), pbar = tqdm(results, desc=desc, total=total, bar_format=TQDM_BAR_FORMAT)
bar_format=TQDM_BAR_FORMAT,
)
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar: for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
nm += nm_f nm += nm_f
nf += nf_f nf += nf_f
@ -73,13 +71,12 @@ class YOLODataset(BaseDataset):
segments=segments, segments=segments,
keypoints=keypoint, keypoints=keypoint,
normalized=True, normalized=True,
bbox_format="xywh", bbox_format="xywh"))
))
if msg: if msg:
msgs.append(msg) msgs.append(msg)
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt" pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
pbar.close()
pbar.close()
if msgs: if msgs:
LOGGER.info("\n".join(msgs)) LOGGER.info("\n".join(msgs))
if nf == 0: if nf == 0:
@ -89,13 +86,12 @@ class YOLODataset(BaseDataset):
x["msgs"] = msgs # warnings x["msgs"] = msgs # warnings
x["version"] = self.cache_version # cache version x["version"] = self.cache_version # cache version
self.im_files = [lb["im_file"] for lb in x["labels"]] self.im_files = [lb["im_file"] for lb in x["labels"]]
try: if is_dir_writeable(path.parent):
np.save(path, x) # save cache for next time np.save(str(path), x) # save cache for next time
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
LOGGER.info(f"{self.prefix}New cache created: {path}") LOGGER.info(f"{self.prefix}New cache created: {path}")
except Exception as e: else:
LOGGER.warning( LOGGER.warning(f"{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable") # not writeable
f"{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable: {e}") # not writeable
return x return x
def get_labels(self): def get_labels(self):
@ -105,7 +101,7 @@ class YOLODataset(BaseDataset):
cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dict cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dict
assert cache["version"] == self.cache_version # matches current version assert cache["version"] == self.cache_version # matches current version
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
except (FileNotFoundError, AssertionError): except (FileNotFoundError, AssertionError, AttributeError):
cache, exists = self.cache_labels(cache_path), False # run cache ops cache, exists = self.cache_labels(cache_path), False # run cache ops
# Display cache # Display cache

@ -99,7 +99,7 @@ names:
# Download script/URL (optional) # Download script/URL (optional)
download: | download: |
from ultralytics.yoloutils.downloads import download from ultralytics.yolo.utils.downloads import download
from pathlib import Path from pathlib import Path
# Download labels # Download labels

@ -60,7 +60,6 @@ from collections import defaultdict
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
import hydra
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import torch import torch
@ -71,7 +70,7 @@ from ultralytics.nn.tasks import ClassificationModel, DetectionModel, Segmentati
from ultralytics.yolo.configs import get_config from ultralytics.yolo.configs import get_config
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
from ultralytics.yolo.data.utils import check_dataset from ultralytics.yolo.data.utils import check_dataset
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, callbacks, colorstr, get_default_args, yaml_save from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, get_default_args, yaml_save
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
from ultralytics.yolo.utils.files import file_size from ultralytics.yolo.utils.files import file_size
from ultralytics.yolo.utils.ops import Profile from ultralytics.yolo.utils.ops import Profile
@ -123,11 +122,11 @@ class Exporter:
A class for exporting a model. A class for exporting a model.
Attributes: Attributes:
args (OmegaConf): Configuration for the exporter. args (SimpleNamespace): Configuration for the exporter.
save_dir (Path): Directory to save results. save_dir (Path): Directory to save results.
""" """
def __init__(self, config=DEFAULT_CONFIG, overrides=None): def __init__(self, config=DEFAULT_CFG, overrides=None):
""" """
Initializes the Exporter class. Initializes the Exporter class.
@ -135,8 +134,6 @@ class Exporter:
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG. config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
overrides (dict, optional): Configuration overrides. Defaults to None. overrides (dict, optional): Configuration overrides. Defaults to None.
""" """
if overrides is None:
overrides = {}
self.args = get_config(config, overrides) self.args = get_config(config, overrides)
self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks
callbacks.add_integration_callbacks(self) callbacks.add_integration_callbacks(self)
@ -799,8 +796,7 @@ class Exporter:
callback(self) callback(self)
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def export(cfg=DEFAULT_CFG):
def export(cfg):
cfg.model = cfg.model or "yolov8n.yaml" cfg.model = cfg.model or "yolov8n.yaml"
cfg.format = cfg.format or "torchscript" cfg.format = cfg.format or "torchscript"
@ -818,7 +814,7 @@ def export(cfg):
from ultralytics import YOLO from ultralytics import YOLO
model = YOLO(cfg.model) model = YOLO(cfg.model)
model.export(**cfg) model.export(**vars(cfg))
if __name__ == "__main__": if __name__ == "__main__":

@ -6,7 +6,7 @@ from ultralytics import yolo # noqa
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight
from ultralytics.yolo.configs import get_config from ultralytics.yolo.configs import get_config
from ultralytics.yolo.engine.exporter import Exporter from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, yaml_load from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, yaml_load
from ultralytics.yolo.utils.checks import check_yaml from ultralytics.yolo.utils.checks import check_yaml
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode
@ -151,7 +151,7 @@ class YOLO:
overrides = self.overrides.copy() overrides = self.overrides.copy()
overrides.update(kwargs) overrides.update(kwargs)
overrides["mode"] = "val" overrides["mode"] = "val"
args = get_config(config=DEFAULT_CONFIG, overrides=overrides) args = get_config(config=DEFAULT_CFG_PATH, overrides=overrides)
args.data = data or args.data args.data = data or args.data
args.task = self.task args.task = self.task
@ -169,7 +169,7 @@ class YOLO:
overrides = self.overrides.copy() overrides = self.overrides.copy()
overrides.update(kwargs) overrides.update(kwargs)
args = get_config(config=DEFAULT_CONFIG, overrides=overrides) args = get_config(config=DEFAULT_CFG_PATH, overrides=overrides)
args.task = self.task args.task = self.task
print(args) print(args)

@ -36,7 +36,7 @@ from ultralytics.nn.autobackend import AutoBackend
from ultralytics.yolo.configs import get_config from ultralytics.yolo.configs import get_config
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, SETTINGS, callbacks, colorstr, ops from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, SETTINGS, callbacks, colorstr, ops
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow
from ultralytics.yolo.utils.files import increment_path from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
@ -49,7 +49,7 @@ class BasePredictor:
A base class for creating predictors. A base class for creating predictors.
Attributes: Attributes:
args (OmegaConf): Configuration for the predictor. args (SimpleNamespace): Configuration for the predictor.
save_dir (Path): Directory to save results. save_dir (Path): Directory to save results.
done_setup (bool): Whether the predictor has finished setup. done_setup (bool): Whether the predictor has finished setup.
model (nn.Module): Model used for prediction. model (nn.Module): Model used for prediction.
@ -62,7 +62,7 @@ class BasePredictor:
data_path (str): Path to data. data_path (str): Path to data.
""" """
def __init__(self, config=DEFAULT_CONFIG, overrides=None): def __init__(self, config=DEFAULT_CFG_PATH, overrides=None):
""" """
Initializes the BasePredictor class. Initializes the BasePredictor class.
@ -70,8 +70,6 @@ class BasePredictor:
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG. config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
overrides (dict, optional): Configuration overrides. Defaults to None. overrides (dict, optional): Configuration overrides. Defaults to None.
""" """
if overrides is None:
overrides = {}
self.args = get_config(config, overrides) self.args = get_config(config, overrides)
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
name = self.args.name or f"{self.args.mode}" name = self.args.name or f"{self.args.mode}"
@ -157,7 +155,7 @@ class BasePredictor:
if stream: if stream:
return self.stream_inference(source, model, verbose) return self.stream_inference(source, model, verbose)
else: else:
return list(chain(*list(self.stream_inference(source, model, verbose)))) # merge list of Result into one return list(self.stream_inference(source, model, verbose)) # merge list of Result into one
def predict_cli(self): def predict_cli(self):
# Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode # Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode
@ -211,7 +209,7 @@ class BasePredictor:
if self.args.save: if self.args.save:
self.save_preds(vid_cap, i, str(self.save_dir / p.name)) self.save_preds(vid_cap, i, str(self.save_dir / p.name))
yield results yield from results
# Print time (inference-only) # Print time (inference-only)
if verbose: if verbose:

@ -15,8 +15,6 @@ import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from omegaconf import OmegaConf # noqa
from omegaconf import open_dict
from torch.cuda import amp from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import lr_scheduler from torch.optim import lr_scheduler
@ -27,7 +25,7 @@ from ultralytics import __version__
from ultralytics.nn.tasks import attempt_load_one_weight from ultralytics.nn.tasks import attempt_load_one_weight
from ultralytics.yolo.configs import get_config from ultralytics.yolo.configs import get_config
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.utils import (DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, from ultralytics.yolo.utils import (DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr,
yaml_save) yaml_save)
from ultralytics.yolo.utils.autobatch import check_train_batch_size from ultralytics.yolo.utils.autobatch import check_train_batch_size
from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args
@ -43,7 +41,7 @@ class BaseTrainer:
A base class for creating trainers. A base class for creating trainers.
Attributes: Attributes:
args (OmegaConf): Configuration for the trainer. args (SimpleNamespace): Configuration for the trainer.
check_resume (method): Method to check if training should be resumed from a saved checkpoint. check_resume (method): Method to check if training should be resumed from a saved checkpoint.
console (logging.Logger): Logger instance. console (logging.Logger): Logger instance.
validator (BaseValidator): Validator instance. validator (BaseValidator): Validator instance.
@ -73,7 +71,7 @@ class BaseTrainer:
csv (Path): Path to results CSV file. csv (Path): Path to results CSV file.
""" """
def __init__(self, config=DEFAULT_CONFIG, overrides=None): def __init__(self, config=DEFAULT_CFG_PATH, overrides=None):
""" """
Initializes the BaseTrainer class. Initializes the BaseTrainer class.
@ -81,8 +79,6 @@ class BaseTrainer:
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG. config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
overrides (dict, optional): Configuration overrides. Defaults to None. overrides (dict, optional): Configuration overrides. Defaults to None.
""" """
if overrides is None:
overrides = {}
self.args = get_config(config, overrides) self.args = get_config(config, overrides)
self.device = utils.torch_utils.select_device(self.args.device, self.args.batch) self.device = utils.torch_utils.select_device(self.args.device, self.args.batch)
self.check_resume() self.check_resume()
@ -95,23 +91,23 @@ class BaseTrainer:
# Dirs # Dirs
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
name = self.args.name or f"{self.args.mode}" name = self.args.name or f"{self.args.mode}"
self.save_dir = Path( if hasattr(self.args, 'save_dir'):
self.args.get( self.save_dir = Path(self.args.save_dir)
"save_dir", else:
increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in {-1, 0} else True))) self.save_dir = Path(
increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in {-1, 0} else True))
self.wdir = self.save_dir / 'weights' # weights dir self.wdir = self.save_dir / 'weights' # weights dir
if RANK in {-1, 0}: if RANK in {-1, 0}:
self.wdir.mkdir(parents=True, exist_ok=True) # make dir self.wdir.mkdir(parents=True, exist_ok=True) # make dir
with open_dict(self.args): self.args.save_dir = str(self.save_dir)
self.args.save_dir = str(self.save_dir) yaml_save(self.save_dir / 'args.yaml', vars(self.args)) # save run args
yaml_save(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True)) # save run args
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
self.batch_size = self.args.batch self.batch_size = self.args.batch
self.epochs = self.args.epochs self.epochs = self.args.epochs
self.start_epoch = 0 self.start_epoch = 0
if RANK == -1: if RANK == -1:
print_args(dict(self.args)) print_args(vars(self.args))
# Device # Device
self.amp = self.device.type != 'cpu' self.amp = self.device.type != 'cpu'
@ -373,7 +369,7 @@ class BaseTrainer:
'ema': deepcopy(self.ema.ema).half(), 'ema': deepcopy(self.ema.ema).half(),
'updates': self.ema.updates, 'updates': self.ema.updates,
'optimizer': self.optimizer.state_dict(), 'optimizer': self.optimizer.state_dict(),
'train_args': self.args, 'train_args': vars(self.args), # save as dict
'date': datetime.now().isoformat(), 'date': datetime.now().isoformat(),
'version': __version__} 'version': __version__}

@ -5,12 +5,12 @@ from collections import defaultdict
from pathlib import Path from pathlib import Path
import torch import torch
from omegaconf import OmegaConf # noqa
from tqdm import tqdm from tqdm import tqdm
from ultralytics.nn.autobackend import AutoBackend from ultralytics.nn.autobackend import AutoBackend
from ultralytics.yolo.configs import get_config
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks
from ultralytics.yolo.utils.checks import check_imgsz from ultralytics.yolo.utils.checks import check_imgsz
from ultralytics.yolo.utils.files import increment_path from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.ops import Profile from ultralytics.yolo.utils.ops import Profile
@ -27,7 +27,7 @@ class BaseValidator:
dataloader (DataLoader): Dataloader to use for validation. dataloader (DataLoader): Dataloader to use for validation.
pbar (tqdm): Progress bar to update during validation. pbar (tqdm): Progress bar to update during validation.
logger (logging.Logger): Logger to use for validation. logger (logging.Logger): Logger to use for validation.
args (OmegaConf): Configuration for the validator. args (SimpleNamespace): Configuration for the validator.
model (nn.Module): Model to validate. model (nn.Module): Model to validate.
data (dict): Data dictionary. data (dict): Data dictionary.
device (torch.device): Device to use for validation. device (torch.device): Device to use for validation.
@ -47,12 +47,12 @@ class BaseValidator:
save_dir (Path): Directory to save results. save_dir (Path): Directory to save results.
pbar (tqdm.tqdm): Progress bar for displaying progress. pbar (tqdm.tqdm): Progress bar for displaying progress.
logger (logging.Logger): Logger to log messages. logger (logging.Logger): Logger to log messages.
args (OmegaConf): Configuration for the validator. args (SimpleNamespace): Configuration for the validator.
""" """
self.dataloader = dataloader self.dataloader = dataloader
self.pbar = pbar self.pbar = pbar
self.logger = logger or LOGGER self.logger = logger or LOGGER
self.args = args or OmegaConf.load(DEFAULT_CONFIG) self.args = args or get_config(DEFAULT_CFG_PATH)
self.model = None self.model = None
self.data = None self.data = None
self.device = None self.device = None

@ -8,6 +8,7 @@ import platform
import sys import sys
import tempfile import tempfile
import threading import threading
import types
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
@ -22,7 +23,7 @@ import yaml
# Constants # Constants
FILE = Path(__file__).resolve() FILE = Path(__file__).resolve()
ROOT = FILE.parents[2] # YOLO ROOT = FILE.parents[2] # YOLO
DEFAULT_CONFIG = ROOT / "yolo/configs/default.yaml" DEFAULT_CFG_PATH = ROOT / "yolo/configs/default.yaml"
RANK = int(os.getenv('RANK', -1)) RANK = int(os.getenv('RANK', -1))
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
@ -73,9 +74,10 @@ os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # for deterministic training os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # for deterministic training
# Default config dictionary # Default config dictionary
with open(DEFAULT_CONFIG, errors='ignore') as f: with open(DEFAULT_CFG_PATH, errors='ignore') as f:
DEFAULT_CONFIG_DICT = yaml.safe_load(f) DEFAULT_CFG_DICT = yaml.safe_load(f)
DEFAULT_CONFIG_KEYS = DEFAULT_CONFIG_DICT.keys() DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()
DEFAULT_CFG = types.SimpleNamespace(**DEFAULT_CFG_DICT)
def is_colab(): def is_colab():

@ -28,7 +28,7 @@ def generate_ddp_file(trainer):
if not trainer.resume: if not trainer.resume:
shutil.rmtree(trainer.save_dir) # remove the save_dir shutil.rmtree(trainer.save_dir) # remove the save_dir
content = f'''config = {dict(trainer.args)} \nif __name__ == "__main__": content = f'''config = {vars(trainer.args)} \nif __name__ == "__main__":
from ultralytics.{import_path} import {trainer.__class__.__name__} from ultralytics.{import_path} import {trainer.__class__.__name__}
trainer = {trainer.__class__.__name__}(config=config) trainer = {trainer.__class__.__name__}(config=config)

@ -18,7 +18,7 @@ import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
import ultralytics import ultralytics
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, DEFAULT_CONFIG_KEYS, LOGGER from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER
from ultralytics.yolo.utils.checks import git_describe from ultralytics.yolo.utils.checks import git_describe
from .checks import check_version from .checks import check_version
@ -288,7 +288,7 @@ def strip_optimizer(f='best.pt', s=''):
None None
""" """
x = torch.load(f, map_location=torch.device('cpu')) x = torch.load(f, map_location=torch.device('cpu'))
args = {**DEFAULT_CONFIG_DICT, **x['train_args']} # combine model args with default args, preferring model args args = {**DEFAULT_CFG_DICT, **x['train_args']} # combine model args with default args, preferring model args
if x.get('ema'): if x.get('ema'):
x['model'] = x['ema'] # replace model with ema x['model'] = x['ema'] # replace model with ema
for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
@ -297,7 +297,8 @@ def strip_optimizer(f='best.pt', s=''):
x['model'].half() # to FP16 x['model'].half() # to FP16
for p in x['model'].parameters(): for p in x['model'].parameters():
p.requires_grad = False p.requires_grad = False
x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS} # strip non-default keys x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
# x['model'].args = x['train_args']
torch.save(x, s or f) torch.save(x, s or f)
mb = os.path.getsize(s or f) / 1E6 # filesize mb = os.path.getsize(s or f) / 1E6 # filesize
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB") LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")

@ -1,6 +1,5 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
from ultralytics.yolo.configs import hydra_patch # noqa (patch hydra CLI)
from ultralytics.yolo.v8 import classify, detect, segment from ultralytics.yolo.v8 import classify, detect, segment
__all__ = ["classify", "segment", "detect"] __all__ = ["classify", "segment", "detect"]

@ -1,11 +1,10 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
import hydra
import torch import torch
from ultralytics.yolo.engine.predictor import BasePredictor from ultralytics.yolo.engine.predictor import BasePredictor
from ultralytics.yolo.engine.results import Results from ultralytics.yolo.engine.results import Results
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, is_git_directory from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, is_git_directory
from ultralytics.yolo.utils.plotting import Annotator from ultralytics.yolo.utils.plotting import Annotator
@ -64,8 +63,7 @@ class ClassificationPredictor(BasePredictor):
return log_string return log_string
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def predict(cfg=DEFAULT_CFG):
def predict(cfg):
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18" cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if is_git_directory() \ cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if is_git_directory() \
else "https://ultralytics.com/images/bus.jpg" else "https://ultralytics.com/images/bus.jpg"

@ -1,6 +1,5 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
import hydra
import torch import torch
import torchvision import torchvision
@ -8,13 +7,13 @@ from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight
from ultralytics.yolo import v8 from ultralytics.yolo import v8
from ultralytics.yolo.data import build_classification_dataloader from ultralytics.yolo.data import build_classification_dataloader
from ultralytics.yolo.engine.trainer import BaseTrainer from ultralytics.yolo.engine.trainer import BaseTrainer
from ultralytics.yolo.utils import DEFAULT_CONFIG from ultralytics.yolo.utils import DEFAULT_CFG
from ultralytics.yolo.utils.torch_utils import strip_optimizer from ultralytics.yolo.utils.torch_utils import strip_optimizer
class ClassificationTrainer(BaseTrainer): class ClassificationTrainer(BaseTrainer):
def __init__(self, config=DEFAULT_CONFIG, overrides=None): def __init__(self, config=DEFAULT_CFG, overrides=None):
if overrides is None: if overrides is None:
overrides = {} overrides = {}
overrides["task"] = "classify" overrides["task"] = "classify"
@ -136,8 +135,7 @@ class ClassificationTrainer(BaseTrainer):
# self.run_callbacks('on_fit_epoch_end') # self.run_callbacks('on_fit_epoch_end')
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def train(cfg=DEFAULT_CFG):
def train(cfg):
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18" cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
cfg.data = cfg.data or "mnist160" # or yolo.ClassificationDataset("mnist") cfg.data = cfg.data or "mnist160" # or yolo.ClassificationDataset("mnist")
@ -152,7 +150,7 @@ def train(cfg):
# trainer.train() # trainer.train()
from ultralytics import YOLO from ultralytics import YOLO
model = YOLO(cfg.model) model = YOLO(cfg.model)
model.train(**cfg) model.train(**vars(cfg))
if __name__ == "__main__": if __name__ == "__main__":

@ -1,10 +1,8 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
import hydra
from ultralytics.yolo.data import build_classification_dataloader from ultralytics.yolo.data import build_classification_dataloader
from ultralytics.yolo.engine.validator import BaseValidator from ultralytics.yolo.engine.validator import BaseValidator
from ultralytics.yolo.utils import DEFAULT_CONFIG from ultralytics.yolo.utils import DEFAULT_CFG
from ultralytics.yolo.utils.metrics import ClassifyMetrics from ultralytics.yolo.utils.metrics import ClassifyMetrics
@ -46,8 +44,7 @@ class ClassificationValidator(BaseValidator):
self.logger.info(pf % ("all", self.metrics.top1, self.metrics.top5)) self.logger.info(pf % ("all", self.metrics.top1, self.metrics.top5))
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def val(cfg=DEFAULT_CFG):
def val(cfg):
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18" cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
cfg.data = cfg.data or "imagenette160" cfg.data = cfg.data or "imagenette160"
validator = ClassificationValidator(args=cfg) validator = ClassificationValidator(args=cfg)

@ -1,11 +1,10 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
import hydra
import torch import torch
from ultralytics.yolo.engine.predictor import BasePredictor from ultralytics.yolo.engine.predictor import BasePredictor
from ultralytics.yolo.engine.results import Results from ultralytics.yolo.engine.results import Results
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, is_git_directory, ops from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, is_git_directory, ops
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
@ -81,8 +80,7 @@ class DetectionPredictor(BasePredictor):
return log_string return log_string
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def predict(cfg=DEFAULT_CFG):
def predict(cfg):
cfg.model = cfg.model or "yolov8n.pt" cfg.model = cfg.model or "yolov8n.pt"
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if is_git_directory() \ cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if is_git_directory() \
else "https://ultralytics.com/images/bus.jpg" else "https://ultralytics.com/images/bus.jpg"

@ -2,7 +2,6 @@
from copy import copy from copy import copy
import hydra
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -11,7 +10,7 @@ from ultralytics.yolo import v8
from ultralytics.yolo.data import build_dataloader from ultralytics.yolo.data import build_dataloader
from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader
from ultralytics.yolo.engine.trainer import BaseTrainer from ultralytics.yolo.engine.trainer import BaseTrainer
from ultralytics.yolo.utils import DEFAULT_CONFIG, colorstr from ultralytics.yolo.utils import DEFAULT_CFG, colorstr
from ultralytics.yolo.utils.loss import BboxLoss from ultralytics.yolo.utils.loss import BboxLoss
from ultralytics.yolo.utils.ops import xywh2xyxy from ultralytics.yolo.utils.ops import xywh2xyxy
from ultralytics.yolo.utils.plotting import plot_images, plot_results from ultralytics.yolo.utils.plotting import plot_images, plot_results
@ -30,7 +29,7 @@ class DetectionTrainer(BaseTrainer):
imgsz=self.args.imgsz, imgsz=self.args.imgsz,
batch_size=batch_size, batch_size=batch_size,
stride=gs, stride=gs,
hyp=dict(self.args), hyp=vars(self.args),
augment=mode == "train", augment=mode == "train",
cache=self.args.cache, cache=self.args.cache,
pad=0 if mode == "train" else 0.5, pad=0 if mode == "train" else 0.5,
@ -195,8 +194,7 @@ class Loss:
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl) return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def train(cfg=DEFAULT_CFG):
def train(cfg):
cfg.model = cfg.model or "yolov8n.pt" cfg.model = cfg.model or "yolov8n.pt"
cfg.data = cfg.data or "coco128.yaml" # or yolo.ClassificationDataset("mnist") cfg.data = cfg.data or "coco128.yaml" # or yolo.ClassificationDataset("mnist")
cfg.device = cfg.device if cfg.device is not None else '' cfg.device = cfg.device if cfg.device is not None else ''
@ -204,7 +202,7 @@ def train(cfg):
# trainer.train() # trainer.train()
from ultralytics import YOLO from ultralytics import YOLO
model = YOLO(cfg.model) model = YOLO(cfg.model)
model.train(**cfg) model.train(**vars(cfg))
if __name__ == "__main__": if __name__ == "__main__":

@ -3,14 +3,13 @@
import os import os
from pathlib import Path from pathlib import Path
import hydra
import numpy as np import numpy as np
import torch import torch
from ultralytics.yolo.data import build_dataloader from ultralytics.yolo.data import build_dataloader
from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader
from ultralytics.yolo.engine.validator import BaseValidator from ultralytics.yolo.engine.validator import BaseValidator
from ultralytics.yolo.utils import DEFAULT_CONFIG, colorstr, ops, yaml_load from ultralytics.yolo.utils import DEFAULT_CFG, colorstr, ops, yaml_load
from ultralytics.yolo.utils.checks import check_file, check_requirements from ultralytics.yolo.utils.checks import check_file, check_requirements
from ultralytics.yolo.utils.metrics import ConfusionMatrix, DetMetrics, box_iou from ultralytics.yolo.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
from ultralytics.yolo.utils.plotting import output_to_target, plot_images from ultralytics.yolo.utils.plotting import output_to_target, plot_images
@ -168,7 +167,7 @@ class DetectionValidator(BaseValidator):
imgsz=self.args.imgsz, imgsz=self.args.imgsz,
batch_size=batch_size, batch_size=batch_size,
stride=gs, stride=gs,
hyp=dict(self.args), hyp=vars(self.args),
cache=False, cache=False,
pad=0.5, pad=0.5,
rect=True, rect=True,
@ -232,8 +231,7 @@ class DetectionValidator(BaseValidator):
return stats return stats
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def val(cfg=DEFAULT_CFG):
def val(cfg):
cfg.model = cfg.model or "yolov8n.pt" cfg.model = cfg.model or "yolov8n.pt"
cfg.data = cfg.data or "coco128.yaml" cfg.data = cfg.data or "coco128.yaml"
validator = DetectionValidator(args=cfg) validator = DetectionValidator(args=cfg)

@ -1,10 +1,9 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
import hydra
import torch import torch
from ultralytics.yolo.engine.results import Results from ultralytics.yolo.engine.results import Results
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, is_git_directory, ops from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, is_git_directory, ops
from ultralytics.yolo.utils.plotting import colors, save_one_box from ultralytics.yolo.utils.plotting import colors, save_one_box
from ultralytics.yolo.v8.detect.predict import DetectionPredictor from ultralytics.yolo.v8.detect.predict import DetectionPredictor
@ -98,8 +97,7 @@ class SegmentationPredictor(DetectionPredictor):
return log_string return log_string
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def predict(cfg=DEFAULT_CFG):
def predict(cfg):
cfg.model = cfg.model or "yolov8n-seg.pt" cfg.model = cfg.model or "yolov8n-seg.pt"
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if is_git_directory() \ cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if is_git_directory() \
else "https://ultralytics.com/images/bus.jpg" else "https://ultralytics.com/images/bus.jpg"

@ -2,13 +2,12 @@
from copy import copy from copy import copy
import hydra
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from ultralytics.nn.tasks import SegmentationModel from ultralytics.nn.tasks import SegmentationModel
from ultralytics.yolo import v8 from ultralytics.yolo import v8
from ultralytics.yolo.utils import DEFAULT_CONFIG from ultralytics.yolo.utils import DEFAULT_CFG
from ultralytics.yolo.utils.ops import crop_mask, xyxy2xywh from ultralytics.yolo.utils.ops import crop_mask, xyxy2xywh
from ultralytics.yolo.utils.plotting import plot_images, plot_results from ultralytics.yolo.utils.plotting import plot_images, plot_results
from ultralytics.yolo.utils.tal import make_anchors from ultralytics.yolo.utils.tal import make_anchors
@ -19,7 +18,7 @@ from ultralytics.yolo.v8.detect.train import Loss
# BaseTrainer python usage # BaseTrainer python usage
class SegmentationTrainer(v8.detect.DetectionTrainer): class SegmentationTrainer(v8.detect.DetectionTrainer):
def __init__(self, config=DEFAULT_CONFIG, overrides=None): def __init__(self, config=DEFAULT_CFG, overrides=None):
if overrides is None: if overrides is None:
overrides = {} overrides = {}
overrides["task"] = "segment" overrides["task"] = "segment"
@ -141,8 +140,7 @@ class SegLoss(Loss):
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean() return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def train(cfg=DEFAULT_CFG):
def train(cfg):
cfg.model = cfg.model or "yolov8n-seg.pt" cfg.model = cfg.model or "yolov8n-seg.pt"
cfg.data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist") cfg.data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist")
cfg.device = cfg.device if cfg.device is not None else '' cfg.device = cfg.device if cfg.device is not None else ''
@ -150,7 +148,7 @@ def train(cfg):
# trainer.train() # trainer.train()
from ultralytics import YOLO from ultralytics import YOLO
model = YOLO(cfg.model) model = YOLO(cfg.model)
model.train(**cfg) model.train(**vars(cfg))
if __name__ == "__main__": if __name__ == "__main__":

@ -4,12 +4,11 @@ import os
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool
from pathlib import Path from pathlib import Path
import hydra
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from ultralytics.yolo.utils import DEFAULT_CONFIG, NUM_THREADS, ops from ultralytics.yolo.utils import DEFAULT_CFG, NUM_THREADS, ops
from ultralytics.yolo.utils.checks import check_requirements from ultralytics.yolo.utils.checks import check_requirements
from ultralytics.yolo.utils.metrics import ConfusionMatrix, SegmentMetrics, box_iou, mask_iou from ultralytics.yolo.utils.metrics import ConfusionMatrix, SegmentMetrics, box_iou, mask_iou
from ultralytics.yolo.utils.plotting import output_to_target, plot_images from ultralytics.yolo.utils.plotting import output_to_target, plot_images
@ -243,8 +242,7 @@ class SegmentationValidator(DetectionValidator):
return stats return stats
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def val(cfg=DEFAULT_CFG):
def val(cfg):
cfg.data = cfg.data or "coco128-seg.yaml" cfg.data = cfg.data or "coco128-seg.yaml"
validator = SegmentationValidator(args=cfg) validator = SegmentationValidator(args=cfg)
validator(model=cfg.model) validator(model=cfg.model)

Loading…
Cancel
Save