DDP, Comet, URLError fixes, improved error handling (#658)

Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tungway1990 <68179274+Tungway1990@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-01-28 01:31:41 +01:00
committed by GitHub
parent 6c44ce21d9
commit a5410ed79e
22 changed files with 79 additions and 81 deletions

View File

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
__version__ = "8.0.21"
__version__ = "8.0.22"
from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils import ops

View File

@ -4,7 +4,6 @@ Common modules
"""
import math
import warnings
import torch
import torch.nn as nn
@ -155,7 +154,7 @@ class C3(nn.Module):
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c1, c_, 1, 1)
self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=((1, 1), (3, 3)), e=1.0) for _ in range(n)))
def forward(self, x):
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
@ -274,9 +273,7 @@ class SPP(nn.Module):
def forward(self, x):
x = self.cv1(x)
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
class SPPF(nn.Module):
@ -290,11 +287,9 @@ class SPPF(nn.Module):
def forward(self, x):
x = self.cv1(x)
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
y1 = self.m(x)
y2 = self.m(y1)
return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
y1 = self.m(x)
y2 = self.m(y1)
return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
class Focus(nn.Module):

View File

@ -88,6 +88,11 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG, override
check_cfg_mismatch(cfg, overrides)
cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)
# Type checks
for k in 'project', 'name':
if isinstance(cfg[k], (int, float)):
cfg[k] = str(cfg[k])
# Return instance
return IterableSimpleNamespace(**cfg)
@ -211,12 +216,15 @@ def entrypoint(debug=False):
else:
raise argument_error(a)
# Defaults
task2model = dict(detect='yolov8n.pt', segment='yolov8n-seg.pt', classify='yolov8n-cls.pt')
task2data = dict(detect='coco128.yaml', segment='coco128-seg.yaml', classify='mnist160')
# Mode
mode = overrides.pop('mode', None)
model = overrides.pop('model', None)
mode = overrides['mode']
if mode is None:
mode = DEFAULT_CFG.mode or 'predict'
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {modes}. Using default 'mode={mode}'.")
LOGGER.warning(f"WARNING ⚠️ 'mode=' is missing. Valid modes are {modes}. Using default 'mode={mode}'.")
elif mode not in modes:
if mode != 'checks':
raise ValueError(emojis(f"ERROR ❌ Invalid 'mode={mode}'. Valid modes are {modes}."))
@ -225,27 +233,33 @@ def entrypoint(debug=False):
return
# Model
model = overrides.pop('model', DEFAULT_CFG.model)
task = overrides.pop('task', None)
if model is None:
model = DEFAULT_CFG.model or 'yolov8n.pt'
LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
model = task2model.get(task, 'yolov8n.pt')
LOGGER.warning(f"WARNING ⚠️ 'model=' is missing. Using default 'model={model}'.")
from ultralytics.yolo.engine.model import YOLO
overrides['model'] = model
model = YOLO(model)
task = model.task
# Task
if task and task != model.task:
LOGGER.warning(f"WARNING ⚠️ 'task={task}' conflicts with {model.task} model {overrides['model']}. "
f"Inheriting 'task={model.task}' from {overrides['model']} and ignoring 'task={task}'.")
task = model.task
overrides['task'] = task
if mode == 'predict' and 'source' not in overrides:
overrides['source'] = DEFAULT_CFG.source or ROOT / "assets" if (ROOT / "assets").exists() \
else "https://ultralytics.com/images/bus.jpg"
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
LOGGER.warning(f"WARNING ⚠️ 'source=' is missing. Using default 'source={overrides['source']}'.")
elif mode in ('train', 'val'):
if 'data' not in overrides:
overrides['data'] = DEFAULT_CFG.data or 'mnist160' if task == 'classify' \
else 'coco128-seg.yaml' if task == 'segment' else 'coco128.yaml'
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
overrides['data'] = task2data.get(task, DEFAULT_CFG.data)
LOGGER.warning(f"WARNING ⚠️ 'data=' is missing. Using {model.task} default 'data={overrides['data']}'.")
elif mode == 'export':
if 'format' not in overrides:
overrides['format'] = DEFAULT_CFG.format or 'torchscript'
LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.")
LOGGER.warning(f"WARNING ⚠️ 'format=' is missing. Using default 'format={overrides['format']}'.")
# Run command in python
getattr(model, mode)(**overrides)

View File

@ -186,6 +186,7 @@ class LoadImages:
self.transforms = transforms # optional
self.vid_stride = vid_stride # video frame-rate stride
if any(videos):
self.orientation = None # rotation degrees
self._new_video(videos[0]) # new video
else:
self.cap = None
@ -243,8 +244,10 @@ class LoadImages:
self.frame = 0
self.cap = cv2.VideoCapture(path)
self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META)) # rotation degrees
# self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0) # disable https://github.com/ultralytics/yolov5/issues/8493
if hasattr(cv2, 'CAP_PROP_ORIENTATION_META'): # cv2<4.6.0 compatibility
self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META)) # rotation degrees
# Disable auto-orientation due to known issues in https://github.com/ultralytics/yolov5/issues/8493
# self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0)
def _cv2_rotate(self, im):
# Rotate a cv2 video manually

View File

@ -16,7 +16,7 @@ from PIL import ExifTags, Image, ImageOps
from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, ROOT, colorstr, emojis, yaml_load
from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii
from ultralytics.yolo.utils.downloads import download
from ultralytics.yolo.utils.downloads import download, safe_download
from ultralytics.yolo.utils.files import unzip_file
from ultralytics.yolo.utils.ops import segments2boxes
@ -238,8 +238,7 @@ def check_det_dataset(dataset, autodownload=True):
t = time.time()
if s.startswith('http') and s.endswith('.zip'): # URL
f = Path(s).name # filename
LOGGER.info(f'Downloading {s} to {f}...')
torch.hub.download_url_to_file(s, f)
safe_download(file=f, url=s)
Path(DATASETS_DIR).mkdir(parents=True, exist_ok=True) # create root
unzip_file(f, path=DATASETS_DIR) # unzip
Path(f).unlink() # remove zip

View File

@ -233,7 +233,7 @@ class BasePredictor:
device = select_device(self.args.device)
model = model or self.args.model
self.args.half &= device.type != 'cpu' # half precision only supported on CUDA
self.model = AutoBackend(model, device=device, dnn=self.args.dnn, fp16=self.args.half)
self.model = AutoBackend(model, device=device, dnn=self.args.dnn, data=self.args.data, fp16=self.args.half)
self.device = device
self.model.eval()

View File

@ -6,7 +6,7 @@ try:
import clearml
from clearml import Task
assert hasattr(clearml, '__version__')
assert clearml.__version__ # verify package is not directory
except (ImportError, AssertionError):
clearml = None

View File

@ -5,7 +5,7 @@ from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
try:
import comet_ml
except (ModuleNotFoundError, ImportError):
except ImportError:
comet_ml = None
@ -35,7 +35,7 @@ def on_fit_epoch_end(trainer):
def on_train_end(trainer):
experiment = comet_ml.get_global_experiment()
experiment.log_model("YOLOv8", file_or_folder=trainer.best, file_name="best.pt", overwrite=True)
experiment.log_model("YOLOv8", file_or_folder=str(trainer.best), file_name="best.pt", overwrite=True)
callbacks = {

View File

@ -18,8 +18,8 @@ import psutil
import torch
from IPython import display
from ultralytics.yolo.utils import (AUTOINSTALL, FONT, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, emojis,
is_colab, is_docker, is_jupyter)
from ultralytics.yolo.utils import (AUTOINSTALL, FONT, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, downloads,
emojis, is_colab, is_docker, is_jupyter)
def is_ascii(s) -> bool:
@ -123,9 +123,7 @@ def check_font(font: str = FONT, progress: bool = False) -> None:
# Check if font file exists at the source or destination path
if not font.exists() and not file.exists():
# Download font file
url = f'https://ultralytics.com/assets/{font.name}'
LOGGER.info(f'Downloading {url} to {file}...')
torch.hub.download_url_to_file(url, str(file), progress=progress)
downloads.safe_download(file=file, url=f'https://ultralytics.com/assets/{font.name}', progress=progress)
def check_online() -> bool:
@ -215,9 +213,7 @@ def check_file(file, suffix=''):
if Path(file).is_file():
LOGGER.info(f'Found {url} locally at {file}') # file already exists
else:
LOGGER.info(f'Downloading {url} to {file}...')
torch.hub.download_url_to_file(url, file)
assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
downloads.safe_download(file=file, url=url)
return file
else: # search
files = []

View File

@ -12,16 +12,16 @@ from zipfile import ZipFile
import requests
import torch
from ultralytics.yolo.utils import LOGGER, SETTINGS
from ultralytics.yolo.utils import LOGGER
def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
def safe_download(file, url, url2=None, min_bytes=1E0, error_msg='', progress=True):
# Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
file = Path(file)
assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
try: # url1
LOGGER.info(f'Downloading {url} to {file}...')
torch.hub.download_url_to_file(url, str(file), progress=LOGGER.level <= logging.INFO)
torch.hub.download_url_to_file(url, str(file), progress=progress and LOGGER.level <= logging.INFO)
assert file.exists() and file.stat().st_size > min_bytes, assert_msg # check
except Exception as e: # url2
if file.exists():
@ -32,7 +32,7 @@ def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
if not file.exists() or file.stat().st_size < min_bytes: # check
if file.exists():
file.unlink() # remove partial downloads
LOGGER.info(f"ERROR: {assert_msg}\n{error_msg}")
LOGGER.warning(f"ERROR: {assert_msg}\n{error_msg}")
LOGGER.info('')
@ -49,6 +49,7 @@ def is_url(url, check=True):
def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'):
# Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
from ultralytics.yolo.utils import SETTINGS
def github_assets(repository, version='latest'):
# Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov5m.pt', ...])
@ -76,7 +77,6 @@ def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'):
return file
# GitHub assets
assets = [f'yolov5{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] # default
assets = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] # default
try:
tag, assets = github_assets(repo, release)
@ -110,13 +110,12 @@ def download(url, dir=Path.cwd(), unzip=True, delete=True, curl=False, threads=1
f = dir / Path(url).name
LOGGER.info(f'Downloading {url} to {f}...')
for i in range(retry + 1):
if curl:
s = 'sS' if threads > 1 else '' # silent
r = os.system(
f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -') # curl download with retry, continue
if curl: # curl download with retry, continue
s = 'sS' * (threads > 1) # silent
r = os.system(f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -')
success = r == 0
else:
torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
else: # torch download
torch.hub.download_url_to_file(url, f, progress=threads == 1)
success = f.is_file()
if success:
break

View File

@ -67,9 +67,19 @@ def select_device(device='', batch=0, newline=False):
if cpu or mps:
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
elif device: # non-cpu device requested
visible = os.environ.get('CUDA_VISIBLE_DEVICES', None)
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', ''))):
raise ValueError(f"Invalid CUDA 'device={device}' requested, use 'device=cpu' or pass valid CUDA device(s)")
LOGGER.info(s)
install = "See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no " \
"CUDA devices are seen by torch.\n" if torch.cuda.device_count() == 0 else ""
raise ValueError(f"Invalid CUDA 'device={device}' requested."
f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}"
f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}"
f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
f"{install}")
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7

View File

@ -69,7 +69,7 @@ def predict(cfg=DEFAULT_CFG, use_python=False):
source = cfg.source if cfg.source is not None else ROOT / "assets" if (ROOT / "assets").exists() \
else "https://ultralytics.com/images/bus.jpg"
args = dict(model=model, source=source, verbose=True)
args = dict(model=model, source=source)
if use_python:
from ultralytics import YOLO
YOLO(model)(**args)

View File

@ -141,7 +141,7 @@ def train(cfg=DEFAULT_CFG, use_python=False):
data = cfg.data or "mnist160" # or yolo.ClassificationDataset("mnist")
device = cfg.device if cfg.device is not None else ''
args = dict(model=model, data=data, device=device, verbose=True)
args = dict(model=model, data=data, device=device)
if use_python:
from ultralytics import YOLO
YOLO(model).train(**args)

View File

@ -50,7 +50,7 @@ def val(cfg=DEFAULT_CFG, use_python=False):
model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
data = cfg.data or "mnist160"
args = dict(model=model, data=data, verbose=True)
args = dict(model=model, data=data)
if use_python:
from ultralytics import YOLO
YOLO(model).val(**args)

View File

@ -87,7 +87,7 @@ def predict(cfg=DEFAULT_CFG, use_python=False):
source = cfg.source if cfg.source is not None else ROOT / "assets" if (ROOT / "assets").exists() \
else "https://ultralytics.com/images/bus.jpg"
args = dict(model=model, source=source, verbose=True)
args = dict(model=model, source=source)
if use_python:
from ultralytics import YOLO
YOLO(model)(**args)

View File

@ -199,7 +199,7 @@ def train(cfg=DEFAULT_CFG, use_python=False):
data = cfg.data or "coco128.yaml" # or yolo.ClassificationDataset("mnist")
device = cfg.device if cfg.device is not None else ''
args = dict(model=model, data=data, device=device, verbose=True)
args = dict(model=model, data=data, device=device)
if use_python:
from ultralytics import YOLO
YOLO(model).train(**args)

View File

@ -129,7 +129,7 @@ class DetectionValidator(BaseValidator):
f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels')
# Print results per class
if (self.args.verbose or not self.training) and self.nc > 1 and len(self.stats):
if self.args.verbose and not self.training and self.nc > 1 and len(self.stats):
for i, c in enumerate(self.metrics.ap_class_index):
self.logger.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
@ -237,7 +237,7 @@ def val(cfg=DEFAULT_CFG, use_python=False):
model = cfg.model or "yolov8n.pt"
data = cfg.data or "coco128.yaml"
args = dict(model=model, data=data, verbose=True)
args = dict(model=model, data=data)
if use_python:
from ultralytics import YOLO
YOLO(model).val(**args)

View File

@ -105,7 +105,7 @@ def predict(cfg=DEFAULT_CFG, use_python=False):
source = cfg.source if cfg.source is not None else ROOT / "assets" if (ROOT / "assets").exists() \
else "https://ultralytics.com/images/bus.jpg"
args = dict(model=model, source=source, verbose=True)
args = dict(model=model, source=source)
if use_python:
from ultralytics import YOLO
YOLO(model)(**args)

View File

@ -145,7 +145,7 @@ def train(cfg=DEFAULT_CFG, use_python=False):
data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist")
device = cfg.device if cfg.device is not None else ''
args = dict(model=model, data=data, device=device, verbose=True)
args = dict(model=model, data=data, device=device)
if use_python:
from ultralytics import YOLO
YOLO(model).train(**args)

View File

@ -247,7 +247,7 @@ def val(cfg=DEFAULT_CFG, use_python=False):
model = cfg.model or "yolov8n-seg.pt"
data = cfg.data or "coco128-seg.yaml"
args = dict(model=model, data=data, verbose=True)
args = dict(model=model, data=data)
if use_python:
from ultralytics import YOLO
YOLO(model).val(**args)