ultralytics 8.0.12
- Hydra removal (#506)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Pronoy Mandal <lukex9442@gmail.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
__version__ = "8.0.11"
|
||||
__version__ = "8.0.12"
|
||||
|
||||
from ultralytics.yolo.engine.model import YOLO
|
||||
from ultralytics.yolo.utils import ops
|
||||
|
@ -7,7 +7,7 @@ import time
|
||||
|
||||
import requests
|
||||
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, LOGGER, RANK, SETTINGS, TryExcept, colorstr, emojis
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, LOGGER, RANK, SETTINGS, TryExcept, colorstr, emojis
|
||||
|
||||
PREFIX = colorstr('Ultralytics: ')
|
||||
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
|
||||
@ -143,7 +143,7 @@ def sync_analytics(cfg, all_keys=False, enabled=False):
|
||||
if SETTINGS['sync'] and RANK in {-1, 0} and enabled:
|
||||
cfg = dict(cfg) # convert type from DictConfig to dict
|
||||
if not all_keys:
|
||||
cfg = {k: v for k, v in cfg.items() if v != DEFAULT_CONFIG_DICT.get(k, None)} # retain non-default values
|
||||
cfg = {k: v for k, v in cfg.items() if v != DEFAULT_CFG_DICT.get(k, None)} # retain non-default values
|
||||
cfg['uuid'] = SETTINGS['uuid'] # add the device UUID to the configuration data
|
||||
|
||||
# Send a request to the HUB API to sync analytics
|
||||
|
@ -10,7 +10,7 @@ import torch.nn as nn
|
||||
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
|
||||
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
|
||||
GhostBottleneck, GhostConv, Segment)
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, DEFAULT_CONFIG_KEYS, LOGGER, colorstr, yaml_load
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_yaml
|
||||
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible,
|
||||
model_info, scale_img, time_sync)
|
||||
@ -113,7 +113,7 @@ class BaseModel(nn.Module):
|
||||
thresh (int, optional): The threshold number of BatchNorm layers. Default is 10.
|
||||
|
||||
Returns:
|
||||
bool: True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
|
||||
(bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
|
||||
"""
|
||||
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
||||
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
|
||||
@ -321,11 +321,11 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
||||
model = Ensemble()
|
||||
for w in weights if isinstance(weights, list) else [weights]:
|
||||
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
|
||||
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
|
||||
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
|
||||
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
||||
|
||||
# Model compatibility updates
|
||||
ckpt.args = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS} # attach args to model
|
||||
ckpt.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
|
||||
ckpt.pt_path = weights # attach *.pt file path to model
|
||||
if not hasattr(ckpt, 'stride'):
|
||||
ckpt.stride = torch.tensor([32.])
|
||||
@ -359,11 +359,11 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
||||
from ultralytics.yolo.utils.downloads import attempt_download
|
||||
|
||||
ckpt = torch.load(attempt_download(weight), map_location='cpu') # load
|
||||
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
|
||||
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
|
||||
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
||||
|
||||
# Model compatibility updates
|
||||
model.args = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS} # attach args to model
|
||||
model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
|
||||
model.pt_path = weight # attach *.pt file path to model
|
||||
if not hasattr(model, 'stride'):
|
||||
model.stride = torch.tensor([32.])
|
||||
|
@ -1,156 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics import __version__, yolo
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, PREFIX, checks, print_settings, yaml_load
|
||||
|
||||
DIR = Path(__file__).parent
|
||||
|
||||
CLI_HELP_MSG = \
|
||||
"""
|
||||
YOLOv8 CLI Usage examples:
|
||||
|
||||
1. Install the ultralytics package:
|
||||
|
||||
pip install ultralytics
|
||||
|
||||
2. Train, Val, Predict and Export using 'yolo' commands:
|
||||
|
||||
yolo TASK MODE ARGS
|
||||
|
||||
Where TASK (optional) is one of [detect, segment, classify]
|
||||
MODE (required) is one of [train, val, predict, export]
|
||||
ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
|
||||
For a full list of available ARGS see https://docs.ultralytics.com/config.
|
||||
|
||||
Train a detection model for 10 epochs with an initial learning_rate of 0.01
|
||||
yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
|
||||
|
||||
Predict a YouTube video using a pretrained segmentation model at image size 320:
|
||||
yolo segment predict model=yolov8n-seg.pt source=https://youtu.be/Zgi9g1ksQHc imgsz=320
|
||||
|
||||
Validate a pretrained detection model at batch-size 1 and image size 640:
|
||||
yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
|
||||
|
||||
Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
|
||||
yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
|
||||
|
||||
3. Run special commands:
|
||||
|
||||
yolo help
|
||||
yolo checks
|
||||
yolo version
|
||||
yolo settings
|
||||
yolo copy-config
|
||||
|
||||
Docs: https://docs.ultralytics.com/cli
|
||||
Community: https://community.ultralytics.com
|
||||
GitHub: https://github.com/ultralytics/ultralytics
|
||||
"""
|
||||
|
||||
|
||||
def cli(cfg):
|
||||
"""
|
||||
Run a specified task and mode with the given configuration.
|
||||
|
||||
Args:
|
||||
cfg (DictConfig): Configuration for the task and mode.
|
||||
"""
|
||||
# LOGGER.info(f"{colorstr(f'Ultralytics YOLO v{ultralytics.__version__}')}")
|
||||
from ultralytics.yolo.configs import get_config
|
||||
|
||||
if cfg.cfg:
|
||||
LOGGER.info(f"{PREFIX}Overriding default config with {cfg.cfg}")
|
||||
cfg = get_config(cfg.cfg)
|
||||
task, mode = cfg.task.lower(), cfg.mode.lower()
|
||||
|
||||
# Mapping from task to module
|
||||
tasks = {"detect": yolo.v8.detect, "segment": yolo.v8.segment, "classify": yolo.v8.classify}
|
||||
module = tasks.get(task)
|
||||
if not module:
|
||||
raise SyntaxError(f"yolo task={task} is invalid. Valid tasks are: {', '.join(tasks.keys())}\n{CLI_HELP_MSG}")
|
||||
|
||||
# Mapping from mode to function
|
||||
modes = {"train": module.train, "val": module.val, "predict": module.predict, "export": yolo.engine.exporter.export}
|
||||
func = modes.get(mode)
|
||||
if not func:
|
||||
raise SyntaxError(f"yolo mode={mode} is invalid. Valid modes are: {', '.join(modes.keys())}\n{CLI_HELP_MSG}")
|
||||
|
||||
func(cfg)
|
||||
|
||||
|
||||
def entrypoint():
|
||||
"""
|
||||
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
|
||||
to the package. It's a combination of argparse and hydra.
|
||||
|
||||
This function allows for:
|
||||
- passing mandatory YOLO args as a list of strings
|
||||
- specifying the task to be performed, either 'detect', 'segment' or 'classify'
|
||||
- specifying the mode, either 'train', 'val', 'test', or 'predict'
|
||||
- running special modes like 'checks'
|
||||
- passing overrides to the package's configuration
|
||||
|
||||
It uses the package's default config and initializes it using the passed overrides.
|
||||
Then it calls the CLI function with the composed config
|
||||
"""
|
||||
if len(sys.argv) == 1: # no arguments passed
|
||||
LOGGER.info(CLI_HELP_MSG)
|
||||
return
|
||||
|
||||
parser = argparse.ArgumentParser(description='YOLO parser')
|
||||
parser.add_argument('args', type=str, nargs='+', help='YOLO args')
|
||||
args = parser.parse_args().args
|
||||
args = re.sub(r'\s*=\s*', '=', ' '.join(args)).split(' ') # remove whitespaces around = sign
|
||||
|
||||
tasks = 'detect', 'segment', 'classify'
|
||||
modes = 'train', 'val', 'predict', 'export'
|
||||
special_modes = {
|
||||
'help': lambda: LOGGER.info(CLI_HELP_MSG),
|
||||
'checks': checks.check_yolo,
|
||||
'version': lambda: LOGGER.info(__version__),
|
||||
'settings': print_settings,
|
||||
'copy-config': copy_default_config}
|
||||
|
||||
overrides = [] # basic overrides, i.e. imgsz=320
|
||||
defaults = yaml_load(DEFAULT_CONFIG)
|
||||
for a in args:
|
||||
if '=' in a:
|
||||
overrides.append(a)
|
||||
elif a in tasks:
|
||||
overrides.append(f'task={a}')
|
||||
elif a in modes:
|
||||
overrides.append(f'mode={a}')
|
||||
elif a in special_modes:
|
||||
special_modes[a]()
|
||||
return
|
||||
elif a in defaults and defaults[a] is False:
|
||||
overrides.append(f'{a}=True') # auto-True for default False args, i.e. yolo show
|
||||
elif a in defaults:
|
||||
raise SyntaxError(f"'{a}' is a valid YOLO argument but is missing an '=' sign to set its value, "
|
||||
f"i.e. try '{a}={defaults[a]}'"
|
||||
f"\n{CLI_HELP_MSG}")
|
||||
else:
|
||||
raise SyntaxError(
|
||||
f"'{a}' is not a valid YOLO argument. For a full list of valid arguments see "
|
||||
f"https://github.com/ultralytics/ultralytics/blob/main/ultralytics/yolo/configs/default.yaml"
|
||||
f"\n{CLI_HELP_MSG}")
|
||||
|
||||
from hydra import compose, initialize
|
||||
|
||||
with initialize(version_base=None, config_path=str(DEFAULT_CONFIG.parent.relative_to(DIR)), job_name="YOLO"):
|
||||
cfg = compose(config_name=DEFAULT_CONFIG.name, overrides=overrides)
|
||||
cli(cfg)
|
||||
|
||||
|
||||
# Special modes --------------------------------------------------------------------------------------------------------
|
||||
def copy_default_config():
|
||||
new_file = Path.cwd() / DEFAULT_CONFIG.name.replace('.yaml', '_copy.yaml')
|
||||
shutil.copy2(DEFAULT_CONFIG, new_file)
|
||||
LOGGER.info(f"{PREFIX}{DEFAULT_CONFIG} copied to {new_file}\n"
|
||||
f"Usage for running YOLO with this new custom config:\nyolo cfg={new_file} args...")
|
@ -1,36 +1,221 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
from difflib import get_close_matches
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Dict, Union
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from ultralytics import __version__, yolo
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, PREFIX, checks, colorstr, print_settings, yaml_load
|
||||
|
||||
from ultralytics.yolo.configs.hydra_patch import check_config_mismatch
|
||||
DIR = Path(__file__).parent
|
||||
|
||||
CLI_HELP_MSG = \
|
||||
"""
|
||||
YOLOv8 CLI Usage examples:
|
||||
|
||||
1. Install the ultralytics package:
|
||||
|
||||
pip install ultralytics
|
||||
|
||||
2. Train, Val, Predict and Export using 'yolo' commands:
|
||||
|
||||
yolo TASK MODE ARGS
|
||||
|
||||
Where TASK (optional) is one of [detect, segment, classify]
|
||||
MODE (required) is one of [train, val, predict, export]
|
||||
ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
|
||||
For a full list of available ARGS see https://docs.ultralytics.com/config.
|
||||
|
||||
Train a detection model for 10 epochs with an initial learning_rate of 0.01
|
||||
yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
|
||||
|
||||
Predict a YouTube video using a pretrained segmentation model at image size 320:
|
||||
yolo segment predict model=yolov8n-seg.pt source=https://youtu.be/Zgi9g1ksQHc imgsz=320
|
||||
|
||||
Validate a pretrained detection model at batch-size 1 and image size 640:
|
||||
yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
|
||||
|
||||
Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
|
||||
yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
|
||||
|
||||
3. Run special commands:
|
||||
|
||||
yolo help
|
||||
yolo checks
|
||||
yolo version
|
||||
yolo settings
|
||||
yolo copy-config
|
||||
|
||||
Docs: https://docs.ultralytics.com/cli
|
||||
Community: https://community.ultralytics.com
|
||||
GitHub: https://github.com/ultralytics/ultralytics
|
||||
"""
|
||||
|
||||
|
||||
def get_config(config: Union[str, Path, DictConfig], overrides: Union[str, Dict] = None):
|
||||
def cfg2dict(cfg):
|
||||
"""
|
||||
Convert a configuration object to a dictionary.
|
||||
|
||||
This function converts a configuration object to a dictionary, whether it is a file path, a string, or a SimpleNamespace object.
|
||||
|
||||
Inputs:
|
||||
cfg (str) or (Path) or (SimpleNamespace): Configuration object to be converted to a dictionary.
|
||||
|
||||
Returns:
|
||||
cfg (dict): Configuration object in dictionary format.
|
||||
"""
|
||||
if isinstance(cfg, (str, Path)):
|
||||
cfg = yaml_load(cfg) # load dict
|
||||
elif isinstance(cfg, SimpleNamespace):
|
||||
cfg = vars(cfg) # convert to dict
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config(config: Union[str, Path, Dict, SimpleNamespace], overrides: Dict = None):
|
||||
"""
|
||||
Load and merge configuration data from a file or dictionary.
|
||||
|
||||
Args:
|
||||
config (str) or (Path) or (DictConfig): Configuration data in the form of a file name or a DictConfig object.
|
||||
overrides (str) or(Dict), optional: Overrides in the form of a file name or a dictionary. Default is None.
|
||||
config (str) or (Path) or (Dict) or (SimpleNamespace): Configuration data.
|
||||
overrides (str) or (Dict), optional: Overrides in the form of a file name or a dictionary. Default is None.
|
||||
|
||||
Returns:
|
||||
OmegaConf.Namespace: Training arguments namespace.
|
||||
(SimpleNamespace): Training arguments namespace.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
if isinstance(config, (str, Path)):
|
||||
config = OmegaConf.load(config)
|
||||
elif isinstance(config, Dict):
|
||||
config = OmegaConf.create(config)
|
||||
# override
|
||||
if isinstance(overrides, str):
|
||||
overrides = OmegaConf.load(overrides)
|
||||
elif isinstance(overrides, Dict):
|
||||
overrides = OmegaConf.create(overrides)
|
||||
config = cfg2dict(config)
|
||||
|
||||
check_config_mismatch(dict(overrides).keys(), dict(config).keys())
|
||||
# Merge overrides
|
||||
if overrides:
|
||||
overrides = cfg2dict(overrides)
|
||||
check_config_mismatch(config, overrides)
|
||||
config = {**config, **overrides} # merge config and overrides dicts (prefer overrides)
|
||||
|
||||
return OmegaConf.merge(config, overrides)
|
||||
# Return instance
|
||||
return SimpleNamespace(**config)
|
||||
|
||||
|
||||
def check_config_mismatch(base: Dict, custom: Dict):
|
||||
"""
|
||||
This function checks for any mismatched keys between a custom configuration list and a base configuration list.
|
||||
If any mismatched keys are found, the function prints out similar keys from the base list and exits the program.
|
||||
|
||||
Inputs:
|
||||
- custom (Dict): a dictionary of custom configuration options
|
||||
- base (Dict): a dictionary of base configuration options
|
||||
"""
|
||||
base, custom = (set(x.keys()) for x in (base, custom))
|
||||
mismatched = [x for x in custom if x not in base]
|
||||
for option in mismatched:
|
||||
LOGGER.info(f"{colorstr(option)} is not a valid key. Similar keys: {get_close_matches(option, base, 3, 0.6)}")
|
||||
if mismatched:
|
||||
sys.exit()
|
||||
|
||||
|
||||
def entrypoint(debug=True):
|
||||
"""
|
||||
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
|
||||
to the package.
|
||||
|
||||
This function allows for:
|
||||
- passing mandatory YOLO args as a list of strings
|
||||
- specifying the task to be performed, either 'detect', 'segment' or 'classify'
|
||||
- specifying the mode, either 'train', 'val', 'test', or 'predict'
|
||||
- running special modes like 'checks'
|
||||
- passing overrides to the package's configuration
|
||||
|
||||
It uses the package's default config and initializes it using the passed overrides.
|
||||
Then it calls the CLI function with the composed config
|
||||
"""
|
||||
if debug:
|
||||
args = ['train', 'predict', 'model=yolov8n.pt'] # for testing
|
||||
else:
|
||||
if len(sys.argv) == 1: # no arguments passed
|
||||
LOGGER.info(CLI_HELP_MSG)
|
||||
return
|
||||
|
||||
parser = argparse.ArgumentParser(description='YOLO parser')
|
||||
parser.add_argument('args', type=str, nargs='+', help='YOLO args')
|
||||
args = parser.parse_args().args
|
||||
args = re.sub(r'\s*=\s*', '=', ' '.join(args)).split(' ') # remove whitespaces around = sign
|
||||
|
||||
tasks = 'detect', 'segment', 'classify'
|
||||
modes = 'train', 'val', 'predict', 'export'
|
||||
special_modes = {
|
||||
'help': lambda: LOGGER.info(CLI_HELP_MSG),
|
||||
'checks': checks.check_yolo,
|
||||
'version': lambda: LOGGER.info(__version__),
|
||||
'settings': print_settings,
|
||||
'copy-config': copy_default_config}
|
||||
|
||||
overrides = {} # basic overrides, i.e. imgsz=320
|
||||
defaults = yaml_load(DEFAULT_CFG_PATH)
|
||||
for a in args:
|
||||
if '=' in a:
|
||||
if a.startswith('cfg='): # custom.yaml passed
|
||||
custom_config = Path(a.split('=')[-1])
|
||||
LOGGER.info(f"{PREFIX}Overriding {DEFAULT_CFG_PATH} with {custom_config}")
|
||||
overrides = {k: v for k, v in yaml_load(custom_config).items() if k not in {'cfg'}}
|
||||
else:
|
||||
k, v = a.split('=')
|
||||
try:
|
||||
if k == 'device': # special DDP handling, i.e. device='0,1,2,3'
|
||||
v = v.replace('[', '').replace(']', '') # handle device=[0,1,2,3]
|
||||
v = v.replace(" ", "").replace('') # handle device=[0, 1, 2, 3]
|
||||
v = v.replace('\\', '') # handle device=\'0,1,2,3\'
|
||||
overrides[k] = v
|
||||
else:
|
||||
overrides[k] = eval(v) # convert strings to integers, floats, bools, etc.
|
||||
except (NameError, SyntaxError):
|
||||
overrides[k] = v
|
||||
elif a in tasks:
|
||||
overrides['task'] = a
|
||||
elif a in modes:
|
||||
overrides['mode'] = a
|
||||
elif a in special_modes:
|
||||
special_modes[a]()
|
||||
return
|
||||
elif a in defaults and defaults[a] is False:
|
||||
overrides[a] = True # auto-True for default False args, i.e. 'yolo show' sets show=True
|
||||
elif a in defaults:
|
||||
raise SyntaxError(f"'{a}' is a valid YOLO argument but is missing an '=' sign to set its value, "
|
||||
f"i.e. try '{a}={defaults[a]}'"
|
||||
f"\n{CLI_HELP_MSG}")
|
||||
else:
|
||||
raise SyntaxError(
|
||||
f"'{a}' is not a valid YOLO argument. For a full list of valid arguments see "
|
||||
f"https://github.com/ultralytics/ultralytics/blob/main/ultralytics/yolo/configs/default.yaml"
|
||||
f"\n{CLI_HELP_MSG}")
|
||||
|
||||
cfg = get_config(defaults, overrides) # create CFG instance
|
||||
|
||||
# Mapping from task to module
|
||||
module = {"detect": yolo.v8.detect, "segment": yolo.v8.segment, "classify": yolo.v8.classify}.get(cfg.task)
|
||||
if not module:
|
||||
raise SyntaxError(f"yolo task={cfg.task} is invalid. Valid tasks are: {', '.join(tasks)}\n{CLI_HELP_MSG}")
|
||||
|
||||
# Mapping from mode to function
|
||||
func = {
|
||||
"train": module.train,
|
||||
"val": module.val,
|
||||
"predict": module.predict,
|
||||
"export": yolo.engine.exporter.export}.get(cfg.mode)
|
||||
if not func:
|
||||
raise SyntaxError(f"yolo mode={cfg.mode} is invalid. Valid modes are: {', '.join(modes)}\n{CLI_HELP_MSG}")
|
||||
|
||||
func(cfg)
|
||||
|
||||
|
||||
# Special modes --------------------------------------------------------------------------------------------------------
|
||||
def copy_default_config():
|
||||
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
|
||||
shutil.copy2(DEFAULT_CFG_PATH, new_file)
|
||||
LOGGER.info(f"{PREFIX}{DEFAULT_CFG_PATH} copied to {new_file}\n"
|
||||
f"Usage for running YOLO with this new custom config:\nyolo cfg={new_file} args...")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
entrypoint()
|
||||
|
@ -1,68 +1,68 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
# Default training settings and hyperparameters for medium-augmentation COCO training
|
||||
|
||||
task: "detect" # choices=['detect', 'segment', 'classify', 'init'] # init is a special case. Specify task to run.
|
||||
mode: "train" # choices=['train', 'val', 'predict'] # mode to run task in.
|
||||
task: "detect" # choices=['detect', 'segment', 'classify', 'init'] # init is a special case. Specify task to run.
|
||||
mode: "train" # choices=['train', 'val', 'predict'] # mode to run task in.
|
||||
|
||||
# Train settings -------------------------------------------------------------------------------------------------------
|
||||
model: null # i.e. yolov8n.pt, yolov8n.yaml. Path to model file
|
||||
data: null # i.e. coco128.yaml. Path to data file
|
||||
epochs: 100 # number of epochs to train for
|
||||
model: null # i.e. yolov8n.pt, yolov8n.yaml. Path to model file
|
||||
data: null # i.e. coco128.yaml. Path to data file
|
||||
epochs: 100 # number of epochs to train for
|
||||
patience: 50 # epochs to wait for no observable improvement for early stopping of training
|
||||
batch: 16 # number of images per batch
|
||||
imgsz: 640 # size of input images
|
||||
save: True # save checkpoints
|
||||
cache: False # True/ram, disk or False. Use cache for data loading
|
||||
device: null # cuda device, i.e. 0 or 0,1,2,3 or cpu. Device to run on
|
||||
workers: 8 # number of worker threads for data loading
|
||||
project: null # project name
|
||||
name: null # experiment name
|
||||
exist_ok: False # whether to overwrite existing experiment
|
||||
pretrained: False # whether to use a pretrained model
|
||||
optimizer: 'SGD' # optimizer to use, choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
|
||||
verbose: False # whether to print verbose output
|
||||
seed: 0 # random seed for reproducibility
|
||||
deterministic: True # whether to enable deterministic mode
|
||||
single_cls: False # train multi-class data as single-class
|
||||
image_weights: False # use weighted image selection for training
|
||||
rect: False # support rectangular training
|
||||
cos_lr: False # use cosine learning rate scheduler
|
||||
close_mosaic: 10 # disable mosaic augmentation for final 10 epochs
|
||||
resume: False # resume training from last checkpoint
|
||||
batch: 16 # number of images per batch
|
||||
imgsz: 640 # size of input images
|
||||
save: True # save checkpoints
|
||||
cache: False # True/ram, disk or False. Use cache for data loading
|
||||
device: null # cuda device, i.e. 0 or 0,1,2,3 or cpu. Device to run on
|
||||
workers: 8 # number of worker threads for data loading
|
||||
project: null # project name
|
||||
name: null # experiment name
|
||||
exist_ok: False # whether to overwrite existing experiment
|
||||
pretrained: False # whether to use a pretrained model
|
||||
optimizer: 'SGD' # optimizer to use, choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
|
||||
verbose: False # whether to print verbose output
|
||||
seed: 0 # random seed for reproducibility
|
||||
deterministic: True # whether to enable deterministic mode
|
||||
single_cls: False # train multi-class data as single-class
|
||||
image_weights: False # use weighted image selection for training
|
||||
rect: False # support rectangular training
|
||||
cos_lr: False # use cosine learning rate scheduler
|
||||
close_mosaic: 10 # disable mosaic augmentation for final 10 epochs
|
||||
resume: False # resume training from last checkpoint
|
||||
# Segmentation
|
||||
overlap_mask: True # masks should overlap during training
|
||||
mask_ratio: 4 # mask downsample ratio
|
||||
overlap_mask: True # masks should overlap during training
|
||||
mask_ratio: 4 # mask downsample ratio
|
||||
# Classification
|
||||
dropout: 0.0 # use dropout regularization
|
||||
|
||||
# Val/Test settings ----------------------------------------------------------------------------------------------------
|
||||
val: True # validate/test during training
|
||||
save_json: False # save results to JSON file
|
||||
save_hybrid: False # save hybrid version of labels (labels + additional predictions)
|
||||
conf: null # object confidence threshold for detection (default 0.25 predict, 0.001 val)
|
||||
iou: 0.7 # intersection over union (IoU) threshold for NMS
|
||||
max_det: 300 # maximum number of detections per image
|
||||
half: False # use half precision (FP16)
|
||||
dnn: False # use OpenCV DNN for ONNX inference
|
||||
plots: True # show plots during training
|
||||
val: True # validate/test during training
|
||||
save_json: False # save results to JSON file
|
||||
save_hybrid: False # save hybrid version of labels (labels + additional predictions)
|
||||
conf: null # object confidence threshold for detection (default 0.25 predict, 0.001 val)
|
||||
iou: 0.7 # intersection over union (IoU) threshold for NMS
|
||||
max_det: 300 # maximum number of detections per image
|
||||
half: False # use half precision (FP16)
|
||||
dnn: False # use OpenCV DNN for ONNX inference
|
||||
plots: True # show plots during training
|
||||
|
||||
# Prediction settings --------------------------------------------------------------------------------------------------
|
||||
source: null # source directory for images or videos
|
||||
show: False # show results if possible
|
||||
save_txt: False # save results as .txt file
|
||||
save_conf: False # save results with confidence scores
|
||||
save_crop: False # save cropped images with results
|
||||
hide_labels: False # hide labels
|
||||
hide_conf: False # hide confidence scores
|
||||
vid_stride: 1 # video frame-rate stride
|
||||
line_thickness: 3 # bounding box thickness (pixels)
|
||||
visualize: False # visualize results
|
||||
augment: False # apply data augmentation to images
|
||||
agnostic_nms: False # class-agnostic NMS
|
||||
retina_masks: False # use retina masks for object detection
|
||||
source: null # source directory for images or videos
|
||||
show: False # show results if possible
|
||||
save_txt: False # save results as .txt file
|
||||
save_conf: False # save results with confidence scores
|
||||
save_crop: False # save cropped images with results
|
||||
hide_labels: False # hide labels
|
||||
hide_conf: False # hide confidence scores
|
||||
vid_stride: 1 # video frame-rate stride
|
||||
line_thickness: 3 # bounding box thickness (pixels)
|
||||
visualize: False # visualize results
|
||||
augment: False # apply data augmentation to images
|
||||
agnostic_nms: False # class-agnostic NMS
|
||||
retina_masks: False # use retina masks for object detection
|
||||
|
||||
# Export settings ------------------------------------------------------------------------------------------------------
|
||||
format: torchscript # format to export to
|
||||
format: torchscript # format to export to
|
||||
keras: False # use Keras
|
||||
optimize: False # TorchScript: optimize for mobile
|
||||
int8: False # CoreML/TF INT8 quantization
|
||||
@ -100,12 +100,8 @@ mosaic: 1.0 # image mosaic (probability)
|
||||
mixup: 0.0 # image mixup (probability)
|
||||
copy_paste: 0.0 # segment copy-paste (probability)
|
||||
|
||||
# Hydra configs --------------------------------------------------------------------------------------------------------
|
||||
cfg: null # for overriding defaults.yaml
|
||||
hydra:
|
||||
output_subdir: null # disable hydra directory creation
|
||||
run:
|
||||
dir: .
|
||||
# Custom config.yaml ---------------------------------------------------------------------------------------------------
|
||||
cfg: null # for overriding defaults.yaml
|
||||
|
||||
# Debug, do not modify -------------------------------------------------------------------------------------------------
|
||||
v5loader: False # use legacy YOLOv5 dataloader
|
||||
|
@ -1,77 +0,0 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
import sys
|
||||
from difflib import get_close_matches
|
||||
from textwrap import dedent
|
||||
|
||||
import hydra
|
||||
from hydra.errors import ConfigCompositionException
|
||||
from omegaconf import OmegaConf, open_dict # noqa
|
||||
from omegaconf.errors import ConfigAttributeError, ConfigKeyError, OmegaConfBaseException # noqa
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER, colorstr
|
||||
|
||||
|
||||
def override_config(overrides, cfg):
|
||||
override_keys = [override.key_or_group for override in overrides]
|
||||
check_config_mismatch(override_keys, cfg.keys())
|
||||
for override in overrides:
|
||||
if override.package is not None:
|
||||
raise ConfigCompositionException(f"Override {override.input_line} looks like a config group"
|
||||
f" override, but config group '{override.key_or_group}' does not exist.")
|
||||
|
||||
key = override.key_or_group
|
||||
value = override.value()
|
||||
try:
|
||||
if override.is_delete():
|
||||
config_val = OmegaConf.select(cfg, key, throw_on_missing=False)
|
||||
if config_val is None:
|
||||
raise ConfigCompositionException(f"Could not delete from config. '{override.key_or_group}'"
|
||||
" does not exist.")
|
||||
elif value is not None and value != config_val:
|
||||
raise ConfigCompositionException("Could not delete from config. The value of"
|
||||
f" '{override.key_or_group}' is {config_val} and not"
|
||||
f" {value}.")
|
||||
|
||||
last_dot = key.rfind(".")
|
||||
with open_dict(cfg):
|
||||
if last_dot == -1:
|
||||
del cfg[key]
|
||||
else:
|
||||
node = OmegaConf.select(cfg, key[:last_dot])
|
||||
del node[key[last_dot + 1:]]
|
||||
|
||||
elif override.is_add():
|
||||
if OmegaConf.select(cfg, key, throw_on_missing=False) is None or isinstance(value, (dict, list)):
|
||||
OmegaConf.update(cfg, key, value, merge=True, force_add=True)
|
||||
else:
|
||||
assert override.input_line is not None
|
||||
raise ConfigCompositionException(
|
||||
dedent(f"""\
|
||||
Could not append to config. An item is already at '{override.key_or_group}'.
|
||||
Either remove + prefix: '{override.input_line[1:]}'
|
||||
Or add a second + to add or override '{override.key_or_group}': '+{override.input_line}'
|
||||
"""))
|
||||
elif override.is_force_add():
|
||||
OmegaConf.update(cfg, key, value, merge=True, force_add=True)
|
||||
else:
|
||||
try:
|
||||
OmegaConf.update(cfg, key, value, merge=True)
|
||||
except (ConfigAttributeError, ConfigKeyError) as ex:
|
||||
raise ConfigCompositionException(f"Could not override '{override.key_or_group}'."
|
||||
f"\nTo append to your config use +{override.input_line}") from ex
|
||||
except OmegaConfBaseException as ex:
|
||||
raise ConfigCompositionException(f"Error merging override {override.input_line}").with_traceback(
|
||||
sys.exc_info()[2]) from ex
|
||||
|
||||
|
||||
def check_config_mismatch(overrides, cfg):
|
||||
mismatched = [option for option in overrides if option not in cfg and 'hydra.' not in option]
|
||||
|
||||
for option in mismatched:
|
||||
LOGGER.info(f"{colorstr(option)} is not a valid key. Similar keys: {get_close_matches(option, cfg, 3, 0.6)}")
|
||||
if mismatched:
|
||||
sys.exit()
|
||||
|
||||
|
||||
hydra._internal.config_loader_impl.ConfigLoaderImpl._apply_overrides_to_config = override_config
|
@ -69,8 +69,8 @@ def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank
|
||||
augment=mode == "train", # augmentation
|
||||
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
|
||||
rect=cfg.rect if mode == "train" else True, # rectangular batches
|
||||
cache=cfg.get("cache", None),
|
||||
single_cls=cfg.get("single_cls", False),
|
||||
cache=cfg.cache or None,
|
||||
single_cls=cfg.single_cls or False,
|
||||
stride=int(stride),
|
||||
pad=0.0 if mode == "train" else 0.5,
|
||||
prefix=colorstr(f"{mode}: "),
|
||||
|
@ -29,7 +29,8 @@ from torch.utils.data import DataLoader, Dataset, dataloader, distributed
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.yolo.data.utils import check_dataset, unzip_file
|
||||
from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, is_colab, is_kaggle
|
||||
from ultralytics.yolo.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, is_colab, is_dir_writeable,
|
||||
is_kaggle)
|
||||
from ultralytics.yolo.utils.checks import check_requirements, check_yaml
|
||||
from ultralytics.yolo.utils.ops import clean_str, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn
|
||||
from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first
|
||||
@ -493,7 +494,7 @@ class LoadImagesAndLabels(Dataset):
|
||||
cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
|
||||
assert cache['version'] == self.cache_version # matches current version
|
||||
assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash
|
||||
except (FileNotFoundError, AssertionError):
|
||||
except (FileNotFoundError, AssertionError, AttributeError):
|
||||
cache, exists = self.cache_labels(cache_path, prefix), False # run cache ops
|
||||
|
||||
# Display cache
|
||||
@ -579,16 +580,17 @@ class LoadImagesAndLabels(Dataset):
|
||||
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
|
||||
self.im_hw0, self.im_hw = [None] * n, [None] * n
|
||||
fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
|
||||
results = ThreadPool(NUM_THREADS).imap(fcn, range(n))
|
||||
pbar = tqdm(enumerate(results), total=n, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
|
||||
for i, x in pbar:
|
||||
if cache_images == 'disk':
|
||||
b += self.npy_files[i].stat().st_size
|
||||
else: # 'ram'
|
||||
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
|
||||
b += self.ims[i].nbytes
|
||||
pbar.desc = f'{prefix}Caching images ({b / gb:.1f}GB {cache_images})'
|
||||
pbar.close()
|
||||
with (Pool if n > 10000 else ThreadPool)(NUM_THREADS) as pool:
|
||||
results = pool.imap(fcn, range(n))
|
||||
pbar = tqdm(enumerate(results), total=n, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
|
||||
for i, x in pbar:
|
||||
if cache_images == 'disk':
|
||||
b += self.npy_files[i].stat().st_size
|
||||
else: # 'ram'
|
||||
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
|
||||
b += self.ims[i].nbytes
|
||||
pbar.desc = f'{prefix}Caching images ({b / gb:.1f}GB {cache_images})'
|
||||
pbar.close()
|
||||
|
||||
def check_cache_ram(self, safety_margin=0.1, prefix=''):
|
||||
# Check image caching requirements vs available memory
|
||||
@ -612,11 +614,10 @@ class LoadImagesAndLabels(Dataset):
|
||||
x = {} # dict
|
||||
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
||||
desc = f"{prefix}Scanning {path.parent / path.stem}..."
|
||||
with Pool(NUM_THREADS) as pool:
|
||||
pbar = tqdm(pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix))),
|
||||
desc=desc,
|
||||
total=len(self.im_files),
|
||||
bar_format=TQDM_BAR_FORMAT)
|
||||
total = len(self.im_files)
|
||||
with (Pool if total > 10000 else ThreadPool)(NUM_THREADS) as pool:
|
||||
results = pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix)))
|
||||
pbar = tqdm(results, desc=desc, total=total, bar_format=TQDM_BAR_FORMAT)
|
||||
for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
|
||||
nm += nm_f
|
||||
nf += nf_f
|
||||
@ -627,8 +628,8 @@ class LoadImagesAndLabels(Dataset):
|
||||
if msg:
|
||||
msgs.append(msg)
|
||||
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||||
pbar.close()
|
||||
|
||||
pbar.close()
|
||||
if msgs:
|
||||
LOGGER.info('\n'.join(msgs))
|
||||
if nf == 0:
|
||||
@ -637,12 +638,12 @@ class LoadImagesAndLabels(Dataset):
|
||||
x['results'] = nf, nm, ne, nc, len(self.im_files)
|
||||
x['msgs'] = msgs # warnings
|
||||
x['version'] = self.cache_version # cache version
|
||||
try:
|
||||
np.save(path, x) # save cache for next time
|
||||
if is_dir_writeable(path.parent):
|
||||
np.save(str(path), x) # save cache for next time
|
||||
path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
|
||||
LOGGER.info(f'{prefix}New cache created: {path}')
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable: {e}') # not writeable
|
||||
else:
|
||||
LOGGER.warning(f'{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable') # not writeable
|
||||
return x
|
||||
|
||||
def __len__(self):
|
||||
@ -1148,8 +1149,10 @@ class HUBDatasetStats():
|
||||
continue
|
||||
dataset = LoadImagesAndLabels(self.data[split]) # load dataset
|
||||
desc = f'{split} images'
|
||||
for _ in tqdm(ThreadPool(NUM_THREADS).imap(self._hub_ops, dataset.im_files), total=dataset.n, desc=desc):
|
||||
pass
|
||||
total = dataset.n
|
||||
with (Pool if total > 10000 else ThreadPool)(NUM_THREADS) as pool:
|
||||
for _ in tqdm(pool.imap(self._hub_ops, dataset.im_files), total=total, desc=desc):
|
||||
pass
|
||||
print(f'Done. All images saved to {self.im_dir}')
|
||||
return self.im_dir
|
||||
|
||||
|
@ -1,13 +1,13 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
from itertools import repeat
|
||||
from multiprocessing.pool import Pool
|
||||
from multiprocessing.pool import Pool, ThreadPool
|
||||
from pathlib import Path
|
||||
|
||||
import torchvision
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..utils import NUM_THREADS, TQDM_BAR_FORMAT
|
||||
from ..utils import NUM_THREADS, TQDM_BAR_FORMAT, is_dir_writeable
|
||||
from .augment import *
|
||||
from .base import BaseDataset
|
||||
from .utils import HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label
|
||||
@ -50,14 +50,12 @@ class YOLODataset(BaseDataset):
|
||||
x = {"labels": []}
|
||||
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
||||
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
|
||||
with Pool(NUM_THREADS) as pool:
|
||||
pbar = tqdm(
|
||||
pool.imap(verify_image_label,
|
||||
zip(self.im_files, self.label_files, repeat(self.prefix), repeat(self.use_keypoints))),
|
||||
desc=desc,
|
||||
total=len(self.im_files),
|
||||
bar_format=TQDM_BAR_FORMAT,
|
||||
)
|
||||
total = len(self.im_files)
|
||||
with (Pool if total > 10000 else ThreadPool)(NUM_THREADS) as pool:
|
||||
results = pool.imap(func=verify_image_label,
|
||||
iterable=zip(self.im_files, self.label_files, repeat(self.prefix),
|
||||
repeat(self.use_keypoints)))
|
||||
pbar = tqdm(results, desc=desc, total=total, bar_format=TQDM_BAR_FORMAT)
|
||||
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
|
||||
nm += nm_f
|
||||
nf += nf_f
|
||||
@ -73,13 +71,12 @@ class YOLODataset(BaseDataset):
|
||||
segments=segments,
|
||||
keypoints=keypoint,
|
||||
normalized=True,
|
||||
bbox_format="xywh",
|
||||
))
|
||||
bbox_format="xywh"))
|
||||
if msg:
|
||||
msgs.append(msg)
|
||||
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||||
pbar.close()
|
||||
|
||||
pbar.close()
|
||||
if msgs:
|
||||
LOGGER.info("\n".join(msgs))
|
||||
if nf == 0:
|
||||
@ -89,13 +86,12 @@ class YOLODataset(BaseDataset):
|
||||
x["msgs"] = msgs # warnings
|
||||
x["version"] = self.cache_version # cache version
|
||||
self.im_files = [lb["im_file"] for lb in x["labels"]]
|
||||
try:
|
||||
np.save(path, x) # save cache for next time
|
||||
if is_dir_writeable(path.parent):
|
||||
np.save(str(path), x) # save cache for next time
|
||||
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
|
||||
LOGGER.info(f"{self.prefix}New cache created: {path}")
|
||||
except Exception as e:
|
||||
LOGGER.warning(
|
||||
f"{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable: {e}") # not writeable
|
||||
else:
|
||||
LOGGER.warning(f"{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable") # not writeable
|
||||
return x
|
||||
|
||||
def get_labels(self):
|
||||
@ -105,7 +101,7 @@ class YOLODataset(BaseDataset):
|
||||
cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dict
|
||||
assert cache["version"] == self.cache_version # matches current version
|
||||
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
|
||||
except (FileNotFoundError, AssertionError):
|
||||
except (FileNotFoundError, AssertionError, AttributeError):
|
||||
cache, exists = self.cache_labels(cache_path), False # run cache ops
|
||||
|
||||
# Display cache
|
||||
|
@ -99,7 +99,7 @@ names:
|
||||
|
||||
# Download script/URL (optional)
|
||||
download: |
|
||||
from ultralytics.yoloutils.downloads import download
|
||||
from ultralytics.yolo.utils.downloads import download
|
||||
from pathlib import Path
|
||||
|
||||
# Download labels
|
||||
|
@ -60,7 +60,6 @@ from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
@ -71,7 +70,7 @@ from ultralytics.nn.tasks import ClassificationModel, DetectionModel, Segmentati
|
||||
from ultralytics.yolo.configs import get_config
|
||||
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
|
||||
from ultralytics.yolo.data.utils import check_dataset
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, callbacks, colorstr, get_default_args, yaml_save
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, get_default_args, yaml_save
|
||||
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
|
||||
from ultralytics.yolo.utils.files import file_size
|
||||
from ultralytics.yolo.utils.ops import Profile
|
||||
@ -123,11 +122,11 @@ class Exporter:
|
||||
A class for exporting a model.
|
||||
|
||||
Attributes:
|
||||
args (OmegaConf): Configuration for the exporter.
|
||||
args (SimpleNamespace): Configuration for the exporter.
|
||||
save_dir (Path): Directory to save results.
|
||||
"""
|
||||
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
|
||||
def __init__(self, config=DEFAULT_CFG, overrides=None):
|
||||
"""
|
||||
Initializes the Exporter class.
|
||||
|
||||
@ -135,8 +134,6 @@ class Exporter:
|
||||
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
self.args = get_config(config, overrides)
|
||||
self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks
|
||||
callbacks.add_integration_callbacks(self)
|
||||
@ -799,8 +796,7 @@ class Exporter:
|
||||
callback(self)
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||
def export(cfg):
|
||||
def export(cfg=DEFAULT_CFG):
|
||||
cfg.model = cfg.model or "yolov8n.yaml"
|
||||
cfg.format = cfg.format or "torchscript"
|
||||
|
||||
@ -818,7 +814,7 @@ def export(cfg):
|
||||
|
||||
from ultralytics import YOLO
|
||||
model = YOLO(cfg.model)
|
||||
model.export(**cfg)
|
||||
model.export(**vars(cfg))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -6,7 +6,7 @@ from ultralytics import yolo # noqa
|
||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight
|
||||
from ultralytics.yolo.configs import get_config
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, yaml_load
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_yaml
|
||||
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode
|
||||
|
||||
@ -151,7 +151,7 @@ class YOLO:
|
||||
overrides = self.overrides.copy()
|
||||
overrides.update(kwargs)
|
||||
overrides["mode"] = "val"
|
||||
args = get_config(config=DEFAULT_CONFIG, overrides=overrides)
|
||||
args = get_config(config=DEFAULT_CFG_PATH, overrides=overrides)
|
||||
args.data = data or args.data
|
||||
args.task = self.task
|
||||
|
||||
@ -169,7 +169,7 @@ class YOLO:
|
||||
|
||||
overrides = self.overrides.copy()
|
||||
overrides.update(kwargs)
|
||||
args = get_config(config=DEFAULT_CONFIG, overrides=overrides)
|
||||
args = get_config(config=DEFAULT_CFG_PATH, overrides=overrides)
|
||||
args.task = self.task
|
||||
|
||||
print(args)
|
||||
|
@ -36,7 +36,7 @@ from ultralytics.nn.autobackend import AutoBackend
|
||||
from ultralytics.yolo.configs import get_config
|
||||
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams
|
||||
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, SETTINGS, callbacks, colorstr, ops
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, SETTINGS, callbacks, colorstr, ops
|
||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow
|
||||
from ultralytics.yolo.utils.files import increment_path
|
||||
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
|
||||
@ -49,7 +49,7 @@ class BasePredictor:
|
||||
A base class for creating predictors.
|
||||
|
||||
Attributes:
|
||||
args (OmegaConf): Configuration for the predictor.
|
||||
args (SimpleNamespace): Configuration for the predictor.
|
||||
save_dir (Path): Directory to save results.
|
||||
done_setup (bool): Whether the predictor has finished setup.
|
||||
model (nn.Module): Model used for prediction.
|
||||
@ -62,7 +62,7 @@ class BasePredictor:
|
||||
data_path (str): Path to data.
|
||||
"""
|
||||
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
|
||||
def __init__(self, config=DEFAULT_CFG_PATH, overrides=None):
|
||||
"""
|
||||
Initializes the BasePredictor class.
|
||||
|
||||
@ -70,8 +70,6 @@ class BasePredictor:
|
||||
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
self.args = get_config(config, overrides)
|
||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||
name = self.args.name or f"{self.args.mode}"
|
||||
@ -157,7 +155,7 @@ class BasePredictor:
|
||||
if stream:
|
||||
return self.stream_inference(source, model, verbose)
|
||||
else:
|
||||
return list(chain(*list(self.stream_inference(source, model, verbose)))) # merge list of Result into one
|
||||
return list(self.stream_inference(source, model, verbose)) # merge list of Result into one
|
||||
|
||||
def predict_cli(self):
|
||||
# Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode
|
||||
@ -211,7 +209,7 @@ class BasePredictor:
|
||||
if self.args.save:
|
||||
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
|
||||
|
||||
yield results
|
||||
yield from results
|
||||
|
||||
# Print time (inference-only)
|
||||
if verbose:
|
||||
|
@ -15,8 +15,6 @@ import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from omegaconf import OmegaConf # noqa
|
||||
from omegaconf import open_dict
|
||||
from torch.cuda import amp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import lr_scheduler
|
||||
@ -27,7 +25,7 @@ from ultralytics import __version__
|
||||
from ultralytics.nn.tasks import attempt_load_one_weight
|
||||
from ultralytics.yolo.configs import get_config
|
||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||
from ultralytics.yolo.utils import (DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr,
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr,
|
||||
yaml_save)
|
||||
from ultralytics.yolo.utils.autobatch import check_train_batch_size
|
||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args
|
||||
@ -43,7 +41,7 @@ class BaseTrainer:
|
||||
A base class for creating trainers.
|
||||
|
||||
Attributes:
|
||||
args (OmegaConf): Configuration for the trainer.
|
||||
args (SimpleNamespace): Configuration for the trainer.
|
||||
check_resume (method): Method to check if training should be resumed from a saved checkpoint.
|
||||
console (logging.Logger): Logger instance.
|
||||
validator (BaseValidator): Validator instance.
|
||||
@ -73,7 +71,7 @@ class BaseTrainer:
|
||||
csv (Path): Path to results CSV file.
|
||||
"""
|
||||
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
|
||||
def __init__(self, config=DEFAULT_CFG_PATH, overrides=None):
|
||||
"""
|
||||
Initializes the BaseTrainer class.
|
||||
|
||||
@ -81,8 +79,6 @@ class BaseTrainer:
|
||||
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
self.args = get_config(config, overrides)
|
||||
self.device = utils.torch_utils.select_device(self.args.device, self.args.batch)
|
||||
self.check_resume()
|
||||
@ -95,23 +91,23 @@ class BaseTrainer:
|
||||
# Dirs
|
||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||
name = self.args.name or f"{self.args.mode}"
|
||||
self.save_dir = Path(
|
||||
self.args.get(
|
||||
"save_dir",
|
||||
increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in {-1, 0} else True)))
|
||||
if hasattr(self.args, 'save_dir'):
|
||||
self.save_dir = Path(self.args.save_dir)
|
||||
else:
|
||||
self.save_dir = Path(
|
||||
increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in {-1, 0} else True))
|
||||
self.wdir = self.save_dir / 'weights' # weights dir
|
||||
if RANK in {-1, 0}:
|
||||
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
||||
with open_dict(self.args):
|
||||
self.args.save_dir = str(self.save_dir)
|
||||
yaml_save(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True)) # save run args
|
||||
self.args.save_dir = str(self.save_dir)
|
||||
yaml_save(self.save_dir / 'args.yaml', vars(self.args)) # save run args
|
||||
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
|
||||
|
||||
self.batch_size = self.args.batch
|
||||
self.epochs = self.args.epochs
|
||||
self.start_epoch = 0
|
||||
if RANK == -1:
|
||||
print_args(dict(self.args))
|
||||
print_args(vars(self.args))
|
||||
|
||||
# Device
|
||||
self.amp = self.device.type != 'cpu'
|
||||
@ -373,7 +369,7 @@ class BaseTrainer:
|
||||
'ema': deepcopy(self.ema.ema).half(),
|
||||
'updates': self.ema.updates,
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'train_args': self.args,
|
||||
'train_args': vars(self.args), # save as dict
|
||||
'date': datetime.now().isoformat(),
|
||||
'version': __version__}
|
||||
|
||||
|
@ -5,12 +5,12 @@ from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf # noqa
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.nn.autobackend import AutoBackend
|
||||
from ultralytics.yolo.configs import get_config
|
||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks
|
||||
from ultralytics.yolo.utils.checks import check_imgsz
|
||||
from ultralytics.yolo.utils.files import increment_path
|
||||
from ultralytics.yolo.utils.ops import Profile
|
||||
@ -27,7 +27,7 @@ class BaseValidator:
|
||||
dataloader (DataLoader): Dataloader to use for validation.
|
||||
pbar (tqdm): Progress bar to update during validation.
|
||||
logger (logging.Logger): Logger to use for validation.
|
||||
args (OmegaConf): Configuration for the validator.
|
||||
args (SimpleNamespace): Configuration for the validator.
|
||||
model (nn.Module): Model to validate.
|
||||
data (dict): Data dictionary.
|
||||
device (torch.device): Device to use for validation.
|
||||
@ -47,12 +47,12 @@ class BaseValidator:
|
||||
save_dir (Path): Directory to save results.
|
||||
pbar (tqdm.tqdm): Progress bar for displaying progress.
|
||||
logger (logging.Logger): Logger to log messages.
|
||||
args (OmegaConf): Configuration for the validator.
|
||||
args (SimpleNamespace): Configuration for the validator.
|
||||
"""
|
||||
self.dataloader = dataloader
|
||||
self.pbar = pbar
|
||||
self.logger = logger or LOGGER
|
||||
self.args = args or OmegaConf.load(DEFAULT_CONFIG)
|
||||
self.args = args or get_config(DEFAULT_CFG_PATH)
|
||||
self.model = None
|
||||
self.data = None
|
||||
self.device = None
|
||||
|
@ -8,6 +8,7 @@ import platform
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import types
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
@ -22,7 +23,7 @@ import yaml
|
||||
# Constants
|
||||
FILE = Path(__file__).resolve()
|
||||
ROOT = FILE.parents[2] # YOLO
|
||||
DEFAULT_CONFIG = ROOT / "yolo/configs/default.yaml"
|
||||
DEFAULT_CFG_PATH = ROOT / "yolo/configs/default.yaml"
|
||||
RANK = int(os.getenv('RANK', -1))
|
||||
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
|
||||
AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
|
||||
@ -73,9 +74,10 @@ os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
|
||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # for deterministic training
|
||||
|
||||
# Default config dictionary
|
||||
with open(DEFAULT_CONFIG, errors='ignore') as f:
|
||||
DEFAULT_CONFIG_DICT = yaml.safe_load(f)
|
||||
DEFAULT_CONFIG_KEYS = DEFAULT_CONFIG_DICT.keys()
|
||||
with open(DEFAULT_CFG_PATH, errors='ignore') as f:
|
||||
DEFAULT_CFG_DICT = yaml.safe_load(f)
|
||||
DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()
|
||||
DEFAULT_CFG = types.SimpleNamespace(**DEFAULT_CFG_DICT)
|
||||
|
||||
|
||||
def is_colab():
|
||||
|
@ -28,7 +28,7 @@ def generate_ddp_file(trainer):
|
||||
|
||||
if not trainer.resume:
|
||||
shutil.rmtree(trainer.save_dir) # remove the save_dir
|
||||
content = f'''config = {dict(trainer.args)} \nif __name__ == "__main__":
|
||||
content = f'''config = {vars(trainer.args)} \nif __name__ == "__main__":
|
||||
from ultralytics.{import_path} import {trainer.__class__.__name__}
|
||||
|
||||
trainer = {trainer.__class__.__name__}(config=config)
|
||||
|
@ -18,7 +18,7 @@ import torch.nn.functional as F
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
import ultralytics
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, DEFAULT_CONFIG_KEYS, LOGGER
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER
|
||||
from ultralytics.yolo.utils.checks import git_describe
|
||||
|
||||
from .checks import check_version
|
||||
@ -288,7 +288,7 @@ def strip_optimizer(f='best.pt', s=''):
|
||||
None
|
||||
"""
|
||||
x = torch.load(f, map_location=torch.device('cpu'))
|
||||
args = {**DEFAULT_CONFIG_DICT, **x['train_args']} # combine model args with default args, preferring model args
|
||||
args = {**DEFAULT_CFG_DICT, **x['train_args']} # combine model args with default args, preferring model args
|
||||
if x.get('ema'):
|
||||
x['model'] = x['ema'] # replace model with ema
|
||||
for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
|
||||
@ -297,7 +297,8 @@ def strip_optimizer(f='best.pt', s=''):
|
||||
x['model'].half() # to FP16
|
||||
for p in x['model'].parameters():
|
||||
p.requires_grad = False
|
||||
x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS} # strip non-default keys
|
||||
x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
|
||||
# x['model'].args = x['train_args']
|
||||
torch.save(x, s or f)
|
||||
mb = os.path.getsize(s or f) / 1E6 # filesize
|
||||
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
||||
|
@ -1,6 +1,5 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
from ultralytics.yolo.configs import hydra_patch # noqa (patch hydra CLI)
|
||||
from ultralytics.yolo.v8 import classify, detect, segment
|
||||
|
||||
__all__ = ["classify", "segment", "detect"]
|
||||
|
@ -1,11 +1,10 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.engine.predictor import BasePredictor
|
||||
from ultralytics.yolo.engine.results import Results
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, is_git_directory
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, is_git_directory
|
||||
from ultralytics.yolo.utils.plotting import Annotator
|
||||
|
||||
|
||||
@ -64,8 +63,7 @@ class ClassificationPredictor(BasePredictor):
|
||||
return log_string
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||
def predict(cfg):
|
||||
def predict(cfg=DEFAULT_CFG):
|
||||
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
|
||||
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if is_git_directory() \
|
||||
else "https://ultralytics.com/images/bus.jpg"
|
||||
|
@ -1,6 +1,5 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
@ -8,13 +7,13 @@ from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight
|
||||
from ultralytics.yolo import v8
|
||||
from ultralytics.yolo.data import build_classification_dataloader
|
||||
from ultralytics.yolo.engine.trainer import BaseTrainer
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG
|
||||
from ultralytics.yolo.utils.torch_utils import strip_optimizer
|
||||
|
||||
|
||||
class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
|
||||
def __init__(self, config=DEFAULT_CFG, overrides=None):
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides["task"] = "classify"
|
||||
@ -136,8 +135,7 @@ class ClassificationTrainer(BaseTrainer):
|
||||
# self.run_callbacks('on_fit_epoch_end')
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||
def train(cfg):
|
||||
def train(cfg=DEFAULT_CFG):
|
||||
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
|
||||
cfg.data = cfg.data or "mnist160" # or yolo.ClassificationDataset("mnist")
|
||||
|
||||
@ -152,7 +150,7 @@ def train(cfg):
|
||||
# trainer.train()
|
||||
from ultralytics import YOLO
|
||||
model = YOLO(cfg.model)
|
||||
model.train(**cfg)
|
||||
model.train(**vars(cfg))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,10 +1,8 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
import hydra
|
||||
|
||||
from ultralytics.yolo.data import build_classification_dataloader
|
||||
from ultralytics.yolo.engine.validator import BaseValidator
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG
|
||||
from ultralytics.yolo.utils.metrics import ClassifyMetrics
|
||||
|
||||
|
||||
@ -46,8 +44,7 @@ class ClassificationValidator(BaseValidator):
|
||||
self.logger.info(pf % ("all", self.metrics.top1, self.metrics.top5))
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||
def val(cfg):
|
||||
def val(cfg=DEFAULT_CFG):
|
||||
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
|
||||
cfg.data = cfg.data or "imagenette160"
|
||||
validator = ClassificationValidator(args=cfg)
|
||||
|
@ -1,11 +1,10 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.engine.predictor import BasePredictor
|
||||
from ultralytics.yolo.engine.results import Results
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, is_git_directory, ops
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, is_git_directory, ops
|
||||
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
|
||||
|
||||
|
||||
@ -81,8 +80,7 @@ class DetectionPredictor(BasePredictor):
|
||||
return log_string
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||
def predict(cfg):
|
||||
def predict(cfg=DEFAULT_CFG):
|
||||
cfg.model = cfg.model or "yolov8n.pt"
|
||||
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if is_git_directory() \
|
||||
else "https://ultralytics.com/images/bus.jpg"
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
from copy import copy
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@ -11,7 +10,7 @@ from ultralytics.yolo import v8
|
||||
from ultralytics.yolo.data import build_dataloader
|
||||
from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader
|
||||
from ultralytics.yolo.engine.trainer import BaseTrainer
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, colorstr
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, colorstr
|
||||
from ultralytics.yolo.utils.loss import BboxLoss
|
||||
from ultralytics.yolo.utils.ops import xywh2xyxy
|
||||
from ultralytics.yolo.utils.plotting import plot_images, plot_results
|
||||
@ -30,7 +29,7 @@ class DetectionTrainer(BaseTrainer):
|
||||
imgsz=self.args.imgsz,
|
||||
batch_size=batch_size,
|
||||
stride=gs,
|
||||
hyp=dict(self.args),
|
||||
hyp=vars(self.args),
|
||||
augment=mode == "train",
|
||||
cache=self.args.cache,
|
||||
pad=0 if mode == "train" else 0.5,
|
||||
@ -195,8 +194,7 @@ class Loss:
|
||||
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||
def train(cfg):
|
||||
def train(cfg=DEFAULT_CFG):
|
||||
cfg.model = cfg.model or "yolov8n.pt"
|
||||
cfg.data = cfg.data or "coco128.yaml" # or yolo.ClassificationDataset("mnist")
|
||||
cfg.device = cfg.device if cfg.device is not None else ''
|
||||
@ -204,7 +202,7 @@ def train(cfg):
|
||||
# trainer.train()
|
||||
from ultralytics import YOLO
|
||||
model = YOLO(cfg.model)
|
||||
model.train(**cfg)
|
||||
model.train(**vars(cfg))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -3,14 +3,13 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.data import build_dataloader
|
||||
from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader
|
||||
from ultralytics.yolo.engine.validator import BaseValidator
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, colorstr, ops, yaml_load
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, colorstr, ops, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_file, check_requirements
|
||||
from ultralytics.yolo.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
|
||||
from ultralytics.yolo.utils.plotting import output_to_target, plot_images
|
||||
@ -168,7 +167,7 @@ class DetectionValidator(BaseValidator):
|
||||
imgsz=self.args.imgsz,
|
||||
batch_size=batch_size,
|
||||
stride=gs,
|
||||
hyp=dict(self.args),
|
||||
hyp=vars(self.args),
|
||||
cache=False,
|
||||
pad=0.5,
|
||||
rect=True,
|
||||
@ -232,8 +231,7 @@ class DetectionValidator(BaseValidator):
|
||||
return stats
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||
def val(cfg):
|
||||
def val(cfg=DEFAULT_CFG):
|
||||
cfg.model = cfg.model or "yolov8n.pt"
|
||||
cfg.data = cfg.data or "coco128.yaml"
|
||||
validator = DetectionValidator(args=cfg)
|
||||
|
@ -1,10 +1,9 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.engine.results import Results
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, is_git_directory, ops
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, is_git_directory, ops
|
||||
from ultralytics.yolo.utils.plotting import colors, save_one_box
|
||||
from ultralytics.yolo.v8.detect.predict import DetectionPredictor
|
||||
|
||||
@ -98,8 +97,7 @@ class SegmentationPredictor(DetectionPredictor):
|
||||
return log_string
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||
def predict(cfg):
|
||||
def predict(cfg=DEFAULT_CFG):
|
||||
cfg.model = cfg.model or "yolov8n-seg.pt"
|
||||
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if is_git_directory() \
|
||||
else "https://ultralytics.com/images/bus.jpg"
|
||||
|
@ -2,13 +2,12 @@
|
||||
|
||||
from copy import copy
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ultralytics.nn.tasks import SegmentationModel
|
||||
from ultralytics.yolo import v8
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG
|
||||
from ultralytics.yolo.utils.ops import crop_mask, xyxy2xywh
|
||||
from ultralytics.yolo.utils.plotting import plot_images, plot_results
|
||||
from ultralytics.yolo.utils.tal import make_anchors
|
||||
@ -19,7 +18,7 @@ from ultralytics.yolo.v8.detect.train import Loss
|
||||
# BaseTrainer python usage
|
||||
class SegmentationTrainer(v8.detect.DetectionTrainer):
|
||||
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
|
||||
def __init__(self, config=DEFAULT_CFG, overrides=None):
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides["task"] = "segment"
|
||||
@ -141,8 +140,7 @@ class SegLoss(Loss):
|
||||
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||
def train(cfg):
|
||||
def train(cfg=DEFAULT_CFG):
|
||||
cfg.model = cfg.model or "yolov8n-seg.pt"
|
||||
cfg.data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist")
|
||||
cfg.device = cfg.device if cfg.device is not None else ''
|
||||
@ -150,7 +148,7 @@ def train(cfg):
|
||||
# trainer.train()
|
||||
from ultralytics import YOLO
|
||||
model = YOLO(cfg.model)
|
||||
model.train(**cfg)
|
||||
model.train(**vars(cfg))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -4,12 +4,11 @@ import os
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, NUM_THREADS, ops
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, NUM_THREADS, ops
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
from ultralytics.yolo.utils.metrics import ConfusionMatrix, SegmentMetrics, box_iou, mask_iou
|
||||
from ultralytics.yolo.utils.plotting import output_to_target, plot_images
|
||||
@ -243,8 +242,7 @@ class SegmentationValidator(DetectionValidator):
|
||||
return stats
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||
def val(cfg):
|
||||
def val(cfg=DEFAULT_CFG):
|
||||
cfg.data = cfg.data or "coco128-seg.yaml"
|
||||
validator = SegmentationValidator(args=cfg)
|
||||
validator(model=cfg.model)
|
||||
|
Reference in New Issue
Block a user