`ultralytics 8.0.21` Windows, segments, YAML fixes (#655)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent dc9502c700
commit 6c44ce21d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -51,9 +51,9 @@ body:
label: Environment label: Environment
description: Please specify the software and hardware you used to produce the bug. description: Please specify the software and hardware you used to produce the bug.
placeholder: | placeholder: |
- YOLO: YOLOv8 🚀 v6.0-67-g60e42e1 torch 1.9.0+cu111 CUDA:0 (A100-SXM4-40GB, 40536MiB) - YOLO: Ultralytics YOLOv8.0.21 🚀 Python-3.8.10 torch-1.13.1+cu117 CUDA:0 (A100-SXM-80GB, 81251MiB)
- OS: Ubuntu 20.04 - OS: Ubuntu 20.04
- Python: 3.9.0 - Python: 3.8.10
validations: validations:
required: false required: false

@ -35,28 +35,29 @@ def test_train_cls():
# Val checks ----------------------------------------------------------------------------------------------------------- # Val checks -----------------------------------------------------------------------------------------------------------
def test_val_detect(): def test_val_detect():
run(f'yolo val detect model={MODEL}.pt data=coco8.yaml imgsz=32 epochs=1') run(f'yolo val detect model={MODEL}.pt data=coco8.yaml imgsz=32')
def test_val_segment(): def test_val_segment():
run(f'yolo val segment model={MODEL}-seg.pt data=coco8-seg.yaml imgsz=32 epochs=1') run(f'yolo val segment model={MODEL}-seg.pt data=coco8-seg.yaml imgsz=32')
def test_val_classify(): def test_val_classify():
pass run(f'yolo val classify model={MODEL}-cls.pt data=mnist160 imgsz=32')
# Predict checks ------------------------------------------------------------------------------------------------------- # Predict checks -------------------------------------------------------------------------------------------------------
def test_predict_detect(): def test_predict_detect():
run(f"yolo predict detect model={MODEL}.pt source={ROOT / 'assets'} imgsz=320 conf=0.25") run(f"yolo predict detect model={MODEL}.pt source={ROOT / 'assets'} imgsz=32")
run(f"yolo predict detect model={MODEL}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32")
def test_predict_segment(): def test_predict_segment():
run(f"yolo predict segment model={MODEL}-seg.pt source={ROOT / 'assets'}") run(f"yolo predict segment model={MODEL}-seg.pt source={ROOT / 'assets'} imgsz=32")
def test_predict_classify(): def test_predict_classify():
pass run(f"yolo predict segment model={MODEL}-cls.pt source={ROOT / 'assets'} imgsz=32")
# Export checks -------------------------------------------------------------------------------------------------------- # Export checks --------------------------------------------------------------------------------------------------------

@ -111,7 +111,9 @@ def test_export_coreml():
model.export(format='coreml') model.export(format='coreml')
def test_export_paddle(): def test_export_paddle(enabled=False):
# Paddle protobuf requirements conflicting with onnx protobuf requirements
if enabled:
model = YOLO(MODEL) model = YOLO(MODEL)
model.export(format='paddle') model.export(format='paddle')

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
__version__ = "8.0.20" __version__ = "8.0.21"
from ultralytics.yolo.engine.model import YOLO from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils import ops from ultralytics.yolo.utils import ops

@ -9,8 +9,8 @@ from types import SimpleNamespace
from typing import Dict, List, Union from typing import Dict, List, Union
from ultralytics import __version__ from ultralytics import __version__
from ultralytics.yolo.utils import (DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, PREFIX, ROOT, USER_CONFIG_DIR, from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, PREFIX, ROOT,
IterableSimpleNamespace, colorstr, yaml_load, yaml_print) USER_CONFIG_DIR, IterableSimpleNamespace, colorstr, emojis, yaml_load, yaml_print)
from ultralytics.yolo.utils.checks import check_yolo from ultralytics.yolo.utils.checks import check_yolo
CLI_HELP_MSG = \ CLI_HELP_MSG = \
@ -69,7 +69,7 @@ def cfg2dict(cfg):
return cfg return cfg
def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace], overrides: Dict = None): def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG, overrides: Dict = None):
""" """
Load and merge configuration data from a file or dictionary. Load and merge configuration data from a file or dictionary.
@ -214,17 +214,19 @@ def entrypoint(debug=False):
# Mode # Mode
mode = overrides.pop('mode', None) mode = overrides.pop('mode', None)
model = overrides.pop('model', None) model = overrides.pop('model', None)
if mode == 'checks': if mode is None:
mode = DEFAULT_CFG.mode or 'predict'
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {modes}. Using default 'mode={mode}'.")
elif mode not in modes:
if mode != 'checks':
raise ValueError(emojis(f"ERROR ❌ Invalid 'mode={mode}'. Valid modes are {modes}."))
LOGGER.warning("WARNING ⚠️ 'yolo mode=checks' is deprecated. Use 'yolo checks' instead.") LOGGER.warning("WARNING ⚠️ 'yolo mode=checks' is deprecated. Use 'yolo checks' instead.")
check_yolo() check_yolo()
return return
elif mode is None:
mode = DEFAULT_CFG_DICT['mode'] or 'predict'
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {modes}. Using default 'mode={mode}'.")
# Model # Model
if model is None: if model is None:
model = DEFAULT_CFG_DICT['model'] or 'yolov8n.pt' model = DEFAULT_CFG.model or 'yolov8n.pt'
LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.") LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
from ultralytics.yolo.engine.model import YOLO from ultralytics.yolo.engine.model import YOLO
model = YOLO(model) model = YOLO(model)
@ -232,21 +234,21 @@ def entrypoint(debug=False):
# Task # Task
if mode == 'predict' and 'source' not in overrides: if mode == 'predict' and 'source' not in overrides:
overrides['source'] = DEFAULT_CFG_DICT['source'] or ROOT / "assets" if (ROOT / "assets").exists() \ overrides['source'] = DEFAULT_CFG.source or ROOT / "assets" if (ROOT / "assets").exists() \
else "https://ultralytics.com/images/bus.jpg" else "https://ultralytics.com/images/bus.jpg"
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.") LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
elif mode in ('train', 'val'): elif mode in ('train', 'val'):
if 'data' not in overrides: if 'data' not in overrides:
overrides['data'] = DEFAULT_CFG_DICT['data'] or 'mnist160' if task == 'classify' \ overrides['data'] = DEFAULT_CFG.data or 'mnist160' if task == 'classify' \
else 'coco128-seg.yaml' if task == 'segment' else 'coco128.yaml' else 'coco128-seg.yaml' if task == 'segment' else 'coco128.yaml'
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.") LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
elif mode == 'export': elif mode == 'export':
if 'format' not in overrides: if 'format' not in overrides:
overrides['format'] = DEFAULT_CFG_DICT['format'] or 'torchscript' overrides['format'] = DEFAULT_CFG.format or 'torchscript'
LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.") LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.")
# Run command in python # Run command in python
getattr(model, mode)(verbose=True, **overrides) getattr(model, mode)(**overrides)
# Special modes -------------------------------------------------------------------------------------------------------- # Special modes --------------------------------------------------------------------------------------------------------

@ -44,7 +44,8 @@ class LoadStreams:
assert not is_colab(), '--source 0 webcam unsupported on Colab. Rerun command in a local environment.' assert not is_colab(), '--source 0 webcam unsupported on Colab. Rerun command in a local environment.'
assert not is_kaggle(), '--source 0 webcam unsupported on Kaggle. Rerun command in a local environment.' assert not is_kaggle(), '--source 0 webcam unsupported on Kaggle. Rerun command in a local environment.'
cap = cv2.VideoCapture(s) cap = cv2.VideoCapture(s)
assert cap.isOpened(), f'{st}Failed to open {s}' if not cap.isOpened():
raise ConnectionError(f'{st}Failed to open {s}')
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
@ -188,8 +189,9 @@ class LoadImages:
self._new_video(videos[0]) # new video self._new_video(videos[0]) # new video
else: else:
self.cap = None self.cap = None
assert self.nf > 0, f'No images or videos found in {p}. ' \ if self.nf == 0:
f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}' raise FileNotFoundError(f'No images or videos found in {p}. '
f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}')
def __iter__(self): def __iter__(self):
self.count = 0 self.count = 0
@ -223,7 +225,8 @@ class LoadImages:
# Read image # Read image
self.count += 1 self.count += 1
im0 = cv2.imread(path) # BGR im0 = cv2.imread(path) # BGR
assert im0 is not None, f'Image Not Found {path}' if im0 is None:
raise FileNotFoundError(f'Image Not Found {path}')
s = f'image {self.count}/{self.nf} {path}: ' s = f'image {self.count}/{self.nf} {path}: '
if self.transforms: if self.transforms:

@ -23,14 +23,13 @@ import numpy as np
import psutil import psutil
import torch import torch
import torchvision import torchvision
import yaml
from PIL import ExifTags, Image, ImageOps from PIL import ExifTags, Image, ImageOps
from torch.utils.data import DataLoader, Dataset, dataloader, distributed from torch.utils.data import DataLoader, Dataset, dataloader, distributed
from tqdm import tqdm from tqdm import tqdm
from ultralytics.yolo.data.utils import check_det_dataset, unzip_file from ultralytics.yolo.data.utils import check_det_dataset, unzip_file
from ultralytics.yolo.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, is_colab, is_dir_writeable, from ultralytics.yolo.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, is_colab, is_dir_writeable,
is_kaggle) is_kaggle, yaml_load)
from ultralytics.yolo.utils.checks import check_requirements, check_yaml from ultralytics.yolo.utils.checks import check_requirements, check_yaml
from ultralytics.yolo.utils.ops import clean_str, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn from ultralytics.yolo.utils.ops import clean_str, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn
from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first
@ -1056,8 +1055,7 @@ class HUBDatasetStats():
# Initialize class # Initialize class
zipped, data_dir, yaml_path = self._unzip(Path(path)) zipped, data_dir, yaml_path = self._unzip(Path(path))
try: try:
with open(check_yaml(yaml_path), errors='ignore') as f: data = yaml_load(check_yaml(yaml_path)) # data dict
data = yaml.safe_load(f) # data dict
if zipped: if zipped:
data['path'] = data_dir data['path'] = data_dir
except Exception as e: except Exception as e:

@ -129,7 +129,7 @@ class Exporter:
overrides (dict, optional): Configuration overrides. Defaults to None. overrides (dict, optional): Configuration overrides. Defaults to None.
""" """
self.args = get_cfg(cfg, overrides) self.args = get_cfg(cfg, overrides)
self.callbacks = defaultdict(list, {k: v for k, v in callbacks.default_callbacks.items()}) # add callbacks self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
callbacks.add_integration_callbacks(self) callbacks.add_integration_callbacks(self)
@smart_inference_mode() @smart_inference_mode()

@ -61,8 +61,8 @@ class YOLO:
else: else:
raise NotImplementedError(f"'{suffix}' model loading not implemented") raise NotImplementedError(f"'{suffix}' model loading not implemented")
def __call__(self, source=None, stream=False, verbose=False, **kwargs): def __call__(self, source=None, stream=False, **kwargs):
return self.predict(source, stream, verbose, **kwargs) return self.predict(source, stream, **kwargs)
def _new(self, cfg: str, verbose=True): def _new(self, cfg: str, verbose=True):
""" """
@ -118,7 +118,7 @@ class YOLO:
self.model.fuse() self.model.fuse()
@smart_inference_mode() @smart_inference_mode()
def predict(self, source=None, stream=False, verbose=False, **kwargs): def predict(self, source=None, stream=False, **kwargs):
""" """
Perform prediction using the YOLO model. Perform prediction using the YOLO model.
@ -126,7 +126,6 @@ class YOLO:
source (str | int | PIL | np.ndarray): The source of the image to make predictions on. source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
Accepts all source types accepted by the YOLO model. Accepts all source types accepted by the YOLO model.
stream (bool): Whether to stream the predictions or not. Defaults to False. stream (bool): Whether to stream the predictions or not. Defaults to False.
verbose (bool): Whether to print verbose information or not. Defaults to False.
**kwargs : Additional keyword arguments passed to the predictor. **kwargs : Additional keyword arguments passed to the predictor.
Check the 'configuration' section in the documentation for all available options. Check the 'configuration' section in the documentation for all available options.
@ -143,7 +142,7 @@ class YOLO:
self.predictor.setup_model(model=self.model) self.predictor.setup_model(model=self.model)
else: # only update args if predictor is already setup else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides) self.predictor.args = get_cfg(self.predictor.args, overrides)
return self.predictor(source=source, stream=stream, verbose=verbose) return self.predictor(source=source, stream=stream)
@smart_inference_mode() @smart_inference_mode()
def val(self, data=None, **kwargs): def val(self, data=None, **kwargs):
@ -234,7 +233,8 @@ class YOLO:
""" """
return self.model.names return self.model.names
def add_callback(self, event: str, func): @staticmethod
def add_callback(event: str, func):
""" """
Add callback Add callback
""" """
@ -242,16 +242,8 @@ class YOLO:
@staticmethod @staticmethod
def _reset_ckpt_args(args): def _reset_ckpt_args(args):
args.pop("project", None) for arg in 'verbose', 'project', 'name', 'exist_ok', 'resume', 'batch', 'epochs', 'cache', 'save_json', \
args.pop("name", None) 'half', 'v5loader':
args.pop("exist_ok", None) args.pop(arg, None)
args.pop("resume", None)
args.pop("batch", None) args["device"] = '' # set device to '' to prevent auto-DDP usage
args.pop("epochs", None)
args.pop("cache", None)
args.pop("save_json", None)
args.pop("half", None)
args.pop("v5loader", None)
# set device to '' to prevent from auto DDP usage
args["device"] = ''

@ -88,7 +88,7 @@ class BasePredictor:
self.vid_path, self.vid_writer = None, None self.vid_path, self.vid_writer = None, None
self.annotator = None self.annotator = None
self.data_path = None self.data_path = None
self.callbacks = defaultdict(list, {k: v for k, v in callbacks.default_callbacks.items()}) # add callbacks self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
callbacks.add_integration_callbacks(self) callbacks.add_integration_callbacks(self)
def preprocess(self, img): def preprocess(self, img):
@ -151,19 +151,19 @@ class BasePredictor:
self.bs = bs self.bs = bs
@smart_inference_mode() @smart_inference_mode()
def __call__(self, source=None, model=None, verbose=False, stream=False): def __call__(self, source=None, model=None, stream=False):
if stream: if stream:
return self.stream_inference(source, model, verbose) return self.stream_inference(source, model)
else: else:
return list(self.stream_inference(source, model, verbose)) # merge list of Result into one return list(self.stream_inference(source, model)) # merge list of Result into one
def predict_cli(self): def predict_cli(self):
# Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode # Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode
gen = self.stream_inference(verbose=True) gen = self.stream_inference()
for _ in gen: # running CLI inference without accumulating any outputs (do not modify) for _ in gen: # running CLI inference without accumulating any outputs (do not modify)
pass pass
def stream_inference(self, source=None, model=None, verbose=False): def stream_inference(self, source=None, model=None):
self.run_callbacks("on_predict_start") self.run_callbacks("on_predict_start")
# setup model # setup model
@ -201,7 +201,7 @@ class BasePredictor:
p, im0 = (path[i], im0s[i]) if self.webcam or self.from_img else (path, im0s) p, im0 = (path[i], im0s[i]) if self.webcam or self.from_img else (path, im0s)
p = Path(p) p = Path(p)
if verbose or self.args.save or self.args.save_txt or self.args.show: if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
s += self.write_results(i, self.results, (p, im, im0)) s += self.write_results(i, self.results, (p, im, im0))
if self.args.show: if self.args.show:
@ -214,11 +214,11 @@ class BasePredictor:
yield from self.results yield from self.results
# Print time (inference-only) # Print time (inference-only)
if verbose: if self.args.verbose:
LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms") LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms")
# Print results # Print results
if verbose and self.seen: if self.args.verbose and self.seen:
t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms postprocess per image at shape ' LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms postprocess per image at shape '
f'{(1, 3, *self.imgsz)}' % t) f'{(1, 3, *self.imgsz)}' % t)
@ -243,7 +243,7 @@ class BasePredictor:
if isinstance(source, (str, int, Path)): # int for local usb carame if isinstance(source, (str, int, Path)): # int for local usb carame
source = str(source) source = str(source)
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS) is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://')) is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://'))
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file) webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
screenshot = source.lower().startswith('screen') screenshot = source.lower().startswith('screen')
if is_url and is_file: if is_url and is_file:

@ -85,7 +85,6 @@ class BaseTrainer:
self.console = LOGGER self.console = LOGGER
self.validator = None self.validator = None
self.model = None self.model = None
self.callbacks = defaultdict(list)
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic) init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
# Dirs # Dirs
@ -141,7 +140,7 @@ class BaseTrainer:
self.plot_idx = [0, 1, 2] self.plot_idx = [0, 1, 2]
# Callbacks # Callbacks
self.callbacks = defaultdict(list, {k: v for k, v in callbacks.default_callbacks.items()}) # add callbacks self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
if RANK in {0, -1}: if RANK in {0, -1}:
callbacks.add_integration_callbacks(self) callbacks.add_integration_callbacks(self)

@ -70,7 +70,7 @@ class BaseValidator:
if self.args.conf is None: if self.args.conf is None:
self.args.conf = 0.001 # default conf=0.001 self.args.conf = 0.001 # default conf=0.001
self.callbacks = defaultdict(list, {k: v for k, v in callbacks.default_callbacks.items()}) # add callbacks self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
@smart_inference_mode() @smart_inference_mode()
def __call__(self, trainer=None, model=None): def __call__(self, trainer=None, model=None):

@ -5,6 +5,7 @@ import inspect
import logging.config import logging.config
import os import os
import platform import platform
import re
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
@ -113,10 +114,64 @@ class IterableSimpleNamespace(SimpleNamespace):
return getattr(self, key, default) return getattr(self, key, default)
def yaml_save(file='data.yaml', data=None):
"""
Save YAML data to a file.
Args:
file (str, optional): File name. Default is 'data.yaml'.
data (dict, optional): Data to save in YAML format. Default is None.
Returns:
None: Data is saved to the specified file.
"""
file = Path(file)
if not file.parent.exists():
# Create parent directories if they don't exist
file.parent.mkdir(parents=True, exist_ok=True)
with open(file, 'w') as f:
# Dump data to file in YAML format, converting Path objects to strings
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
def yaml_load(file='data.yaml', append_filename=False):
"""
Load YAML data from a file.
Args:
file (str, optional): File name. Default is 'data.yaml'.
append_filename (bool): Add the YAML filename to the YAML dictionary. Default is False.
Returns:
dict: YAML data and file name.
"""
with open(file, errors='ignore', encoding='utf-8') as f:
# Add YAML filename to dict and return
s = f.read() # string
if not s.isprintable(): # remove special characters
s = re.sub(r'[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+', '', s)
return {**yaml.safe_load(s), 'yaml_file': str(file)} if append_filename else yaml.safe_load(s)
def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
"""
Pretty prints a yaml file or a yaml-formatted dictionary.
Args:
yaml_file: The file path of the yaml file or a yaml-formatted dictionary.
Returns:
None
"""
yaml_dict = yaml_load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file
dump = yaml.dump(yaml_dict, default_flow_style=False)
LOGGER.info(f"Printing '{colorstr('bold', 'black', yaml_file)}'\n\n{dump}")
# Default configuration # Default configuration
with open(DEFAULT_CFG_PATH, errors='ignore') as f: DEFAULT_CFG_DICT = yaml_load(DEFAULT_CFG_PATH)
DEFAULT_CFG_DICT = yaml.safe_load(f) for k, v in DEFAULT_CFG_DICT.items():
for k, v in DEFAULT_CFG_DICT.items():
if isinstance(v, str) and v.lower() == 'none': if isinstance(v, str) and v.lower() == 'none':
DEFAULT_CFG_DICT[k] = None DEFAULT_CFG_DICT[k] = None
DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys() DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()
@ -393,58 +448,6 @@ def threaded(func):
return wrapper return wrapper
def yaml_save(file='data.yaml', data=None):
"""
Save YAML data to a file.
Args:
file (str, optional): File name. Default is 'data.yaml'.
data (dict, optional): Data to save in YAML format. Default is None.
Returns:
None: Data is saved to the specified file.
"""
file = Path(file)
if not file.parent.exists():
# Create parent directories if they don't exist
file.parent.mkdir(parents=True, exist_ok=True)
with open(file, 'w') as f:
# Dump data to file in YAML format, converting Path objects to strings
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
def yaml_load(file='data.yaml', append_filename=False):
"""
Load YAML data from a file.
Args:
file (str, optional): File name. Default is 'data.yaml'.
append_filename (bool): Add the YAML filename to the YAML dictionary. Default is False.
Returns:
dict: YAML data and file name.
"""
with open(file, errors='ignore') as f:
# Add YAML filename to dict and return
return {**yaml.safe_load(f), 'yaml_file': str(file)} if append_filename else yaml.safe_load(f)
def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
"""
Pretty prints a yaml file or a yaml-formatted dictionary.
Args:
yaml_file: The file path of the yaml file or a yaml-formatted dictionary.
Returns:
None
"""
yaml_dict = yaml_load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file
dump = yaml.dump(yaml_dict, default_flow_style=False)
LOGGER.info(f"Printing '{colorstr('bold', 'black', yaml_file)}'\n\n{dump}")
def set_sentry(): def set_sentry():
""" """
Initialize the Sentry SDK for error tracking and reporting if pytest is not currently running. Initialize the Sentry SDK for error tracking and reporting if pytest is not currently running.

@ -207,9 +207,9 @@ def check_file(file, suffix=''):
# Search/download file (if necessary) and return path # Search/download file (if necessary) and return path
check_suffix(file, suffix) # optional check_suffix(file, suffix) # optional
file = str(file) # convert to str() file = str(file) # convert to str()
if Path(file).is_file() or not file: # exists if not file or ('://' not in file and Path(file).is_file()): # exists ('://' check required in Windows Python<3.10)
return file return file
elif file.startswith(('http:/', 'https:/')): # download elif file.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://')): # download
url = file # warning: Pathlib turns :// -> :/ url = file # warning: Pathlib turns :// -> :/
file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
if Path(file).is_file(): if Path(file).is_file():
@ -276,7 +276,7 @@ def git_describe(path=ROOT): # path must be a directory
try: try:
assert (Path(path) / '.git').is_dir() assert (Path(path) / '.git').is_dir()
return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1] return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
except Exception: except AssertionError:
return '' return ''

@ -104,7 +104,7 @@ def download(url, dir=Path.cwd(), unzip=True, delete=True, curl=False, threads=1
def download_one(url, dir): def download_one(url, dir):
# Download 1 file # Download 1 file
success = True success = True
if Path(url).is_file(): if '://' not in str(url) and Path(url).is_file(): # exists ('://' check required in Windows Python<3.10)
f = Path(url) # filename f = Path(url) # filename
else: # does not exist else: # does not exist
f = dir / Path(url).name f = dir / Path(url).name

@ -17,11 +17,8 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
import ultralytics
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER
from ultralytics.yolo.utils.checks import git_describe from ultralytics.yolo.utils.checks import check_version
from .checks import check_version
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1)) RANK = int(os.getenv('RANK', -1))
@ -60,8 +57,8 @@ def DDP_model(model):
def select_device(device='', batch=0, newline=False): def select_device(device='', batch=0, newline=False):
# device = None or 'cpu' or 0 or '0' or '0,1,2,3' # device = None or 'cpu' or 0 or '0' or '0,1,2,3'
ver = git_describe() or ultralytics.__version__ # git commit or pip package version from ultralytics import __version__
s = f'Ultralytics YOLOv{ver} 🚀 Python-{platform.python_version()} torch-{torch.__version__} ' s = f'Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
device = str(device).lower() device = str(device).lower()
for remove in 'cuda:', 'none', '(', ')', '[', ']', "'", ' ': for remove in 'cuda:', 'none', '(', ')', '[', ']', "'", ' ':
device = device.replace(remove, '') # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1' device = device.replace(remove, '') # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
@ -247,6 +244,7 @@ class ModelEMA:
""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
Keeps a moving average of everything in the model state_dict (parameters and buffers) Keeps a moving average of everything in the model state_dict (parameters and buffers)
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
To disable EMA set the `enabled` attribute to `False`.
""" """
def __init__(self, model, decay=0.9999, tau=2000, updates=0): def __init__(self, model, decay=0.9999, tau=2000, updates=0):
@ -256,9 +254,11 @@ class ModelEMA:
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs) self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
for p in self.ema.parameters(): for p in self.ema.parameters():
p.requires_grad_(False) p.requires_grad_(False)
self.enabled = True
def update(self, model): def update(self, model):
# Update EMA parameters # Update EMA parameters
if self.enabled:
self.updates += 1 self.updates += 1
d = self.decay(self.updates) d = self.decay(self.updates)
@ -267,10 +267,11 @@ class ModelEMA:
if v.dtype.is_floating_point: # true for FP16 and FP32 if v.dtype.is_floating_point: # true for FP16 and FP32
v *= d v *= d
v += (1 - d) * msd[k].detach() v += (1 - d) * msd[k].detach()
# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32' # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}'
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
# Update EMA attributes # Update EMA attributes
if self.enabled:
copy_attr(self.ema, model, include, exclude) copy_attr(self.ema, model, include, exclude)
@ -285,8 +286,8 @@ def strip_optimizer(f='best.pt', s=''):
strip_optimizer(f) strip_optimizer(f)
Args: Args:
f (str): file path to model state to strip the optimizer from. Default is 'best.pt'. f (str): file path to model to strip the optimizer from. Default is 'best.pt'.
s (str): file path to save the model with stripped optimizer to. Default is ''. If not provided, the original file will be overwritten. s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
Returns: Returns:
None None
@ -364,12 +365,12 @@ class EarlyStopping:
Early stopping class that stops training when a specified number of epochs have passed without improvement. Early stopping class that stops training when a specified number of epochs have passed without improvement.
""" """
def __init__(self, patience=30): def __init__(self, patience=50):
""" """
Initialize early stopping object Initialize early stopping object
Args: Args:
patience (int, optional): Number of epochs to wait after fitness stops improving before stopping. Default is 30. patience (int, optional): Number of epochs to wait after fitness stops improving before stopping.
""" """
self.best_fitness = 0.0 # i.e. mAP self.best_fitness = 0.0 # i.e. mAP
self.best_epoch = 0 self.best_epoch = 0

Loading…
Cancel
Save