ultralytics 8.0.136
refactor and simplify package (#3748)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
@ -1,14 +1,13 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = '8.0.135'
|
||||
__version__ = '8.0.136'
|
||||
|
||||
from ultralytics.engine.model import YOLO
|
||||
from ultralytics.hub import start
|
||||
from ultralytics.vit.rtdetr import RTDETR
|
||||
from ultralytics.vit.sam import SAM
|
||||
from ultralytics.yolo.engine.model import YOLO
|
||||
from ultralytics.yolo.fastsam import FastSAM
|
||||
from ultralytics.yolo.nas import NAS
|
||||
from ultralytics.yolo.utils.checks import check_yolo as checks
|
||||
from ultralytics.yolo.utils.downloads import download
|
||||
from ultralytics.models import RTDETR, SAM
|
||||
from ultralytics.models.fastsam import FastSAM
|
||||
from ultralytics.models.nas import NAS
|
||||
from ultralytics.utils.checks import check_yolo as checks
|
||||
from ultralytics.utils.downloads import download
|
||||
|
||||
__all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'FastSAM', 'RTDETR', 'checks', 'download', 'start' # allow simpler import
|
||||
|
421
ultralytics/cfg/__init__.py
Normal file
421
ultralytics/cfg/__init__.py
Normal file
@ -0,0 +1,421 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import contextlib
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
from difflib import get_close_matches
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from ultralytics.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, ROOT, USER_CONFIG_DIR,
|
||||
IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn, get_settings,
|
||||
yaml_load, yaml_print)
|
||||
|
||||
# Define valid tasks and modes
|
||||
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
|
||||
TASKS = 'detect', 'segment', 'classify', 'pose'
|
||||
TASK2DATA = {'detect': 'coco8.yaml', 'segment': 'coco8-seg.yaml', 'classify': 'imagenet100', 'pose': 'coco8-pose.yaml'}
|
||||
TASK2MODEL = {
|
||||
'detect': 'yolov8n.pt',
|
||||
'segment': 'yolov8n-seg.pt',
|
||||
'classify': 'yolov8n-cls.pt',
|
||||
'pose': 'yolov8n-pose.pt'}
|
||||
TASK2METRIC = {
|
||||
'detect': 'metrics/mAP50-95(B)',
|
||||
'segment': 'metrics/mAP50-95(M)',
|
||||
'classify': 'metrics/accuracy_top1',
|
||||
'pose': 'metrics/mAP50-95(P)'}
|
||||
|
||||
|
||||
CLI_HELP_MSG = \
|
||||
f"""
|
||||
Arguments received: {str(['yolo'] + sys.argv[1:])}. Ultralytics 'yolo' commands use the following syntax:
|
||||
|
||||
yolo TASK MODE ARGS
|
||||
|
||||
Where TASK (optional) is one of {TASKS}
|
||||
MODE (required) is one of {MODES}
|
||||
ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
|
||||
See all ARGS at https://docs.ultralytics.com/usage/cfg or with 'yolo cfg'
|
||||
|
||||
1. Train a detection model for 10 epochs with an initial learning_rate of 0.01
|
||||
yolo train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
|
||||
|
||||
2. Predict a YouTube video using a pretrained segmentation model at image size 320:
|
||||
yolo predict model=yolov8n-seg.pt source='https://youtu.be/Zgi9g1ksQHc' imgsz=320
|
||||
|
||||
3. Val a pretrained detection model at batch-size 1 and image size 640:
|
||||
yolo val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
|
||||
|
||||
4. Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
|
||||
yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
|
||||
|
||||
5. Run special commands:
|
||||
yolo help
|
||||
yolo checks
|
||||
yolo version
|
||||
yolo settings
|
||||
yolo copy-cfg
|
||||
yolo cfg
|
||||
|
||||
Docs: https://docs.ultralytics.com
|
||||
Community: https://community.ultralytics.com
|
||||
GitHub: https://github.com/ultralytics/ultralytics
|
||||
"""
|
||||
|
||||
# Define keys for arg type checks
|
||||
CFG_FLOAT_KEYS = 'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear'
|
||||
CFG_FRACTION_KEYS = ('dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr',
|
||||
'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud',
|
||||
'fliplr', 'mosaic', 'mixup', 'copy_paste', 'conf', 'iou', 'fraction') # fraction floats 0.0 - 1.0
|
||||
CFG_INT_KEYS = ('epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
|
||||
'line_width', 'workspace', 'nbs', 'save_period')
|
||||
CFG_BOOL_KEYS = ('save', 'exist_ok', 'verbose', 'deterministic', 'single_cls', 'rect', 'cos_lr', 'overlap_mask', 'val',
|
||||
'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt', 'save_conf', 'save_crop',
|
||||
'show_labels', 'show_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks', 'boxes', 'keras',
|
||||
'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'profile')
|
||||
|
||||
|
||||
def cfg2dict(cfg):
|
||||
"""
|
||||
Convert a configuration object to a dictionary, whether it is a file path, a string, or a SimpleNamespace object.
|
||||
|
||||
Args:
|
||||
cfg (str | Path | SimpleNamespace): Configuration object to be converted to a dictionary.
|
||||
|
||||
Returns:
|
||||
cfg (dict): Configuration object in dictionary format.
|
||||
"""
|
||||
if isinstance(cfg, (str, Path)):
|
||||
cfg = yaml_load(cfg) # load dict
|
||||
elif isinstance(cfg, SimpleNamespace):
|
||||
cfg = vars(cfg) # convert to dict
|
||||
return cfg
|
||||
|
||||
|
||||
def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, overrides: Dict = None):
|
||||
"""
|
||||
Load and merge configuration data from a file or dictionary.
|
||||
|
||||
Args:
|
||||
cfg (str | Path | Dict | SimpleNamespace): Configuration data.
|
||||
overrides (str | Dict | optional): Overrides in the form of a file name or a dictionary. Default is None.
|
||||
|
||||
Returns:
|
||||
(SimpleNamespace): Training arguments namespace.
|
||||
"""
|
||||
cfg = cfg2dict(cfg)
|
||||
|
||||
# Merge overrides
|
||||
if overrides:
|
||||
overrides = cfg2dict(overrides)
|
||||
check_cfg_mismatch(cfg, overrides)
|
||||
cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)
|
||||
|
||||
# Special handling for numeric project/name
|
||||
for k in 'project', 'name':
|
||||
if k in cfg and isinstance(cfg[k], (int, float)):
|
||||
cfg[k] = str(cfg[k])
|
||||
if cfg.get('name') == 'model': # assign model to 'name' arg
|
||||
cfg['name'] = cfg.get('model', '').split('.')[0]
|
||||
LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.")
|
||||
|
||||
# Type and Value checks
|
||||
for k, v in cfg.items():
|
||||
if v is not None: # None values may be from optional args
|
||||
if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)):
|
||||
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
|
||||
elif k in CFG_FRACTION_KEYS:
|
||||
if not isinstance(v, (int, float)):
|
||||
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
|
||||
if not (0.0 <= v <= 1.0):
|
||||
raise ValueError(f"'{k}={v}' is an invalid value. "
|
||||
f"Valid '{k}' values are between 0.0 and 1.0.")
|
||||
elif k in CFG_INT_KEYS and not isinstance(v, int):
|
||||
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||
f"'{k}' must be an int (i.e. '{k}=8')")
|
||||
elif k in CFG_BOOL_KEYS and not isinstance(v, bool):
|
||||
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||
f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')")
|
||||
|
||||
# Return instance
|
||||
return IterableSimpleNamespace(**cfg)
|
||||
|
||||
|
||||
def _handle_deprecation(custom):
|
||||
"""
|
||||
Hardcoded function to handle deprecated config keys
|
||||
"""
|
||||
|
||||
for key in custom.copy().keys():
|
||||
if key == 'hide_labels':
|
||||
deprecation_warn(key, 'show_labels')
|
||||
custom['show_labels'] = custom.pop('hide_labels') == 'False'
|
||||
if key == 'hide_conf':
|
||||
deprecation_warn(key, 'show_conf')
|
||||
custom['show_conf'] = custom.pop('hide_conf') == 'False'
|
||||
if key == 'line_thickness':
|
||||
deprecation_warn(key, 'line_width')
|
||||
custom['line_width'] = custom.pop('line_thickness')
|
||||
|
||||
return custom
|
||||
|
||||
|
||||
def check_cfg_mismatch(base: Dict, custom: Dict, e=None):
|
||||
"""
|
||||
This function checks for any mismatched keys between a custom configuration list and a base configuration list.
|
||||
If any mismatched keys are found, the function prints out similar keys from the base list and exits the program.
|
||||
|
||||
Args:
|
||||
custom (dict): a dictionary of custom configuration options
|
||||
base (dict): a dictionary of base configuration options
|
||||
"""
|
||||
custom = _handle_deprecation(custom)
|
||||
base, custom = (set(x.keys()) for x in (base, custom))
|
||||
mismatched = [x for x in custom if x not in base]
|
||||
if mismatched:
|
||||
string = ''
|
||||
for x in mismatched:
|
||||
matches = get_close_matches(x, base) # key list
|
||||
matches = [f'{k}={DEFAULT_CFG_DICT[k]}' if DEFAULT_CFG_DICT.get(k) is not None else k for k in matches]
|
||||
match_str = f'Similar arguments are i.e. {matches}.' if matches else ''
|
||||
string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n"
|
||||
raise SyntaxError(string + CLI_HELP_MSG) from e
|
||||
|
||||
|
||||
def merge_equals_args(args: List[str]) -> List[str]:
|
||||
"""
|
||||
Merges arguments around isolated '=' args in a list of strings.
|
||||
The function considers cases where the first argument ends with '=' or the second starts with '=',
|
||||
as well as when the middle one is an equals sign.
|
||||
|
||||
Args:
|
||||
args (List[str]): A list of strings where each element is an argument.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of strings where the arguments around isolated '=' are merged.
|
||||
"""
|
||||
new_args = []
|
||||
for i, arg in enumerate(args):
|
||||
if arg == '=' and 0 < i < len(args) - 1: # merge ['arg', '=', 'val']
|
||||
new_args[-1] += f'={args[i + 1]}'
|
||||
del args[i + 1]
|
||||
elif arg.endswith('=') and i < len(args) - 1 and '=' not in args[i + 1]: # merge ['arg=', 'val']
|
||||
new_args.append(f'{arg}{args[i + 1]}')
|
||||
del args[i + 1]
|
||||
elif arg.startswith('=') and i > 0: # merge ['arg', '=val']
|
||||
new_args[-1] += arg
|
||||
else:
|
||||
new_args.append(arg)
|
||||
return new_args
|
||||
|
||||
|
||||
def handle_yolo_hub(args: List[str]) -> None:
|
||||
"""
|
||||
Handle Ultralytics HUB command-line interface (CLI) commands.
|
||||
|
||||
This function processes Ultralytics HUB CLI commands such as login and logout.
|
||||
It should be called when executing a script with arguments related to HUB authentication.
|
||||
|
||||
Args:
|
||||
args (List[str]): A list of command line arguments
|
||||
|
||||
Example:
|
||||
python my_script.py hub login your_api_key
|
||||
"""
|
||||
from ultralytics import hub
|
||||
|
||||
if args[0] == 'login':
|
||||
key = args[1] if len(args) > 1 else ''
|
||||
# Log in to Ultralytics HUB using the provided API key
|
||||
hub.login(key)
|
||||
elif args[0] == 'logout':
|
||||
# Log out from Ultralytics HUB
|
||||
hub.logout()
|
||||
|
||||
|
||||
def handle_yolo_settings(args: List[str]) -> None:
|
||||
"""
|
||||
Handle YOLO settings command-line interface (CLI) commands.
|
||||
|
||||
This function processes YOLO settings CLI commands such as reset.
|
||||
It should be called when executing a script with arguments related to YOLO settings management.
|
||||
|
||||
Args:
|
||||
args (List[str]): A list of command line arguments for YOLO settings management.
|
||||
|
||||
Example:
|
||||
python my_script.py yolo settings reset
|
||||
"""
|
||||
path = USER_CONFIG_DIR / 'settings.yaml' # get SETTINGS YAML file path
|
||||
if any(args) and args[0] == 'reset':
|
||||
path.unlink() # delete the settings file
|
||||
get_settings() # create new settings
|
||||
LOGGER.info('Settings reset successfully') # inform the user that settings have been reset
|
||||
yaml_print(path) # print the current settings
|
||||
|
||||
|
||||
def entrypoint(debug=''):
|
||||
"""
|
||||
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
|
||||
to the package.
|
||||
|
||||
This function allows for:
|
||||
- passing mandatory YOLO args as a list of strings
|
||||
- specifying the task to be performed, either 'detect', 'segment' or 'classify'
|
||||
- specifying the mode, either 'train', 'val', 'test', or 'predict'
|
||||
- running special modes like 'checks'
|
||||
- passing overrides to the package's configuration
|
||||
|
||||
It uses the package's default cfg and initializes it using the passed overrides.
|
||||
Then it calls the CLI function with the composed cfg
|
||||
"""
|
||||
args = (debug.split(' ') if debug else sys.argv)[1:]
|
||||
if not args: # no arguments passed
|
||||
LOGGER.info(CLI_HELP_MSG)
|
||||
return
|
||||
|
||||
special = {
|
||||
'help': lambda: LOGGER.info(CLI_HELP_MSG),
|
||||
'checks': checks.check_yolo,
|
||||
'version': lambda: LOGGER.info(__version__),
|
||||
'settings': lambda: handle_yolo_settings(args[1:]),
|
||||
'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
|
||||
'hub': lambda: handle_yolo_hub(args[1:]),
|
||||
'login': lambda: handle_yolo_hub(args),
|
||||
'copy-cfg': copy_default_cfg}
|
||||
full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special}
|
||||
|
||||
# Define common mis-uses of special commands, i.e. -h, -help, --help
|
||||
special.update({k[0]: v for k, v in special.items()}) # singular
|
||||
special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith('s')}) # singular
|
||||
special = {**special, **{f'-{k}': v for k, v in special.items()}, **{f'--{k}': v for k, v in special.items()}}
|
||||
|
||||
overrides = {} # basic overrides, i.e. imgsz=320
|
||||
for a in merge_equals_args(args): # merge spaces around '=' sign
|
||||
if a.startswith('--'):
|
||||
LOGGER.warning(f"WARNING ⚠️ '{a}' does not require leading dashes '--', updating to '{a[2:]}'.")
|
||||
a = a[2:]
|
||||
if a.endswith(','):
|
||||
LOGGER.warning(f"WARNING ⚠️ '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.")
|
||||
a = a[:-1]
|
||||
if '=' in a:
|
||||
try:
|
||||
re.sub(r' *= *', '=', a) # remove spaces around equals sign
|
||||
k, v = a.split('=', 1) # split on first '=' sign
|
||||
assert v, f"missing '{k}' value"
|
||||
if k == 'cfg': # custom.yaml passed
|
||||
LOGGER.info(f'Overriding {DEFAULT_CFG_PATH} with {v}')
|
||||
overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != 'cfg'}
|
||||
else:
|
||||
if v.lower() == 'none':
|
||||
v = None
|
||||
elif v.lower() == 'true':
|
||||
v = True
|
||||
elif v.lower() == 'false':
|
||||
v = False
|
||||
else:
|
||||
with contextlib.suppress(Exception):
|
||||
v = eval(v)
|
||||
overrides[k] = v
|
||||
except (NameError, SyntaxError, ValueError, AssertionError) as e:
|
||||
check_cfg_mismatch(full_args_dict, {a: ''}, e)
|
||||
|
||||
elif a in TASKS:
|
||||
overrides['task'] = a
|
||||
elif a in MODES:
|
||||
overrides['mode'] = a
|
||||
elif a.lower() in special:
|
||||
special[a.lower()]()
|
||||
return
|
||||
elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool):
|
||||
overrides[a] = True # auto-True for default bool args, i.e. 'yolo show' sets show=True
|
||||
elif a in DEFAULT_CFG_DICT:
|
||||
raise SyntaxError(f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign "
|
||||
f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}")
|
||||
else:
|
||||
check_cfg_mismatch(full_args_dict, {a: ''})
|
||||
|
||||
# Check keys
|
||||
check_cfg_mismatch(full_args_dict, overrides)
|
||||
|
||||
# Mode
|
||||
mode = overrides.get('mode', None)
|
||||
if mode is None:
|
||||
mode = DEFAULT_CFG.mode or 'predict'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {MODES}. Using default 'mode={mode}'.")
|
||||
elif mode not in MODES:
|
||||
if mode not in ('checks', checks):
|
||||
raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}")
|
||||
LOGGER.warning("WARNING ⚠️ 'yolo mode=checks' is deprecated. Use 'yolo checks' instead.")
|
||||
checks.check_yolo()
|
||||
return
|
||||
|
||||
# Task
|
||||
task = overrides.pop('task', None)
|
||||
if task:
|
||||
if task not in TASKS:
|
||||
raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}")
|
||||
if 'model' not in overrides:
|
||||
overrides['model'] = TASK2MODEL[task]
|
||||
|
||||
# Model
|
||||
model = overrides.pop('model', DEFAULT_CFG.model)
|
||||
if model is None:
|
||||
model = 'yolov8n.pt'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
|
||||
overrides['model'] = model
|
||||
if 'rtdetr' in model.lower(): # guess architecture
|
||||
from ultralytics import RTDETR
|
||||
model = RTDETR(model) # no task argument
|
||||
elif 'sam' in model.lower():
|
||||
from ultralytics import SAM
|
||||
model = SAM(model)
|
||||
else:
|
||||
from ultralytics import YOLO
|
||||
model = YOLO(model, task=task)
|
||||
if isinstance(overrides.get('pretrained'), str):
|
||||
model.load(overrides['pretrained'])
|
||||
|
||||
# Task Update
|
||||
if task != model.task:
|
||||
if task:
|
||||
LOGGER.warning(f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. "
|
||||
f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model.")
|
||||
task = model.task
|
||||
|
||||
# Mode
|
||||
if mode in ('predict', 'track') and 'source' not in overrides:
|
||||
overrides['source'] = DEFAULT_CFG.source or ROOT / 'assets' if (ROOT / 'assets').exists() \
|
||||
else 'https://ultralytics.com/images/bus.jpg'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
|
||||
elif mode in ('train', 'val'):
|
||||
if 'data' not in overrides:
|
||||
overrides['data'] = TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
|
||||
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
|
||||
elif mode == 'export':
|
||||
if 'format' not in overrides:
|
||||
overrides['format'] = DEFAULT_CFG.format or 'torchscript'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.")
|
||||
|
||||
# Run command in python
|
||||
# getattr(model, mode)(**vars(get_cfg(overrides=overrides))) # default args using default.yaml
|
||||
getattr(model, mode)(**overrides) # default args from model
|
||||
|
||||
|
||||
# Special modes --------------------------------------------------------------------------------------------------------
|
||||
def copy_default_cfg():
|
||||
"""Copy and create a new default configuration file with '_copy' appended to its name."""
|
||||
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
|
||||
shutil.copy2(DEFAULT_CFG_PATH, new_file)
|
||||
LOGGER.info(f'{DEFAULT_CFG_PATH} copied to {new_file}\n'
|
||||
f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Example Usage: entrypoint(debug='yolo predict model=yolov8n.pt')
|
||||
entrypoint(debug='')
|
@ -29,7 +29,7 @@ names:
|
||||
download: |
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from ultralytics.yolo.utils.downloads import download
|
||||
from ultralytics.utils.downloads import download
|
||||
from pathlib import Path
|
||||
|
||||
def argoverse2yolo(set):
|
@ -32,7 +32,7 @@ names:
|
||||
|
||||
# Download script/URL (optional) ---------------------------------------------------------------------------------------
|
||||
download: |
|
||||
from ultralytics.yolo.utils.downloads import download
|
||||
from ultralytics.utils.downloads import download
|
||||
from pathlib import Path
|
||||
|
||||
# Download
|
@ -386,9 +386,9 @@ names:
|
||||
download: |
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
from ultralytics.yolo.utils.downloads import download
|
||||
from ultralytics.yolo.utils.ops import xyxy2xywhn
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
from ultralytics.utils.downloads import download
|
||||
from ultralytics.utils.ops import xyxy2xywhn
|
||||
|
||||
import numpy as np
|
||||
from pathlib import Path
|
@ -27,8 +27,8 @@ download: |
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.yolo.utils.downloads import download
|
||||
from ultralytics.yolo.utils.ops import xyxy2xywh
|
||||
from ultralytics.utils.downloads import download
|
||||
from ultralytics.utils.ops import xyxy2xywh
|
||||
|
||||
# Download
|
||||
dir = Path(yaml['path']) # dataset root dir
|
@ -48,7 +48,7 @@ download: |
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from tqdm import tqdm
|
||||
from ultralytics.yolo.utils.downloads import download
|
||||
from ultralytics.utils.downloads import download
|
||||
from pathlib import Path
|
||||
|
||||
def convert_label(path, lb_path, year, image_id):
|
@ -32,7 +32,7 @@ download: |
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics.yolo.utils.downloads import download
|
||||
from ultralytics.utils.downloads import download
|
||||
|
||||
def visdrone2yolo(dir):
|
||||
from PIL import Image
|
@ -23,7 +23,7 @@ names:
|
||||
|
||||
# Download script/URL (optional)
|
||||
download: |
|
||||
from ultralytics.yolo.utils.downloads import download
|
||||
from ultralytics.utils.downloads import download
|
||||
from pathlib import Path
|
||||
|
||||
# Download labels
|
@ -99,7 +99,7 @@ names:
|
||||
|
||||
# Download script/URL (optional)
|
||||
download: |
|
||||
from ultralytics.yolo.utils.downloads import download
|
||||
from ultralytics.utils.downloads import download
|
||||
from pathlib import Path
|
||||
|
||||
# Download labels
|
@ -87,8 +87,8 @@ download: |
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.yolo.data.utils import autosplit
|
||||
from ultralytics.yolo.utils.ops import xyxy2xywhn
|
||||
from ultralytics.data.utils import autosplit
|
||||
from ultralytics.utils.ops import xyxy2xywhn
|
||||
|
||||
|
||||
def convert_labels(fname=Path('xView/xView_train.geojson')):
|
8
ultralytics/data/__init__.py
Normal file
8
ultralytics/data/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from .base import BaseDataset
|
||||
from .build import build_dataloader, build_yolo_dataset, load_inference_source
|
||||
from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
|
||||
|
||||
__all__ = ('BaseDataset', 'ClassificationDataset', 'SemanticDataset', 'YOLODataset', 'build_yolo_dataset',
|
||||
'build_dataloader', 'load_inference_source')
|
@ -9,11 +9,12 @@ import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
|
||||
from ..utils import LOGGER, colorstr
|
||||
from ..utils.checks import check_version
|
||||
from ..utils.instance import Instances
|
||||
from ..utils.metrics import bbox_ioa
|
||||
from ..utils.ops import segment2box
|
||||
from ultralytics.utils import LOGGER, colorstr
|
||||
from ultralytics.utils.checks import check_version
|
||||
from ultralytics.utils.instance import Instances
|
||||
from ultralytics.utils.metrics import bbox_ioa
|
||||
from ultralytics.utils.ops import segment2box
|
||||
|
||||
from .utils import polygons2masks, polygons2masks_overlap
|
||||
|
||||
POSE_FLIPLR_INDEX = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
|
@ -15,7 +15,8 @@ import psutil
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT
|
||||
from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT
|
||||
|
||||
from .utils import HELP_URL, IMG_FORMATS
|
||||
|
||||
|
@ -9,12 +9,12 @@ import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import dataloader, distributed
|
||||
|
||||
from ultralytics.yolo.data.dataloaders.stream_loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots,
|
||||
LoadStreams, LoadTensor, SourceTypes, autocast_list)
|
||||
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
||||
from ultralytics.yolo.utils.checks import check_file
|
||||
from ultralytics.data.loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams, LoadTensor,
|
||||
SourceTypes, autocast_list)
|
||||
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
|
||||
from ultralytics.utils import RANK, colorstr
|
||||
from ultralytics.utils.checks import check_file
|
||||
|
||||
from ..utils import RANK, colorstr
|
||||
from .dataset import YOLODataset
|
||||
from .utils import PIN_MEMORY
|
||||
|
@ -6,8 +6,8 @@ import cv2
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
from ultralytics.yolo.utils.files import make_dirs
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
from ultralytics.utils.files import make_dirs
|
||||
|
||||
|
||||
def coco91_to_coco80_class():
|
@ -10,7 +10,8 @@ import torch
|
||||
import torchvision
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT, is_dir_writeable
|
||||
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT, is_dir_writeable
|
||||
|
||||
from .augment import Compose, Format, Instances, LetterBox, classify_albumentations, classify_transforms, v8_transforms
|
||||
from .base import BaseDataset
|
||||
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image_label
|
@ -15,9 +15,9 @@ import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
||||
from ultralytics.yolo.utils import LOGGER, ROOT, is_colab, is_kaggle, ops
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
|
||||
from ultralytics.utils import LOGGER, ROOT, is_colab, is_kaggle, ops
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -318,11 +318,10 @@ class LoadTensor:
|
||||
s = f'WARNING ⚠️ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) ' \
|
||||
f'divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible.'
|
||||
if len(im.shape) != 4:
|
||||
if len(im.shape) == 3:
|
||||
LOGGER.warning(s)
|
||||
im = im.unsqueeze(0)
|
||||
else:
|
||||
if len(im.shape) != 3:
|
||||
raise ValueError(s)
|
||||
LOGGER.warning(s)
|
||||
im = im.unsqueeze(0)
|
||||
if im.shape[2] % stride or im.shape[3] % stride:
|
||||
raise ValueError(s)
|
||||
if im.max() > 1.0:
|
@ -1,7 +1,7 @@
|
||||
#!/bin/bash
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
# Download latest models from https://github.com/ultralytics/assets/releases
|
||||
# Example usage: bash ultralytics/yolo/data/scripts/download_weights.sh
|
||||
# Example usage: bash ultralytics/data/scripts/download_weights.sh
|
||||
# parent
|
||||
# └── weights
|
||||
# ├── yolov8n.pt ← downloads here
|
||||
@ -9,9 +9,9 @@
|
||||
# └── ...
|
||||
|
||||
python - <<EOF
|
||||
from ultralytics.yolo.utils.downloads import attempt_download_asset
|
||||
from ultralytics.utils.downloads import attempt_download_asset
|
||||
|
||||
assets = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '-cls', '-seg')]
|
||||
assets = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '-cls', '-seg', '-pose')]
|
||||
for x in assets:
|
||||
attempt_download_asset(f'weights/{x}')
|
||||
|
@ -18,11 +18,11 @@ from PIL import ExifTags, Image, ImageOps
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.nn.autobackend import check_class_names
|
||||
from ultralytics.yolo.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, SETTINGS_YAML, clean_url, colorstr, emojis,
|
||||
yaml_load)
|
||||
from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii
|
||||
from ultralytics.yolo.utils.downloads import download, safe_download, unzip_file
|
||||
from ultralytics.yolo.utils.ops import segments2boxes
|
||||
from ultralytics.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, SETTINGS_YAML, clean_url, colorstr, emojis,
|
||||
yaml_load)
|
||||
from ultralytics.utils.checks import check_file, check_font, is_ascii
|
||||
from ultralytics.utils.downloads import download, safe_download, unzip_file
|
||||
from ultralytics.utils.ops import segments2boxes
|
||||
|
||||
HELP_URL = 'See https://docs.ultralytics.com/yolov5/tutorials/train_custom_data'
|
||||
IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # image suffixes
|
||||
@ -296,7 +296,7 @@ def check_cls_dataset(dataset: str, split=''):
|
||||
LOGGER.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
|
||||
t = time.time()
|
||||
if str(dataset) == 'imagenet':
|
||||
subprocess.run(f"bash {ROOT / 'yolo/data/scripts/get_imagenet.sh'}", shell=True, check=True)
|
||||
subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
|
||||
else:
|
||||
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
|
||||
download(url, dir=data_dir.parent)
|
||||
@ -326,7 +326,7 @@ class HUBDatasetStats():
|
||||
autodownload (bool): Attempt to download dataset if not found locally. Default is False.
|
||||
|
||||
Usage
|
||||
from ultralytics.yolo.data.utils import HUBDatasetStats
|
||||
from ultralytics.data.utils import HUBDatasetStats
|
||||
stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco8.zip', task='detect') # detect dataset
|
||||
stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco8-seg.zip', task='segment') # segment dataset
|
||||
stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco8-pose.zip', task='pose') # pose dataset
|
||||
@ -379,7 +379,7 @@ class HUBDatasetStats():
|
||||
|
||||
def get_json(self, save=False, verbose=False):
|
||||
"""Return dataset JSON for Ultralytics HUB."""
|
||||
from ultralytics.yolo.data import YOLODataset # ClassificationDataset
|
||||
from ultralytics.data import YOLODataset # ClassificationDataset
|
||||
|
||||
def _round(labels):
|
||||
"""Update labels to integer class and 4 decimal place floats."""
|
||||
@ -430,7 +430,7 @@ class HUBDatasetStats():
|
||||
|
||||
def process_images(self):
|
||||
"""Compress images for Ultralytics HUB."""
|
||||
from ultralytics.yolo.data import YOLODataset # ClassificationDataset
|
||||
from ultralytics.data import YOLODataset # ClassificationDataset
|
||||
|
||||
for split in 'train', 'val', 'test':
|
||||
if self.data.get(split) is None:
|
||||
@ -457,7 +457,7 @@ def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
|
||||
|
||||
Usage:
|
||||
from pathlib import Path
|
||||
from ultralytics.yolo.data.utils import compress_one_image
|
||||
from ultralytics.data.utils import compress_one_image
|
||||
for f in Path('/Users/glennjocher/Downloads/dataset').rglob('*.jpg'):
|
||||
compress_one_image(f)
|
||||
"""
|
||||
@ -485,7 +485,7 @@ def delete_dsstore(path):
|
||||
path (str, optional): The directory path where the ".DS_store" files should be deleted.
|
||||
|
||||
Usage:
|
||||
from ultralytics.yolo.data.utils import delete_dsstore
|
||||
from ultralytics.data.utils import delete_dsstore
|
||||
delete_dsstore('/Users/glennjocher/Downloads/dataset')
|
||||
|
||||
Note:
|
||||
@ -508,7 +508,7 @@ def zip_directory(dir, use_zipfile_library=True):
|
||||
use_zipfile_library (bool): Whether to use zipfile library or shutil for zipping.
|
||||
|
||||
Usage:
|
||||
from ultralytics.yolo.data.utils import zip_directory
|
||||
from ultralytics.data.utils import zip_directory
|
||||
zip_directory('/Users/glennjocher/Downloads/playground')
|
||||
|
||||
zip -r coco8-pose.zip coco8-pose
|
@ -60,17 +60,17 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.cfg import get_cfg
|
||||
from ultralytics.nn.autobackend import check_class_names
|
||||
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
|
||||
from ultralytics.nn.tasks import DetectionModel, SegmentationModel
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.utils import (ARM64, DEFAULT_CFG, LINUX, LOGGER, MACOS, ROOT, __version__, callbacks, colorstr,
|
||||
get_default_args, yaml_save)
|
||||
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version
|
||||
from ultralytics.yolo.utils.downloads import attempt_download_asset, get_github_assets
|
||||
from ultralytics.yolo.utils.files import file_size
|
||||
from ultralytics.yolo.utils.ops import Profile
|
||||
from ultralytics.yolo.utils.torch_utils import get_latest_opset, select_device, smart_inference_mode
|
||||
from ultralytics.utils import (ARM64, DEFAULT_CFG, LINUX, LOGGER, MACOS, ROOT, __version__, callbacks, colorstr,
|
||||
get_default_args, yaml_save)
|
||||
from ultralytics.utils.checks import check_imgsz, check_requirements, check_version
|
||||
from ultralytics.utils.downloads import attempt_download_asset, get_github_assets
|
||||
from ultralytics.utils.files import file_size
|
||||
from ultralytics.utils.ops import Profile
|
||||
from ultralytics.utils.torch_utils import get_latest_opset, select_device, smart_inference_mode
|
||||
|
||||
|
||||
def export_formats():
|
@ -4,29 +4,28 @@ import sys
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from ultralytics import yolo # noqa
|
||||
from ultralytics.cfg import get_cfg
|
||||
from ultralytics.engine.exporter import Exporter
|
||||
from ultralytics.models import yolo # noqa
|
||||
from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, PoseModel, SegmentationModel,
|
||||
attempt_load_one_weight, guess_model_task, nn, yaml_model_load)
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
|
||||
is_git_dir, yaml_load)
|
||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
|
||||
from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS
|
||||
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
|
||||
from ultralytics.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
|
||||
is_git_dir, yaml_load)
|
||||
from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
|
||||
from ultralytics.utils.downloads import GITHUB_ASSET_STEMS
|
||||
from ultralytics.utils.torch_utils import smart_inference_mode
|
||||
|
||||
# Map head to model, trainer, validator, and predictor classes
|
||||
TASK_MAP = {
|
||||
'classify': [
|
||||
ClassificationModel, yolo.v8.classify.ClassificationTrainer, yolo.v8.classify.ClassificationValidator,
|
||||
yolo.v8.classify.ClassificationPredictor],
|
||||
'detect': [
|
||||
DetectionModel, yolo.v8.detect.DetectionTrainer, yolo.v8.detect.DetectionValidator,
|
||||
yolo.v8.detect.DetectionPredictor],
|
||||
ClassificationModel, yolo.classify.ClassificationTrainer, yolo.classify.ClassificationValidator,
|
||||
yolo.classify.ClassificationPredictor],
|
||||
'detect':
|
||||
[DetectionModel, yolo.detect.DetectionTrainer, yolo.detect.DetectionValidator, yolo.detect.DetectionPredictor],
|
||||
'segment': [
|
||||
SegmentationModel, yolo.v8.segment.SegmentationTrainer, yolo.v8.segment.SegmentationValidator,
|
||||
yolo.v8.segment.SegmentationPredictor],
|
||||
'pose': [PoseModel, yolo.v8.pose.PoseTrainer, yolo.v8.pose.PoseValidator, yolo.v8.pose.PosePredictor]}
|
||||
SegmentationModel, yolo.segment.SegmentationTrainer, yolo.segment.SegmentationValidator,
|
||||
yolo.segment.SegmentationPredictor],
|
||||
'pose': [PoseModel, yolo.pose.PoseTrainer, yolo.pose.PoseValidator, yolo.pose.PosePredictor]}
|
||||
|
||||
|
||||
class YOLO:
|
||||
@ -63,11 +62,11 @@ class YOLO:
|
||||
Logs the model info.
|
||||
fuse() -> None:
|
||||
Fuses the model for faster inference.
|
||||
predict(source=None, stream=False, **kwargs) -> List[ultralytics.yolo.engine.results.Results]:
|
||||
predict(source=None, stream=False, **kwargs) -> List[ultralytics.engine.results.Results]:
|
||||
Performs prediction using the YOLO model.
|
||||
|
||||
Returns:
|
||||
list(ultralytics.yolo.engine.results.Results): The prediction results.
|
||||
list(ultralytics.engine.results.Results): The prediction results.
|
||||
"""
|
||||
|
||||
def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None) -> None:
|
||||
@ -230,7 +229,7 @@ class YOLO:
|
||||
Check the 'configuration' section in the documentation for all available options.
|
||||
|
||||
Returns:
|
||||
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
|
||||
(List[ultralytics.engine.results.Results]): The prediction results.
|
||||
"""
|
||||
if source is None:
|
||||
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
|
||||
@ -265,11 +264,11 @@ class YOLO:
|
||||
**kwargs (optional): Additional keyword arguments for the tracking process.
|
||||
|
||||
Returns:
|
||||
(List[ultralytics.yolo.engine.results.Results]): The tracking results.
|
||||
(List[ultralytics.engine.results.Results]): The tracking results.
|
||||
|
||||
"""
|
||||
if not hasattr(self.predictor, 'trackers'):
|
||||
from ultralytics.tracker import register_tracker
|
||||
from ultralytics.trackers import register_tracker
|
||||
register_tracker(self, persist)
|
||||
# ByteTrack-based method needs low confidence predictions as input
|
||||
conf = kwargs.get('conf') or 0.1
|
||||
@ -315,7 +314,7 @@ class YOLO:
|
||||
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
from ultralytics.yolo.utils.benchmarks import benchmark
|
||||
from ultralytics.utils.benchmarks import benchmark
|
||||
overrides = self.model.args.copy()
|
||||
overrides.update(kwargs)
|
||||
overrides['mode'] = 'benchmark'
|
||||
@ -389,7 +388,7 @@ class YOLO:
|
||||
|
||||
def tune(self, *args, **kwargs):
|
||||
"""
|
||||
Runs hyperparameter tuning using Ray Tune. See ultralytics.yolo.utils.tuner.run_ray_tune for Args.
|
||||
Runs hyperparameter tuning using Ray Tune. See ultralytics.utils.tuner.run_ray_tune for Args.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary containing the results of the hyperparameter search.
|
||||
@ -398,7 +397,7 @@ class YOLO:
|
||||
ModuleNotFoundError: If Ray Tune is not installed.
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
from ultralytics.yolo.utils.tuner import run_ray_tune
|
||||
from ultralytics.utils.tuner import run_ray_tune
|
||||
return run_ray_tune(self, *args, **kwargs)
|
||||
|
||||
@property
|
@ -34,14 +34,14 @@ import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.cfg import get_cfg
|
||||
from ultralytics.data import load_inference_source
|
||||
from ultralytics.data.augment import LetterBox, classify_transforms
|
||||
from ultralytics.nn.autobackend import AutoBackend
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.data import load_inference_source
|
||||
from ultralytics.yolo.data.augment import LetterBox, classify_transforms
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, MACOS, SETTINGS, WINDOWS, callbacks, colorstr, ops
|
||||
from ultralytics.yolo.utils.checks import check_imgsz, check_imshow
|
||||
from ultralytics.yolo.utils.files import increment_path
|
||||
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, SETTINGS, WINDOWS, callbacks, colorstr, ops
|
||||
from ultralytics.utils.checks import check_imgsz, check_imshow
|
||||
from ultralytics.utils.files import increment_path
|
||||
from ultralytics.utils.torch_utils import select_device, smart_inference_mode
|
||||
|
||||
STREAM_WARNING = """
|
||||
WARNING ⚠️ stream/video/webcam/dir predict source will accumulate results in RAM unless `stream=True` is passed,
|
@ -12,9 +12,9 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.data.augment import LetterBox
|
||||
from ultralytics.yolo.utils import LOGGER, SimpleClass, deprecation_warn, ops
|
||||
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
|
||||
from ultralytics.data.augment import LetterBox
|
||||
from ultralytics.utils import LOGGER, SimpleClass, deprecation_warn, ops
|
||||
from ultralytics.utils.plotting import Annotator, colors, save_one_box
|
||||
|
||||
|
||||
class BaseTensor(SimpleClass):
|
@ -21,17 +21,17 @@ from torch.cuda import amp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.cfg import get_cfg
|
||||
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
||||
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, __version__, callbacks,
|
||||
clean_url, colorstr, emojis, yaml_save)
|
||||
from ultralytics.yolo.utils.autobatch import check_train_batch_size
|
||||
from ultralytics.yolo.utils.checks import check_amp, check_file, check_imgsz, print_args
|
||||
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
|
||||
from ultralytics.yolo.utils.files import get_latest_run, increment_path
|
||||
from ultralytics.yolo.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle,
|
||||
select_device, strip_optimizer)
|
||||
from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, __version__, callbacks, clean_url,
|
||||
colorstr, emojis, yaml_save)
|
||||
from ultralytics.utils.autobatch import check_train_batch_size
|
||||
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, print_args
|
||||
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
|
||||
from ultralytics.utils.files import get_latest_run, increment_path
|
||||
from ultralytics.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, select_device,
|
||||
strip_optimizer)
|
||||
|
||||
|
||||
class BaseTrainer:
|
@ -25,14 +25,14 @@ from pathlib import Path
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.cfg import get_cfg
|
||||
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
||||
from ultralytics.nn.autobackend import AutoBackend
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis
|
||||
from ultralytics.yolo.utils.checks import check_imgsz
|
||||
from ultralytics.yolo.utils.files import increment_path
|
||||
from ultralytics.yolo.utils.ops import Profile
|
||||
from ultralytics.yolo.utils.torch_utils import de_parallel, select_device, smart_inference_mode
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis
|
||||
from ultralytics.utils.checks import check_imgsz
|
||||
from ultralytics.utils.files import increment_path
|
||||
from ultralytics.utils.ops import Profile
|
||||
from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode
|
||||
|
||||
|
||||
class BaseValidator:
|
@ -2,10 +2,10 @@
|
||||
|
||||
import requests
|
||||
|
||||
from ultralytics.data.utils import HUBDatasetStats
|
||||
from ultralytics.hub.auth import Auth
|
||||
from ultralytics.hub.utils import PREFIX
|
||||
from ultralytics.yolo.data.utils import HUBDatasetStats
|
||||
from ultralytics.yolo.utils import LOGGER, SETTINGS, USER_CONFIG_DIR, yaml_save
|
||||
from ultralytics.utils import LOGGER, SETTINGS, USER_CONFIG_DIR, yaml_save
|
||||
|
||||
|
||||
def login(api_key=''):
|
||||
@ -65,7 +65,7 @@ def reset_model(model_id=''):
|
||||
|
||||
def export_fmts_hub():
|
||||
"""Returns a list of HUB-supported export formats."""
|
||||
from ultralytics.yolo.engine.exporter import export_formats
|
||||
from ultralytics.engine.exporter import export_formats
|
||||
return list(export_formats()['Argument'][1:]) + ['ultralytics_tflite', 'ultralytics_coreml']
|
||||
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
import requests
|
||||
|
||||
from ultralytics.hub.utils import HUB_API_ROOT, PREFIX, request_with_credentials
|
||||
from ultralytics.yolo.utils import LOGGER, SETTINGS, emojis, is_colab, set_settings
|
||||
from ultralytics.utils import LOGGER, SETTINGS, emojis, is_colab, set_settings
|
||||
|
||||
API_KEY_URL = 'https://hub.ultralytics.com/settings?tab=api+keys'
|
||||
|
||||
|
@ -7,8 +7,8 @@ from time import sleep
|
||||
import requests
|
||||
|
||||
from ultralytics.hub.utils import HUB_API_ROOT, PREFIX, smart_request
|
||||
from ultralytics.yolo.utils import LOGGER, __version__, checks, emojis, is_colab, threaded
|
||||
from ultralytics.yolo.utils.errors import HUBModelError
|
||||
from ultralytics.utils import LOGGER, __version__, checks, emojis, is_colab, threaded
|
||||
from ultralytics.utils.errors import HUBModelError
|
||||
|
||||
AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'
|
||||
|
||||
|
@ -11,9 +11,8 @@ from pathlib import Path
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.yolo.utils import (ENVIRONMENT, LOGGER, ONLINE, RANK, SETTINGS, TESTS_RUNNING, TQDM_BAR_FORMAT,
|
||||
TryExcept, __version__, colorstr, get_git_origin_url, is_colab, is_git_dir,
|
||||
is_pip_package)
|
||||
from ultralytics.utils import (ENVIRONMENT, LOGGER, ONLINE, RANK, SETTINGS, TESTS_RUNNING, TQDM_BAR_FORMAT, TryExcept,
|
||||
__version__, colorstr, get_git_origin_url, is_colab, is_git_dir, is_pip_package)
|
||||
|
||||
PREFIX = colorstr('Ultralytics HUB: ')
|
||||
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
|
||||
|
@ -1,5 +1,3 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from .rtdetr import RTDETR
|
||||
from .sam import SAM
|
||||
|
@ -9,13 +9,13 @@ Usage - Predict:
|
||||
results = model.predict('ultralytics/assets/bus.jpg')
|
||||
"""
|
||||
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.engine.model import YOLO
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir
|
||||
from ultralytics.yolo.utils.checks import check_imgsz
|
||||
from ultralytics.cfg import get_cfg
|
||||
from ultralytics.engine.exporter import Exporter
|
||||
from ultralytics.engine.model import YOLO
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir
|
||||
from ultralytics.utils.checks import check_imgsz
|
||||
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
|
||||
|
||||
from ...yolo.utils.torch_utils import model_info, smart_inference_mode
|
||||
from .predict import FastSAMPredictor
|
||||
|
||||
|
||||
@ -41,7 +41,7 @@ class FastSAM(YOLO):
|
||||
Check the 'configuration' section in the documentation for all available options.
|
||||
|
||||
Returns:
|
||||
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
|
||||
(List[ultralytics.engine.results.Results]): The prediction results.
|
||||
"""
|
||||
if source is None:
|
||||
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
|
@ -2,10 +2,10 @@
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.engine.results import Results
|
||||
from ultralytics.yolo.fastsam.utils import bbox_iou
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, ops
|
||||
from ultralytics.yolo.v8.detect.predict import DetectionPredictor
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.models.fastsam.utils import bbox_iou
|
||||
from ultralytics.models.yolo.detect.predict import DetectionPredictor
|
||||
from ultralytics.utils import DEFAULT_CFG, ops
|
||||
|
||||
|
||||
class FastSAMPredictor(DetectionPredictor):
|
@ -22,7 +22,7 @@ class FastSAMPrompt:
|
||||
try:
|
||||
import clip # for linear_assignment
|
||||
except ImportError:
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source
|
||||
import clip
|
||||
self.clip = clip
|
@ -7,11 +7,11 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER, NUM_THREADS, ops
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
from ultralytics.yolo.utils.metrics import SegmentMetrics, box_iou, mask_iou
|
||||
from ultralytics.yolo.utils.plotting import output_to_target, plot_images
|
||||
from ultralytics.yolo.v8.detect import DetectionValidator
|
||||
from ultralytics.models.yolo.detect import DetectionValidator
|
||||
from ultralytics.utils import LOGGER, NUM_THREADS, ops
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
from ultralytics.utils.metrics import SegmentMetrics, box_iou, mask_iou
|
||||
from ultralytics.utils.plotting import output_to_target, plot_images
|
||||
|
||||
|
||||
class FastSAMValidator(DetectionValidator):
|
@ -13,12 +13,12 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, ROOT, is_git_dir
|
||||
from ultralytics.yolo.utils.checks import check_imgsz
|
||||
from ultralytics.cfg import get_cfg
|
||||
from ultralytics.engine.exporter import Exporter
|
||||
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, ROOT, is_git_dir
|
||||
from ultralytics.utils.checks import check_imgsz
|
||||
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
|
||||
|
||||
from ...yolo.utils.torch_utils import model_info, smart_inference_mode
|
||||
from .predict import NASPredictor
|
||||
from .val import NASValidator
|
||||
|
||||
@ -65,7 +65,7 @@ class NAS:
|
||||
Check the 'configuration' section in the documentation for all available options.
|
||||
|
||||
Returns:
|
||||
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
|
||||
(List[ultralytics.engine.results.Results]): The prediction results.
|
||||
"""
|
||||
if source is None:
|
||||
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
|
@ -2,10 +2,10 @@
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.engine.predictor import BasePredictor
|
||||
from ultralytics.yolo.engine.results import Results
|
||||
from ultralytics.yolo.utils import ops
|
||||
from ultralytics.yolo.utils.ops import xyxy2xywh
|
||||
from ultralytics.engine.predictor import BasePredictor
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.utils import ops
|
||||
from ultralytics.utils.ops import xyxy2xywh
|
||||
|
||||
|
||||
class NASPredictor(BasePredictor):
|
@ -2,9 +2,9 @@
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.utils import ops
|
||||
from ultralytics.yolo.utils.ops import xyxy2xywh
|
||||
from ultralytics.yolo.v8.detect import DetectionValidator
|
||||
from ultralytics.models.yolo.detect import DetectionValidator
|
||||
from ultralytics.utils import ops
|
||||
from ultralytics.utils.ops import xyxy2xywh
|
||||
|
||||
__all__ = ['NASValidator']
|
||||
|
@ -7,12 +7,12 @@ from pathlib import Path
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from ultralytics.cfg import get_cfg
|
||||
from ultralytics.engine.exporter import Exporter
|
||||
from ultralytics.nn.tasks import RTDETRDetectionModel, attempt_load_one_weight, yaml_model_load
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, RANK, ROOT, is_git_dir
|
||||
from ultralytics.yolo.utils.checks import check_imgsz
|
||||
from ultralytics.yolo.utils.torch_utils import model_info, smart_inference_mode
|
||||
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, RANK, ROOT, is_git_dir
|
||||
from ultralytics.utils.checks import check_imgsz
|
||||
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
|
||||
|
||||
from .predict import RTDETRPredictor
|
||||
from .train import RTDETRTrainer
|
||||
@ -72,7 +72,7 @@ class RTDETR:
|
||||
Check the 'configuration' section in the documentation for all available options.
|
||||
|
||||
Returns:
|
||||
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
|
||||
(List[ultralytics.engine.results.Results]): The prediction results.
|
||||
"""
|
||||
if source is None:
|
||||
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
|
@ -2,10 +2,10 @@
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.data.augment import LetterBox
|
||||
from ultralytics.yolo.engine.predictor import BasePredictor
|
||||
from ultralytics.yolo.engine.results import Results
|
||||
from ultralytics.yolo.utils import ops
|
||||
from ultralytics.data.augment import LetterBox
|
||||
from ultralytics.engine.predictor import BasePredictor
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.utils import ops
|
||||
|
||||
|
||||
class RTDETRPredictor(BasePredictor):
|
@ -4,9 +4,9 @@ from copy import copy
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.models.yolo.detect import DetectionTrainer
|
||||
from ultralytics.nn.tasks import RTDETRDetectionModel
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, RANK, colorstr
|
||||
from ultralytics.yolo.v8.detect import DetectionTrainer
|
||||
from ultralytics.utils import DEFAULT_CFG, RANK, colorstr
|
||||
|
||||
from .val import RTDETRDataset, RTDETRValidator
|
||||
|
@ -6,10 +6,10 @@ import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.data import YOLODataset
|
||||
from ultralytics.yolo.data.augment import Compose, Format, v8_transforms
|
||||
from ultralytics.yolo.utils import colorstr, ops
|
||||
from ultralytics.yolo.v8.detect import DetectionValidator
|
||||
from ultralytics.data import YOLODataset
|
||||
from ultralytics.data.augment import Compose, Format, v8_transforms
|
||||
from ultralytics.models.yolo.detect import DetectionValidator
|
||||
from ultralytics.utils import colorstr, ops
|
||||
|
||||
__all__ = 'RTDETRValidator', # tuple or list
|
||||
|
@ -10,7 +10,8 @@ from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
from ...yolo.utils.downloads import attempt_download_asset
|
||||
from ultralytics.utils.downloads import attempt_download_asset
|
||||
|
||||
from .modules.decoders import MaskDecoder
|
||||
from .modules.encoders import ImageEncoderViT, PromptEncoder
|
||||
from .modules.sam import Sam
|
@ -3,8 +3,8 @@
|
||||
SAM model interface
|
||||
"""
|
||||
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.utils.torch_utils import model_info
|
||||
from ultralytics.cfg import get_cfg
|
||||
from ultralytics.utils.torch_utils import model_info
|
||||
|
||||
from .build import build_sam
|
||||
from .predict import Predictor
|
@ -17,7 +17,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
|
||||
from ultralytics.yolo.utils.instance import to_2tuple
|
||||
from ultralytics.utils.instance import to_2tuple
|
||||
|
||||
|
||||
class Conv2d_BN(torch.nn.Sequential):
|
||||
@ -50,7 +50,7 @@ class Conv2d_BN(torch.nn.Sequential):
|
||||
|
||||
|
||||
# NOTE: This module and timm package is needed only for training.
|
||||
# from ultralytics.yolo.utils.checks import check_requirements
|
||||
# from ultralytics.utils.checks import check_requirements
|
||||
# check_requirements('timm')
|
||||
# from timm.models.layers import DropPath as TimmDropPath
|
||||
# from timm.models.layers import trunc_normal_
|
@ -5,11 +5,11 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
from ultralytics.yolo.data.augment import LetterBox
|
||||
from ultralytics.yolo.engine.predictor import BasePredictor
|
||||
from ultralytics.yolo.engine.results import Results
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, ops
|
||||
from ultralytics.yolo.utils.torch_utils import select_device
|
||||
from ultralytics.data.augment import LetterBox
|
||||
from ultralytics.engine.predictor import BasePredictor
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.utils import DEFAULT_CFG, ops
|
||||
from ultralytics.utils.torch_utils import select_device
|
||||
|
||||
from .amg import (batch_iterator, batched_mask_to_box, build_all_layer_point_grids, calculate_stability_score,
|
||||
generate_crop_boxes, is_box_near_crop_edge, remove_small_regions, uncrop_boxes_xyxy, uncrop_masks)
|
@ -4,9 +4,10 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ultralytics.vit.utils.ops import HungarianMatcher
|
||||
from ultralytics.yolo.utils.loss import FocalLoss, VarifocalLoss
|
||||
from ultralytics.yolo.utils.metrics import bbox_iou
|
||||
from ultralytics.utils.loss import FocalLoss, VarifocalLoss
|
||||
from ultralytics.utils.metrics import bbox_iou
|
||||
|
||||
from .ops import HungarianMatcher
|
||||
|
||||
|
||||
class DETRLoss(nn.Module):
|
@ -5,8 +5,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
from ultralytics.yolo.utils.metrics import bbox_iou
|
||||
from ultralytics.yolo.utils.ops import xywh2xyxy, xyxy2xywh
|
||||
from ultralytics.utils.metrics import bbox_iou
|
||||
from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh
|
||||
|
||||
|
||||
class HungarianMatcher(nn.Module):
|
5
ultralytics/models/yolo/__init__.py
Normal file
5
ultralytics/models/yolo/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from ultralytics.models.yolo import classify, detect, pose, segment
|
||||
|
||||
__all__ = 'classify', 'segment', 'detect', 'pose'
|
7
ultralytics/models/yolo/classify/__init__.py
Normal file
7
ultralytics/models/yolo/classify/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from ultralytics.models.yolo.classify.predict import ClassificationPredictor, predict
|
||||
from ultralytics.models.yolo.classify.train import ClassificationTrainer, train
|
||||
from ultralytics.models.yolo.classify.val import ClassificationValidator, val
|
||||
|
||||
__all__ = 'ClassificationPredictor', 'predict', 'ClassificationTrainer', 'train', 'ClassificationValidator', 'val'
|
@ -2,9 +2,9 @@
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.engine.predictor import BasePredictor
|
||||
from ultralytics.yolo.engine.results import Results
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, ROOT
|
||||
from ultralytics.engine.predictor import BasePredictor
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.utils import DEFAULT_CFG, ROOT
|
||||
|
||||
|
||||
class ClassificationPredictor(BasePredictor):
|
@ -3,13 +3,13 @@
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from ultralytics.data import ClassificationDataset, build_dataloader
|
||||
from ultralytics.engine.trainer import BaseTrainer
|
||||
from ultralytics.models import yolo
|
||||
from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight
|
||||
from ultralytics.yolo import v8
|
||||
from ultralytics.yolo.data import ClassificationDataset, build_dataloader
|
||||
from ultralytics.yolo.engine.trainer import BaseTrainer
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, colorstr
|
||||
from ultralytics.yolo.utils.plotting import plot_images, plot_results
|
||||
from ultralytics.yolo.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, colorstr
|
||||
from ultralytics.utils.plotting import plot_images, plot_results
|
||||
from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
|
||||
|
||||
|
||||
class ClassificationTrainer(BaseTrainer):
|
||||
@ -98,7 +98,7 @@ class ClassificationTrainer(BaseTrainer):
|
||||
def get_validator(self):
|
||||
"""Returns an instance of ClassificationValidator for validation."""
|
||||
self.loss_names = ['loss']
|
||||
return v8.classify.ClassificationValidator(self.test_loader, self.save_dir)
|
||||
return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir)
|
||||
|
||||
def label_loss_items(self, loss_items=None, prefix='train'):
|
||||
"""
|
@ -2,11 +2,11 @@
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.data import ClassificationDataset, build_dataloader
|
||||
from ultralytics.yolo.engine.validator import BaseValidator
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER
|
||||
from ultralytics.yolo.utils.metrics import ClassifyMetrics, ConfusionMatrix
|
||||
from ultralytics.yolo.utils.plotting import plot_images
|
||||
from ultralytics.data import ClassificationDataset, build_dataloader
|
||||
from ultralytics.engine.validator import BaseValidator
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER
|
||||
from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix
|
||||
from ultralytics.utils.plotting import plot_images
|
||||
|
||||
|
||||
class ClassificationValidator(BaseValidator):
|
@ -2,9 +2,9 @@
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.engine.predictor import BasePredictor
|
||||
from ultralytics.yolo.engine.results import Results
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops
|
||||
from ultralytics.engine.predictor import BasePredictor
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.utils import DEFAULT_CFG, ROOT, ops
|
||||
|
||||
|
||||
class DetectionPredictor(BasePredictor):
|
@ -3,13 +3,13 @@ from copy import copy
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ultralytics.data import build_dataloader, build_yolo_dataset
|
||||
from ultralytics.engine.trainer import BaseTrainer
|
||||
from ultralytics.models import yolo
|
||||
from ultralytics.nn.tasks import DetectionModel
|
||||
from ultralytics.yolo import v8
|
||||
from ultralytics.yolo.data import build_dataloader, build_yolo_dataset
|
||||
from ultralytics.yolo.engine.trainer import BaseTrainer
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK
|
||||
from ultralytics.yolo.utils.plotting import plot_images, plot_labels, plot_results
|
||||
from ultralytics.yolo.utils.torch_utils import de_parallel, torch_distributed_zero_first
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
|
||||
from ultralytics.utils.plotting import plot_images, plot_labels, plot_results
|
||||
from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first
|
||||
|
||||
|
||||
# BaseTrainer python usage
|
||||
@ -64,7 +64,7 @@ class DetectionTrainer(BaseTrainer):
|
||||
def get_validator(self):
|
||||
"""Returns a DetectionValidator for YOLO model validation."""
|
||||
self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
|
||||
return v8.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
||||
return yolo.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
||||
|
||||
def label_loss_items(self, loss_items=None, prefix='train'):
|
||||
"""
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user