ultralytics 8.0.136 refactor and simplify package (#3748)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing
2023-07-16 23:47:45 +08:00
committed by GitHub
parent 8ebe94d1e9
commit 620f3eb218
383 changed files with 4213 additions and 4646 deletions

View File

@ -1,14 +1,13 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.135'
__version__ = '8.0.136'
from ultralytics.engine.model import YOLO
from ultralytics.hub import start
from ultralytics.vit.rtdetr import RTDETR
from ultralytics.vit.sam import SAM
from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.fastsam import FastSAM
from ultralytics.yolo.nas import NAS
from ultralytics.yolo.utils.checks import check_yolo as checks
from ultralytics.yolo.utils.downloads import download
from ultralytics.models import RTDETR, SAM
from ultralytics.models.fastsam import FastSAM
from ultralytics.models.nas import NAS
from ultralytics.utils.checks import check_yolo as checks
from ultralytics.utils.downloads import download
__all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'FastSAM', 'RTDETR', 'checks', 'download', 'start' # allow simpler import

421
ultralytics/cfg/__init__.py Normal file
View File

@ -0,0 +1,421 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
import re
import shutil
import sys
from difflib import get_close_matches
from pathlib import Path
from types import SimpleNamespace
from typing import Dict, List, Union
from ultralytics.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, ROOT, USER_CONFIG_DIR,
IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn, get_settings,
yaml_load, yaml_print)
# Define valid tasks and modes
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
TASKS = 'detect', 'segment', 'classify', 'pose'
TASK2DATA = {'detect': 'coco8.yaml', 'segment': 'coco8-seg.yaml', 'classify': 'imagenet100', 'pose': 'coco8-pose.yaml'}
TASK2MODEL = {
'detect': 'yolov8n.pt',
'segment': 'yolov8n-seg.pt',
'classify': 'yolov8n-cls.pt',
'pose': 'yolov8n-pose.pt'}
TASK2METRIC = {
'detect': 'metrics/mAP50-95(B)',
'segment': 'metrics/mAP50-95(M)',
'classify': 'metrics/accuracy_top1',
'pose': 'metrics/mAP50-95(P)'}
CLI_HELP_MSG = \
f"""
Arguments received: {str(['yolo'] + sys.argv[1:])}. Ultralytics 'yolo' commands use the following syntax:
yolo TASK MODE ARGS
Where TASK (optional) is one of {TASKS}
MODE (required) is one of {MODES}
ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
See all ARGS at https://docs.ultralytics.com/usage/cfg or with 'yolo cfg'
1. Train a detection model for 10 epochs with an initial learning_rate of 0.01
yolo train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
2. Predict a YouTube video using a pretrained segmentation model at image size 320:
yolo predict model=yolov8n-seg.pt source='https://youtu.be/Zgi9g1ksQHc' imgsz=320
3. Val a pretrained detection model at batch-size 1 and image size 640:
yolo val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
4. Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
5. Run special commands:
yolo help
yolo checks
yolo version
yolo settings
yolo copy-cfg
yolo cfg
Docs: https://docs.ultralytics.com
Community: https://community.ultralytics.com
GitHub: https://github.com/ultralytics/ultralytics
"""
# Define keys for arg type checks
CFG_FLOAT_KEYS = 'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear'
CFG_FRACTION_KEYS = ('dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr',
'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud',
'fliplr', 'mosaic', 'mixup', 'copy_paste', 'conf', 'iou', 'fraction') # fraction floats 0.0 - 1.0
CFG_INT_KEYS = ('epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
'line_width', 'workspace', 'nbs', 'save_period')
CFG_BOOL_KEYS = ('save', 'exist_ok', 'verbose', 'deterministic', 'single_cls', 'rect', 'cos_lr', 'overlap_mask', 'val',
'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt', 'save_conf', 'save_crop',
'show_labels', 'show_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks', 'boxes', 'keras',
'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'profile')
def cfg2dict(cfg):
"""
Convert a configuration object to a dictionary, whether it is a file path, a string, or a SimpleNamespace object.
Args:
cfg (str | Path | SimpleNamespace): Configuration object to be converted to a dictionary.
Returns:
cfg (dict): Configuration object in dictionary format.
"""
if isinstance(cfg, (str, Path)):
cfg = yaml_load(cfg) # load dict
elif isinstance(cfg, SimpleNamespace):
cfg = vars(cfg) # convert to dict
return cfg
def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, overrides: Dict = None):
"""
Load and merge configuration data from a file or dictionary.
Args:
cfg (str | Path | Dict | SimpleNamespace): Configuration data.
overrides (str | Dict | optional): Overrides in the form of a file name or a dictionary. Default is None.
Returns:
(SimpleNamespace): Training arguments namespace.
"""
cfg = cfg2dict(cfg)
# Merge overrides
if overrides:
overrides = cfg2dict(overrides)
check_cfg_mismatch(cfg, overrides)
cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)
# Special handling for numeric project/name
for k in 'project', 'name':
if k in cfg and isinstance(cfg[k], (int, float)):
cfg[k] = str(cfg[k])
if cfg.get('name') == 'model': # assign model to 'name' arg
cfg['name'] = cfg.get('model', '').split('.')[0]
LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.")
# Type and Value checks
for k, v in cfg.items():
if v is not None: # None values may be from optional args
if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)):
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
elif k in CFG_FRACTION_KEYS:
if not isinstance(v, (int, float)):
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
if not (0.0 <= v <= 1.0):
raise ValueError(f"'{k}={v}' is an invalid value. "
f"Valid '{k}' values are between 0.0 and 1.0.")
elif k in CFG_INT_KEYS and not isinstance(v, int):
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
f"'{k}' must be an int (i.e. '{k}=8')")
elif k in CFG_BOOL_KEYS and not isinstance(v, bool):
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')")
# Return instance
return IterableSimpleNamespace(**cfg)
def _handle_deprecation(custom):
"""
Hardcoded function to handle deprecated config keys
"""
for key in custom.copy().keys():
if key == 'hide_labels':
deprecation_warn(key, 'show_labels')
custom['show_labels'] = custom.pop('hide_labels') == 'False'
if key == 'hide_conf':
deprecation_warn(key, 'show_conf')
custom['show_conf'] = custom.pop('hide_conf') == 'False'
if key == 'line_thickness':
deprecation_warn(key, 'line_width')
custom['line_width'] = custom.pop('line_thickness')
return custom
def check_cfg_mismatch(base: Dict, custom: Dict, e=None):
"""
This function checks for any mismatched keys between a custom configuration list and a base configuration list.
If any mismatched keys are found, the function prints out similar keys from the base list and exits the program.
Args:
custom (dict): a dictionary of custom configuration options
base (dict): a dictionary of base configuration options
"""
custom = _handle_deprecation(custom)
base, custom = (set(x.keys()) for x in (base, custom))
mismatched = [x for x in custom if x not in base]
if mismatched:
string = ''
for x in mismatched:
matches = get_close_matches(x, base) # key list
matches = [f'{k}={DEFAULT_CFG_DICT[k]}' if DEFAULT_CFG_DICT.get(k) is not None else k for k in matches]
match_str = f'Similar arguments are i.e. {matches}.' if matches else ''
string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n"
raise SyntaxError(string + CLI_HELP_MSG) from e
def merge_equals_args(args: List[str]) -> List[str]:
"""
Merges arguments around isolated '=' args in a list of strings.
The function considers cases where the first argument ends with '=' or the second starts with '=',
as well as when the middle one is an equals sign.
Args:
args (List[str]): A list of strings where each element is an argument.
Returns:
List[str]: A list of strings where the arguments around isolated '=' are merged.
"""
new_args = []
for i, arg in enumerate(args):
if arg == '=' and 0 < i < len(args) - 1: # merge ['arg', '=', 'val']
new_args[-1] += f'={args[i + 1]}'
del args[i + 1]
elif arg.endswith('=') and i < len(args) - 1 and '=' not in args[i + 1]: # merge ['arg=', 'val']
new_args.append(f'{arg}{args[i + 1]}')
del args[i + 1]
elif arg.startswith('=') and i > 0: # merge ['arg', '=val']
new_args[-1] += arg
else:
new_args.append(arg)
return new_args
def handle_yolo_hub(args: List[str]) -> None:
"""
Handle Ultralytics HUB command-line interface (CLI) commands.
This function processes Ultralytics HUB CLI commands such as login and logout.
It should be called when executing a script with arguments related to HUB authentication.
Args:
args (List[str]): A list of command line arguments
Example:
python my_script.py hub login your_api_key
"""
from ultralytics import hub
if args[0] == 'login':
key = args[1] if len(args) > 1 else ''
# Log in to Ultralytics HUB using the provided API key
hub.login(key)
elif args[0] == 'logout':
# Log out from Ultralytics HUB
hub.logout()
def handle_yolo_settings(args: List[str]) -> None:
"""
Handle YOLO settings command-line interface (CLI) commands.
This function processes YOLO settings CLI commands such as reset.
It should be called when executing a script with arguments related to YOLO settings management.
Args:
args (List[str]): A list of command line arguments for YOLO settings management.
Example:
python my_script.py yolo settings reset
"""
path = USER_CONFIG_DIR / 'settings.yaml' # get SETTINGS YAML file path
if any(args) and args[0] == 'reset':
path.unlink() # delete the settings file
get_settings() # create new settings
LOGGER.info('Settings reset successfully') # inform the user that settings have been reset
yaml_print(path) # print the current settings
def entrypoint(debug=''):
"""
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
to the package.
This function allows for:
- passing mandatory YOLO args as a list of strings
- specifying the task to be performed, either 'detect', 'segment' or 'classify'
- specifying the mode, either 'train', 'val', 'test', or 'predict'
- running special modes like 'checks'
- passing overrides to the package's configuration
It uses the package's default cfg and initializes it using the passed overrides.
Then it calls the CLI function with the composed cfg
"""
args = (debug.split(' ') if debug else sys.argv)[1:]
if not args: # no arguments passed
LOGGER.info(CLI_HELP_MSG)
return
special = {
'help': lambda: LOGGER.info(CLI_HELP_MSG),
'checks': checks.check_yolo,
'version': lambda: LOGGER.info(__version__),
'settings': lambda: handle_yolo_settings(args[1:]),
'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
'hub': lambda: handle_yolo_hub(args[1:]),
'login': lambda: handle_yolo_hub(args),
'copy-cfg': copy_default_cfg}
full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special}
# Define common mis-uses of special commands, i.e. -h, -help, --help
special.update({k[0]: v for k, v in special.items()}) # singular
special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith('s')}) # singular
special = {**special, **{f'-{k}': v for k, v in special.items()}, **{f'--{k}': v for k, v in special.items()}}
overrides = {} # basic overrides, i.e. imgsz=320
for a in merge_equals_args(args): # merge spaces around '=' sign
if a.startswith('--'):
LOGGER.warning(f"WARNING ⚠️ '{a}' does not require leading dashes '--', updating to '{a[2:]}'.")
a = a[2:]
if a.endswith(','):
LOGGER.warning(f"WARNING ⚠️ '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.")
a = a[:-1]
if '=' in a:
try:
re.sub(r' *= *', '=', a) # remove spaces around equals sign
k, v = a.split('=', 1) # split on first '=' sign
assert v, f"missing '{k}' value"
if k == 'cfg': # custom.yaml passed
LOGGER.info(f'Overriding {DEFAULT_CFG_PATH} with {v}')
overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != 'cfg'}
else:
if v.lower() == 'none':
v = None
elif v.lower() == 'true':
v = True
elif v.lower() == 'false':
v = False
else:
with contextlib.suppress(Exception):
v = eval(v)
overrides[k] = v
except (NameError, SyntaxError, ValueError, AssertionError) as e:
check_cfg_mismatch(full_args_dict, {a: ''}, e)
elif a in TASKS:
overrides['task'] = a
elif a in MODES:
overrides['mode'] = a
elif a.lower() in special:
special[a.lower()]()
return
elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool):
overrides[a] = True # auto-True for default bool args, i.e. 'yolo show' sets show=True
elif a in DEFAULT_CFG_DICT:
raise SyntaxError(f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign "
f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}")
else:
check_cfg_mismatch(full_args_dict, {a: ''})
# Check keys
check_cfg_mismatch(full_args_dict, overrides)
# Mode
mode = overrides.get('mode', None)
if mode is None:
mode = DEFAULT_CFG.mode or 'predict'
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {MODES}. Using default 'mode={mode}'.")
elif mode not in MODES:
if mode not in ('checks', checks):
raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}")
LOGGER.warning("WARNING ⚠️ 'yolo mode=checks' is deprecated. Use 'yolo checks' instead.")
checks.check_yolo()
return
# Task
task = overrides.pop('task', None)
if task:
if task not in TASKS:
raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}")
if 'model' not in overrides:
overrides['model'] = TASK2MODEL[task]
# Model
model = overrides.pop('model', DEFAULT_CFG.model)
if model is None:
model = 'yolov8n.pt'
LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
overrides['model'] = model
if 'rtdetr' in model.lower(): # guess architecture
from ultralytics import RTDETR
model = RTDETR(model) # no task argument
elif 'sam' in model.lower():
from ultralytics import SAM
model = SAM(model)
else:
from ultralytics import YOLO
model = YOLO(model, task=task)
if isinstance(overrides.get('pretrained'), str):
model.load(overrides['pretrained'])
# Task Update
if task != model.task:
if task:
LOGGER.warning(f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. "
f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model.")
task = model.task
# Mode
if mode in ('predict', 'track') and 'source' not in overrides:
overrides['source'] = DEFAULT_CFG.source or ROOT / 'assets' if (ROOT / 'assets').exists() \
else 'https://ultralytics.com/images/bus.jpg'
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
elif mode in ('train', 'val'):
if 'data' not in overrides:
overrides['data'] = TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
elif mode == 'export':
if 'format' not in overrides:
overrides['format'] = DEFAULT_CFG.format or 'torchscript'
LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.")
# Run command in python
# getattr(model, mode)(**vars(get_cfg(overrides=overrides))) # default args using default.yaml
getattr(model, mode)(**overrides) # default args from model
# Special modes --------------------------------------------------------------------------------------------------------
def copy_default_cfg():
"""Copy and create a new default configuration file with '_copy' appended to its name."""
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
shutil.copy2(DEFAULT_CFG_PATH, new_file)
LOGGER.info(f'{DEFAULT_CFG_PATH} copied to {new_file}\n'
f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8")
if __name__ == '__main__':
# Example Usage: entrypoint(debug='yolo predict model=yolov8n.pt')
entrypoint(debug='')

View File

@ -29,7 +29,7 @@ names:
download: |
import json
from tqdm import tqdm
from ultralytics.yolo.utils.downloads import download
from ultralytics.utils.downloads import download
from pathlib import Path
def argoverse2yolo(set):

View File

@ -32,7 +32,7 @@ names:
# Download script/URL (optional) ---------------------------------------------------------------------------------------
download: |
from ultralytics.yolo.utils.downloads import download
from ultralytics.utils.downloads import download
from pathlib import Path
# Download

View File

@ -386,9 +386,9 @@ names:
download: |
from tqdm import tqdm
from ultralytics.yolo.utils.checks import check_requirements
from ultralytics.yolo.utils.downloads import download
from ultralytics.yolo.utils.ops import xyxy2xywhn
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.downloads import download
from ultralytics.utils.ops import xyxy2xywhn
import numpy as np
from pathlib import Path

View File

@ -27,8 +27,8 @@ download: |
import pandas as pd
from tqdm import tqdm
from ultralytics.yolo.utils.downloads import download
from ultralytics.yolo.utils.ops import xyxy2xywh
from ultralytics.utils.downloads import download
from ultralytics.utils.ops import xyxy2xywh
# Download
dir = Path(yaml['path']) # dataset root dir

View File

@ -48,7 +48,7 @@ download: |
import xml.etree.ElementTree as ET
from tqdm import tqdm
from ultralytics.yolo.utils.downloads import download
from ultralytics.utils.downloads import download
from pathlib import Path
def convert_label(path, lb_path, year, image_id):

View File

@ -32,7 +32,7 @@ download: |
import os
from pathlib import Path
from ultralytics.yolo.utils.downloads import download
from ultralytics.utils.downloads import download
def visdrone2yolo(dir):
from PIL import Image

View File

@ -23,7 +23,7 @@ names:
# Download script/URL (optional)
download: |
from ultralytics.yolo.utils.downloads import download
from ultralytics.utils.downloads import download
from pathlib import Path
# Download labels

View File

@ -99,7 +99,7 @@ names:
# Download script/URL (optional)
download: |
from ultralytics.yolo.utils.downloads import download
from ultralytics.utils.downloads import download
from pathlib import Path
# Download labels

View File

@ -87,8 +87,8 @@ download: |
from PIL import Image
from tqdm import tqdm
from ultralytics.yolo.data.utils import autosplit
from ultralytics.yolo.utils.ops import xyxy2xywhn
from ultralytics.data.utils import autosplit
from ultralytics.utils.ops import xyxy2xywhn
def convert_labels(fname=Path('xView/xView_train.geojson')):

View File

@ -0,0 +1,8 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .base import BaseDataset
from .build import build_dataloader, build_yolo_dataset, load_inference_source
from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
__all__ = ('BaseDataset', 'ClassificationDataset', 'SemanticDataset', 'YOLODataset', 'build_yolo_dataset',
'build_dataloader', 'load_inference_source')

View File

@ -9,11 +9,12 @@ import numpy as np
import torch
import torchvision.transforms as T
from ..utils import LOGGER, colorstr
from ..utils.checks import check_version
from ..utils.instance import Instances
from ..utils.metrics import bbox_ioa
from ..utils.ops import segment2box
from ultralytics.utils import LOGGER, colorstr
from ultralytics.utils.checks import check_version
from ultralytics.utils.instance import Instances
from ultralytics.utils.metrics import bbox_ioa
from ultralytics.utils.ops import segment2box
from .utils import polygons2masks, polygons2masks_overlap
POSE_FLIPLR_INDEX = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]

View File

@ -15,7 +15,8 @@ import psutil
from torch.utils.data import Dataset
from tqdm import tqdm
from ..utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT
from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT
from .utils import HELP_URL, IMG_FORMATS

View File

@ -9,12 +9,12 @@ import torch
from PIL import Image
from torch.utils.data import dataloader, distributed
from ultralytics.yolo.data.dataloaders.stream_loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots,
LoadStreams, LoadTensor, SourceTypes, autocast_list)
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.yolo.utils.checks import check_file
from ultralytics.data.loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams, LoadTensor,
SourceTypes, autocast_list)
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.utils import RANK, colorstr
from ultralytics.utils.checks import check_file
from ..utils import RANK, colorstr
from .dataset import YOLODataset
from .utils import PIN_MEMORY

View File

@ -6,8 +6,8 @@ import cv2
import numpy as np
from tqdm import tqdm
from ultralytics.yolo.utils.checks import check_requirements
from ultralytics.yolo.utils.files import make_dirs
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.files import make_dirs
def coco91_to_coco80_class():

View File

@ -10,7 +10,8 @@ import torch
import torchvision
from tqdm import tqdm
from ..utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT, is_dir_writeable
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT, is_dir_writeable
from .augment import Compose, Format, Instances, LetterBox, classify_albumentations, classify_transforms, v8_transforms
from .base import BaseDataset
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image_label

View File

@ -15,9 +15,9 @@ import requests
import torch
from PIL import Image
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.yolo.utils import LOGGER, ROOT, is_colab, is_kaggle, ops
from ultralytics.yolo.utils.checks import check_requirements
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.utils import LOGGER, ROOT, is_colab, is_kaggle, ops
from ultralytics.utils.checks import check_requirements
@dataclass
@ -318,11 +318,10 @@ class LoadTensor:
s = f'WARNING ⚠️ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) ' \
f'divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible.'
if len(im.shape) != 4:
if len(im.shape) == 3:
LOGGER.warning(s)
im = im.unsqueeze(0)
else:
if len(im.shape) != 3:
raise ValueError(s)
LOGGER.warning(s)
im = im.unsqueeze(0)
if im.shape[2] % stride or im.shape[3] % stride:
raise ValueError(s)
if im.max() > 1.0:

View File

@ -1,7 +1,7 @@
#!/bin/bash
# Ultralytics YOLO 🚀, AGPL-3.0 license
# Download latest models from https://github.com/ultralytics/assets/releases
# Example usage: bash ultralytics/yolo/data/scripts/download_weights.sh
# Example usage: bash ultralytics/data/scripts/download_weights.sh
# parent
# └── weights
# ├── yolov8n.pt ← downloads here
@ -9,9 +9,9 @@
# └── ...
python - <<EOF
from ultralytics.yolo.utils.downloads import attempt_download_asset
from ultralytics.utils.downloads import attempt_download_asset
assets = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '-cls', '-seg')]
assets = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '-cls', '-seg', '-pose')]
for x in assets:
attempt_download_asset(f'weights/{x}')

View File

@ -18,11 +18,11 @@ from PIL import ExifTags, Image, ImageOps
from tqdm import tqdm
from ultralytics.nn.autobackend import check_class_names
from ultralytics.yolo.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, SETTINGS_YAML, clean_url, colorstr, emojis,
yaml_load)
from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii
from ultralytics.yolo.utils.downloads import download, safe_download, unzip_file
from ultralytics.yolo.utils.ops import segments2boxes
from ultralytics.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, SETTINGS_YAML, clean_url, colorstr, emojis,
yaml_load)
from ultralytics.utils.checks import check_file, check_font, is_ascii
from ultralytics.utils.downloads import download, safe_download, unzip_file
from ultralytics.utils.ops import segments2boxes
HELP_URL = 'See https://docs.ultralytics.com/yolov5/tutorials/train_custom_data'
IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # image suffixes
@ -296,7 +296,7 @@ def check_cls_dataset(dataset: str, split=''):
LOGGER.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
t = time.time()
if str(dataset) == 'imagenet':
subprocess.run(f"bash {ROOT / 'yolo/data/scripts/get_imagenet.sh'}", shell=True, check=True)
subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
else:
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
download(url, dir=data_dir.parent)
@ -326,7 +326,7 @@ class HUBDatasetStats():
autodownload (bool): Attempt to download dataset if not found locally. Default is False.
Usage
from ultralytics.yolo.data.utils import HUBDatasetStats
from ultralytics.data.utils import HUBDatasetStats
stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco8.zip', task='detect') # detect dataset
stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco8-seg.zip', task='segment') # segment dataset
stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco8-pose.zip', task='pose') # pose dataset
@ -379,7 +379,7 @@ class HUBDatasetStats():
def get_json(self, save=False, verbose=False):
"""Return dataset JSON for Ultralytics HUB."""
from ultralytics.yolo.data import YOLODataset # ClassificationDataset
from ultralytics.data import YOLODataset # ClassificationDataset
def _round(labels):
"""Update labels to integer class and 4 decimal place floats."""
@ -430,7 +430,7 @@ class HUBDatasetStats():
def process_images(self):
"""Compress images for Ultralytics HUB."""
from ultralytics.yolo.data import YOLODataset # ClassificationDataset
from ultralytics.data import YOLODataset # ClassificationDataset
for split in 'train', 'val', 'test':
if self.data.get(split) is None:
@ -457,7 +457,7 @@ def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
Usage:
from pathlib import Path
from ultralytics.yolo.data.utils import compress_one_image
from ultralytics.data.utils import compress_one_image
for f in Path('/Users/glennjocher/Downloads/dataset').rglob('*.jpg'):
compress_one_image(f)
"""
@ -485,7 +485,7 @@ def delete_dsstore(path):
path (str, optional): The directory path where the ".DS_store" files should be deleted.
Usage:
from ultralytics.yolo.data.utils import delete_dsstore
from ultralytics.data.utils import delete_dsstore
delete_dsstore('/Users/glennjocher/Downloads/dataset')
Note:
@ -508,7 +508,7 @@ def zip_directory(dir, use_zipfile_library=True):
use_zipfile_library (bool): Whether to use zipfile library or shutil for zipping.
Usage:
from ultralytics.yolo.data.utils import zip_directory
from ultralytics.data.utils import zip_directory
zip_directory('/Users/glennjocher/Downloads/playground')
zip -r coco8-pose.zip coco8-pose

View File

@ -60,17 +60,17 @@ from pathlib import Path
import torch
from ultralytics.cfg import get_cfg
from ultralytics.nn.autobackend import check_class_names
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
from ultralytics.nn.tasks import DetectionModel, SegmentationModel
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.utils import (ARM64, DEFAULT_CFG, LINUX, LOGGER, MACOS, ROOT, __version__, callbacks, colorstr,
get_default_args, yaml_save)
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version
from ultralytics.yolo.utils.downloads import attempt_download_asset, get_github_assets
from ultralytics.yolo.utils.files import file_size
from ultralytics.yolo.utils.ops import Profile
from ultralytics.yolo.utils.torch_utils import get_latest_opset, select_device, smart_inference_mode
from ultralytics.utils import (ARM64, DEFAULT_CFG, LINUX, LOGGER, MACOS, ROOT, __version__, callbacks, colorstr,
get_default_args, yaml_save)
from ultralytics.utils.checks import check_imgsz, check_requirements, check_version
from ultralytics.utils.downloads import attempt_download_asset, get_github_assets
from ultralytics.utils.files import file_size
from ultralytics.utils.ops import Profile
from ultralytics.utils.torch_utils import get_latest_opset, select_device, smart_inference_mode
def export_formats():

View File

@ -4,29 +4,28 @@ import sys
from pathlib import Path
from typing import Union
from ultralytics import yolo # noqa
from ultralytics.cfg import get_cfg
from ultralytics.engine.exporter import Exporter
from ultralytics.models import yolo # noqa
from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, PoseModel, SegmentationModel,
attempt_load_one_weight, guess_model_task, nn, yaml_model_load)
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
is_git_dir, yaml_load)
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
from ultralytics.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
is_git_dir, yaml_load)
from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
from ultralytics.utils.downloads import GITHUB_ASSET_STEMS
from ultralytics.utils.torch_utils import smart_inference_mode
# Map head to model, trainer, validator, and predictor classes
TASK_MAP = {
'classify': [
ClassificationModel, yolo.v8.classify.ClassificationTrainer, yolo.v8.classify.ClassificationValidator,
yolo.v8.classify.ClassificationPredictor],
'detect': [
DetectionModel, yolo.v8.detect.DetectionTrainer, yolo.v8.detect.DetectionValidator,
yolo.v8.detect.DetectionPredictor],
ClassificationModel, yolo.classify.ClassificationTrainer, yolo.classify.ClassificationValidator,
yolo.classify.ClassificationPredictor],
'detect':
[DetectionModel, yolo.detect.DetectionTrainer, yolo.detect.DetectionValidator, yolo.detect.DetectionPredictor],
'segment': [
SegmentationModel, yolo.v8.segment.SegmentationTrainer, yolo.v8.segment.SegmentationValidator,
yolo.v8.segment.SegmentationPredictor],
'pose': [PoseModel, yolo.v8.pose.PoseTrainer, yolo.v8.pose.PoseValidator, yolo.v8.pose.PosePredictor]}
SegmentationModel, yolo.segment.SegmentationTrainer, yolo.segment.SegmentationValidator,
yolo.segment.SegmentationPredictor],
'pose': [PoseModel, yolo.pose.PoseTrainer, yolo.pose.PoseValidator, yolo.pose.PosePredictor]}
class YOLO:
@ -63,11 +62,11 @@ class YOLO:
Logs the model info.
fuse() -> None:
Fuses the model for faster inference.
predict(source=None, stream=False, **kwargs) -> List[ultralytics.yolo.engine.results.Results]:
predict(source=None, stream=False, **kwargs) -> List[ultralytics.engine.results.Results]:
Performs prediction using the YOLO model.
Returns:
list(ultralytics.yolo.engine.results.Results): The prediction results.
list(ultralytics.engine.results.Results): The prediction results.
"""
def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None) -> None:
@ -230,7 +229,7 @@ class YOLO:
Check the 'configuration' section in the documentation for all available options.
Returns:
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
(List[ultralytics.engine.results.Results]): The prediction results.
"""
if source is None:
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
@ -265,11 +264,11 @@ class YOLO:
**kwargs (optional): Additional keyword arguments for the tracking process.
Returns:
(List[ultralytics.yolo.engine.results.Results]): The tracking results.
(List[ultralytics.engine.results.Results]): The tracking results.
"""
if not hasattr(self.predictor, 'trackers'):
from ultralytics.tracker import register_tracker
from ultralytics.trackers import register_tracker
register_tracker(self, persist)
# ByteTrack-based method needs low confidence predictions as input
conf = kwargs.get('conf') or 0.1
@ -315,7 +314,7 @@ class YOLO:
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
"""
self._check_is_pytorch_model()
from ultralytics.yolo.utils.benchmarks import benchmark
from ultralytics.utils.benchmarks import benchmark
overrides = self.model.args.copy()
overrides.update(kwargs)
overrides['mode'] = 'benchmark'
@ -389,7 +388,7 @@ class YOLO:
def tune(self, *args, **kwargs):
"""
Runs hyperparameter tuning using Ray Tune. See ultralytics.yolo.utils.tuner.run_ray_tune for Args.
Runs hyperparameter tuning using Ray Tune. See ultralytics.utils.tuner.run_ray_tune for Args.
Returns:
(dict): A dictionary containing the results of the hyperparameter search.
@ -398,7 +397,7 @@ class YOLO:
ModuleNotFoundError: If Ray Tune is not installed.
"""
self._check_is_pytorch_model()
from ultralytics.yolo.utils.tuner import run_ray_tune
from ultralytics.utils.tuner import run_ray_tune
return run_ray_tune(self, *args, **kwargs)
@property

View File

@ -34,14 +34,14 @@ import cv2
import numpy as np
import torch
from ultralytics.cfg import get_cfg
from ultralytics.data import load_inference_source
from ultralytics.data.augment import LetterBox, classify_transforms
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.data import load_inference_source
from ultralytics.yolo.data.augment import LetterBox, classify_transforms
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, MACOS, SETTINGS, WINDOWS, callbacks, colorstr, ops
from ultralytics.yolo.utils.checks import check_imgsz, check_imshow
from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, SETTINGS, WINDOWS, callbacks, colorstr, ops
from ultralytics.utils.checks import check_imgsz, check_imshow
from ultralytics.utils.files import increment_path
from ultralytics.utils.torch_utils import select_device, smart_inference_mode
STREAM_WARNING = """
WARNING stream/video/webcam/dir predict source will accumulate results in RAM unless `stream=True` is passed,

View File

@ -12,9 +12,9 @@ from pathlib import Path
import numpy as np
import torch
from ultralytics.yolo.data.augment import LetterBox
from ultralytics.yolo.utils import LOGGER, SimpleClass, deprecation_warn, ops
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
from ultralytics.data.augment import LetterBox
from ultralytics.utils import LOGGER, SimpleClass, deprecation_warn, ops
from ultralytics.utils.plotting import Annotator, colors, save_one_box
class BaseTensor(SimpleClass):

View File

@ -21,17 +21,17 @@ from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm
from ultralytics.cfg import get_cfg
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, __version__, callbacks,
clean_url, colorstr, emojis, yaml_save)
from ultralytics.yolo.utils.autobatch import check_train_batch_size
from ultralytics.yolo.utils.checks import check_amp, check_file, check_imgsz, print_args
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
from ultralytics.yolo.utils.files import get_latest_run, increment_path
from ultralytics.yolo.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle,
select_device, strip_optimizer)
from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, __version__, callbacks, clean_url,
colorstr, emojis, yaml_save)
from ultralytics.utils.autobatch import check_train_batch_size
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, print_args
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
from ultralytics.utils.files import get_latest_run, increment_path
from ultralytics.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, select_device,
strip_optimizer)
class BaseTrainer:

View File

@ -25,14 +25,14 @@ from pathlib import Path
import torch
from tqdm import tqdm
from ultralytics.cfg import get_cfg
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis
from ultralytics.yolo.utils.checks import check_imgsz
from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.ops import Profile
from ultralytics.yolo.utils.torch_utils import de_parallel, select_device, smart_inference_mode
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis
from ultralytics.utils.checks import check_imgsz
from ultralytics.utils.files import increment_path
from ultralytics.utils.ops import Profile
from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode
class BaseValidator:

View File

@ -2,10 +2,10 @@
import requests
from ultralytics.data.utils import HUBDatasetStats
from ultralytics.hub.auth import Auth
from ultralytics.hub.utils import PREFIX
from ultralytics.yolo.data.utils import HUBDatasetStats
from ultralytics.yolo.utils import LOGGER, SETTINGS, USER_CONFIG_DIR, yaml_save
from ultralytics.utils import LOGGER, SETTINGS, USER_CONFIG_DIR, yaml_save
def login(api_key=''):
@ -65,7 +65,7 @@ def reset_model(model_id=''):
def export_fmts_hub():
"""Returns a list of HUB-supported export formats."""
from ultralytics.yolo.engine.exporter import export_formats
from ultralytics.engine.exporter import export_formats
return list(export_formats()['Argument'][1:]) + ['ultralytics_tflite', 'ultralytics_coreml']

View File

@ -3,7 +3,7 @@
import requests
from ultralytics.hub.utils import HUB_API_ROOT, PREFIX, request_with_credentials
from ultralytics.yolo.utils import LOGGER, SETTINGS, emojis, is_colab, set_settings
from ultralytics.utils import LOGGER, SETTINGS, emojis, is_colab, set_settings
API_KEY_URL = 'https://hub.ultralytics.com/settings?tab=api+keys'

View File

@ -7,8 +7,8 @@ from time import sleep
import requests
from ultralytics.hub.utils import HUB_API_ROOT, PREFIX, smart_request
from ultralytics.yolo.utils import LOGGER, __version__, checks, emojis, is_colab, threaded
from ultralytics.yolo.utils.errors import HUBModelError
from ultralytics.utils import LOGGER, __version__, checks, emojis, is_colab, threaded
from ultralytics.utils.errors import HUBModelError
AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'

View File

@ -11,9 +11,8 @@ from pathlib import Path
import requests
from tqdm import tqdm
from ultralytics.yolo.utils import (ENVIRONMENT, LOGGER, ONLINE, RANK, SETTINGS, TESTS_RUNNING, TQDM_BAR_FORMAT,
TryExcept, __version__, colorstr, get_git_origin_url, is_colab, is_git_dir,
is_pip_package)
from ultralytics.utils import (ENVIRONMENT, LOGGER, ONLINE, RANK, SETTINGS, TESTS_RUNNING, TQDM_BAR_FORMAT, TryExcept,
__version__, colorstr, get_git_origin_url, is_colab, is_git_dir, is_pip_package)
PREFIX = colorstr('Ultralytics HUB: ')
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'

View File

@ -1,5 +1,3 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .rtdetr import RTDETR
from .sam import SAM

View File

@ -9,13 +9,13 @@ Usage - Predict:
results = model.predict('ultralytics/assets/bus.jpg')
"""
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir
from ultralytics.yolo.utils.checks import check_imgsz
from ultralytics.cfg import get_cfg
from ultralytics.engine.exporter import Exporter
from ultralytics.engine.model import YOLO
from ultralytics.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir
from ultralytics.utils.checks import check_imgsz
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
from ...yolo.utils.torch_utils import model_info, smart_inference_mode
from .predict import FastSAMPredictor
@ -41,7 +41,7 @@ class FastSAM(YOLO):
Check the 'configuration' section in the documentation for all available options.
Returns:
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
(List[ultralytics.engine.results.Results]): The prediction results.
"""
if source is None:
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'

View File

@ -2,10 +2,10 @@
import torch
from ultralytics.yolo.engine.results import Results
from ultralytics.yolo.fastsam.utils import bbox_iou
from ultralytics.yolo.utils import DEFAULT_CFG, ops
from ultralytics.yolo.v8.detect.predict import DetectionPredictor
from ultralytics.engine.results import Results
from ultralytics.models.fastsam.utils import bbox_iou
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import DEFAULT_CFG, ops
class FastSAMPredictor(DetectionPredictor):

View File

@ -22,7 +22,7 @@ class FastSAMPrompt:
try:
import clip # for linear_assignment
except ImportError:
from ultralytics.yolo.utils.checks import check_requirements
from ultralytics.utils.checks import check_requirements
check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source
import clip
self.clip = clip

View File

@ -7,11 +7,11 @@ import numpy as np
import torch
import torch.nn.functional as F
from ultralytics.yolo.utils import LOGGER, NUM_THREADS, ops
from ultralytics.yolo.utils.checks import check_requirements
from ultralytics.yolo.utils.metrics import SegmentMetrics, box_iou, mask_iou
from ultralytics.yolo.utils.plotting import output_to_target, plot_images
from ultralytics.yolo.v8.detect import DetectionValidator
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import LOGGER, NUM_THREADS, ops
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.metrics import SegmentMetrics, box_iou, mask_iou
from ultralytics.utils.plotting import output_to_target, plot_images
class FastSAMValidator(DetectionValidator):

View File

@ -13,12 +13,12 @@ from pathlib import Path
import torch
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, ROOT, is_git_dir
from ultralytics.yolo.utils.checks import check_imgsz
from ultralytics.cfg import get_cfg
from ultralytics.engine.exporter import Exporter
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, ROOT, is_git_dir
from ultralytics.utils.checks import check_imgsz
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
from ...yolo.utils.torch_utils import model_info, smart_inference_mode
from .predict import NASPredictor
from .val import NASValidator
@ -65,7 +65,7 @@ class NAS:
Check the 'configuration' section in the documentation for all available options.
Returns:
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
(List[ultralytics.engine.results.Results]): The prediction results.
"""
if source is None:
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'

View File

@ -2,10 +2,10 @@
import torch
from ultralytics.yolo.engine.predictor import BasePredictor
from ultralytics.yolo.engine.results import Results
from ultralytics.yolo.utils import ops
from ultralytics.yolo.utils.ops import xyxy2xywh
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import ops
from ultralytics.utils.ops import xyxy2xywh
class NASPredictor(BasePredictor):

View File

@ -2,9 +2,9 @@
import torch
from ultralytics.yolo.utils import ops
from ultralytics.yolo.utils.ops import xyxy2xywh
from ultralytics.yolo.v8.detect import DetectionValidator
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import ops
from ultralytics.utils.ops import xyxy2xywh
__all__ = ['NASValidator']

View File

@ -7,12 +7,12 @@ from pathlib import Path
import torch.nn as nn
from ultralytics.cfg import get_cfg
from ultralytics.engine.exporter import Exporter
from ultralytics.nn.tasks import RTDETRDetectionModel, attempt_load_one_weight, yaml_model_load
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, RANK, ROOT, is_git_dir
from ultralytics.yolo.utils.checks import check_imgsz
from ultralytics.yolo.utils.torch_utils import model_info, smart_inference_mode
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, RANK, ROOT, is_git_dir
from ultralytics.utils.checks import check_imgsz
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
from .predict import RTDETRPredictor
from .train import RTDETRTrainer
@ -72,7 +72,7 @@ class RTDETR:
Check the 'configuration' section in the documentation for all available options.
Returns:
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
(List[ultralytics.engine.results.Results]): The prediction results.
"""
if source is None:
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'

View File

@ -2,10 +2,10 @@
import torch
from ultralytics.yolo.data.augment import LetterBox
from ultralytics.yolo.engine.predictor import BasePredictor
from ultralytics.yolo.engine.results import Results
from ultralytics.yolo.utils import ops
from ultralytics.data.augment import LetterBox
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import ops
class RTDETRPredictor(BasePredictor):

View File

@ -4,9 +4,9 @@ from copy import copy
import torch
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import RTDETRDetectionModel
from ultralytics.yolo.utils import DEFAULT_CFG, RANK, colorstr
from ultralytics.yolo.v8.detect import DetectionTrainer
from ultralytics.utils import DEFAULT_CFG, RANK, colorstr
from .val import RTDETRDataset, RTDETRValidator

View File

@ -6,10 +6,10 @@ import cv2
import numpy as np
import torch
from ultralytics.yolo.data import YOLODataset
from ultralytics.yolo.data.augment import Compose, Format, v8_transforms
from ultralytics.yolo.utils import colorstr, ops
from ultralytics.yolo.v8.detect import DetectionValidator
from ultralytics.data import YOLODataset
from ultralytics.data.augment import Compose, Format, v8_transforms
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import colorstr, ops
__all__ = 'RTDETRValidator', # tuple or list

View File

@ -10,7 +10,8 @@ from functools import partial
import torch
from ...yolo.utils.downloads import attempt_download_asset
from ultralytics.utils.downloads import attempt_download_asset
from .modules.decoders import MaskDecoder
from .modules.encoders import ImageEncoderViT, PromptEncoder
from .modules.sam import Sam

View File

@ -3,8 +3,8 @@
SAM model interface
"""
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.utils.torch_utils import model_info
from ultralytics.cfg import get_cfg
from ultralytics.utils.torch_utils import model_info
from .build import build_sam
from .predict import Predictor

View File

@ -17,7 +17,7 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from ultralytics.yolo.utils.instance import to_2tuple
from ultralytics.utils.instance import to_2tuple
class Conv2d_BN(torch.nn.Sequential):
@ -50,7 +50,7 @@ class Conv2d_BN(torch.nn.Sequential):
# NOTE: This module and timm package is needed only for training.
# from ultralytics.yolo.utils.checks import check_requirements
# from ultralytics.utils.checks import check_requirements
# check_requirements('timm')
# from timm.models.layers import DropPath as TimmDropPath
# from timm.models.layers import trunc_normal_

View File

@ -5,11 +5,11 @@ import torch
import torch.nn.functional as F
import torchvision
from ultralytics.yolo.data.augment import LetterBox
from ultralytics.yolo.engine.predictor import BasePredictor
from ultralytics.yolo.engine.results import Results
from ultralytics.yolo.utils import DEFAULT_CFG, ops
from ultralytics.yolo.utils.torch_utils import select_device
from ultralytics.data.augment import LetterBox
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import DEFAULT_CFG, ops
from ultralytics.utils.torch_utils import select_device
from .amg import (batch_iterator, batched_mask_to_box, build_all_layer_point_grids, calculate_stability_score,
generate_crop_boxes, is_box_near_crop_edge, remove_small_regions, uncrop_boxes_xyxy, uncrop_masks)

View File

@ -4,9 +4,10 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from ultralytics.vit.utils.ops import HungarianMatcher
from ultralytics.yolo.utils.loss import FocalLoss, VarifocalLoss
from ultralytics.yolo.utils.metrics import bbox_iou
from ultralytics.utils.loss import FocalLoss, VarifocalLoss
from ultralytics.utils.metrics import bbox_iou
from .ops import HungarianMatcher
class DETRLoss(nn.Module):

View File

@ -5,8 +5,8 @@ import torch.nn as nn
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment
from ultralytics.yolo.utils.metrics import bbox_iou
from ultralytics.yolo.utils.ops import xywh2xyxy, xyxy2xywh
from ultralytics.utils.metrics import bbox_iou
from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh
class HungarianMatcher(nn.Module):

View File

@ -0,0 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.models.yolo import classify, detect, pose, segment
__all__ = 'classify', 'segment', 'detect', 'pose'

View File

@ -0,0 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.models.yolo.classify.predict import ClassificationPredictor, predict
from ultralytics.models.yolo.classify.train import ClassificationTrainer, train
from ultralytics.models.yolo.classify.val import ClassificationValidator, val
__all__ = 'ClassificationPredictor', 'predict', 'ClassificationTrainer', 'train', 'ClassificationValidator', 'val'

View File

@ -2,9 +2,9 @@
import torch
from ultralytics.yolo.engine.predictor import BasePredictor
from ultralytics.yolo.engine.results import Results
from ultralytics.yolo.utils import DEFAULT_CFG, ROOT
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import DEFAULT_CFG, ROOT
class ClassificationPredictor(BasePredictor):

View File

@ -3,13 +3,13 @@
import torch
import torchvision
from ultralytics.data import ClassificationDataset, build_dataloader
from ultralytics.engine.trainer import BaseTrainer
from ultralytics.models import yolo
from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight
from ultralytics.yolo import v8
from ultralytics.yolo.data import ClassificationDataset, build_dataloader
from ultralytics.yolo.engine.trainer import BaseTrainer
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, colorstr
from ultralytics.yolo.utils.plotting import plot_images, plot_results
from ultralytics.yolo.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, colorstr
from ultralytics.utils.plotting import plot_images, plot_results
from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
class ClassificationTrainer(BaseTrainer):
@ -98,7 +98,7 @@ class ClassificationTrainer(BaseTrainer):
def get_validator(self):
"""Returns an instance of ClassificationValidator for validation."""
self.loss_names = ['loss']
return v8.classify.ClassificationValidator(self.test_loader, self.save_dir)
return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir)
def label_loss_items(self, loss_items=None, prefix='train'):
"""

View File

@ -2,11 +2,11 @@
import torch
from ultralytics.yolo.data import ClassificationDataset, build_dataloader
from ultralytics.yolo.engine.validator import BaseValidator
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER
from ultralytics.yolo.utils.metrics import ClassifyMetrics, ConfusionMatrix
from ultralytics.yolo.utils.plotting import plot_images
from ultralytics.data import ClassificationDataset, build_dataloader
from ultralytics.engine.validator import BaseValidator
from ultralytics.utils import DEFAULT_CFG, LOGGER
from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix
from ultralytics.utils.plotting import plot_images
class ClassificationValidator(BaseValidator):

View File

@ -2,9 +2,9 @@
import torch
from ultralytics.yolo.engine.predictor import BasePredictor
from ultralytics.yolo.engine.results import Results
from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import DEFAULT_CFG, ROOT, ops
class DetectionPredictor(BasePredictor):

View File

@ -3,13 +3,13 @@ from copy import copy
import numpy as np
from ultralytics.data import build_dataloader, build_yolo_dataset
from ultralytics.engine.trainer import BaseTrainer
from ultralytics.models import yolo
from ultralytics.nn.tasks import DetectionModel
from ultralytics.yolo import v8
from ultralytics.yolo.data import build_dataloader, build_yolo_dataset
from ultralytics.yolo.engine.trainer import BaseTrainer
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK
from ultralytics.yolo.utils.plotting import plot_images, plot_labels, plot_results
from ultralytics.yolo.utils.torch_utils import de_parallel, torch_distributed_zero_first
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
from ultralytics.utils.plotting import plot_images, plot_labels, plot_results
from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first
# BaseTrainer python usage
@ -64,7 +64,7 @@ class DetectionTrainer(BaseTrainer):
def get_validator(self):
"""Returns a DetectionValidator for YOLO model validation."""
self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
return v8.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
return yolo.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
def label_loss_items(self, loss_items=None, prefix='train'):
"""

Some files were not shown because too many files have changed in this diff Show More