`imgsz` warning fix, download function consolidation (#681)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: HaeJin Lee <seareale@gmail.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent 0609561549
commit 899abe9f82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
__version__ = "8.0.22" __version__ = "8.0.23"
from ultralytics.yolo.engine.model import YOLO from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils import ops from ultralytics.yolo.utils import ops

@ -14,7 +14,7 @@ from PIL import Image
from ultralytics.yolo.utils import LOGGER, ROOT, yaml_load from ultralytics.yolo.utils import LOGGER, ROOT, yaml_load
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_version from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_version
from ultralytics.yolo.utils.downloads import attempt_download, is_url from ultralytics.yolo.utils.downloads import attempt_download_asset, is_url
from ultralytics.yolo.utils.ops import xywh2xyxy from ultralytics.yolo.utils.ops import xywh2xyxy
@ -58,7 +58,7 @@ class AutoBackend(nn.Module):
model = None # TODO: resolves ONNX inference, verify effect on other backends model = None # TODO: resolves ONNX inference, verify effect on other backends
cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
if not (pt or triton or nn_module): if not (pt or triton or nn_module):
w = attempt_download(w) # download if not local w = attempt_download_asset(w) # download if not local
# NOTE: special case: in-memory pytorch model # NOTE: special case: in-memory pytorch model
if nn_module: if nn_module:

@ -325,9 +325,9 @@ def torch_safe_load(weight):
Returns: Returns:
The loaded PyTorch model. The loaded PyTorch model.
""" """
from ultralytics.yolo.utils.downloads import attempt_download from ultralytics.yolo.utils.downloads import attempt_download_asset
file = attempt_download(weight) # search online if missing locally file = attempt_download_asset(weight) # search online if missing locally
try: try:
return torch.load(file, map_location='cpu') # load return torch.load(file, map_location='cpu') # load
except ModuleNotFoundError as e: except ModuleNotFoundError as e:

@ -90,7 +90,7 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG, override
# Type checks # Type checks
for k in 'project', 'name': for k in 'project', 'name':
if isinstance(cfg[k], (int, float)): if k in cfg and isinstance(cfg[k], (int, float)):
cfg[k] = str(cfg[k]) cfg[k] = str(cfg[k])
# Return instance # Return instance
@ -176,7 +176,7 @@ def entrypoint(debug=False):
'version': lambda: LOGGER.info(__version__), 'version': lambda: LOGGER.info(__version__),
'settings': lambda: yaml_print(USER_CONFIG_DIR / 'settings.yaml'), 'settings': lambda: yaml_print(USER_CONFIG_DIR / 'settings.yaml'),
'cfg': lambda: yaml_print(DEFAULT_CFG_PATH), 'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
'copy-cfg': copy_default_config} 'copy-cfg': copy_default_cfg}
overrides = {} # basic overrides, i.e. imgsz=320 overrides = {} # basic overrides, i.e. imgsz=320
for a in merge_equals_args(args): # merge spaces around '=' sign for a in merge_equals_args(args): # merge spaces around '=' sign
@ -221,7 +221,7 @@ def entrypoint(debug=False):
task2data = dict(detect='coco128.yaml', segment='coco128-seg.yaml', classify='mnist160') task2data = dict(detect='coco128.yaml', segment='coco128-seg.yaml', classify='mnist160')
# Mode # Mode
mode = overrides['mode'] mode = overrides.get('mode', None)
if mode is None: if mode is None:
mode = DEFAULT_CFG.mode or 'predict' mode = DEFAULT_CFG.mode or 'predict'
LOGGER.warning(f"WARNING ⚠️ 'mode=' is missing. Valid modes are {modes}. Using default 'mode={mode}'.") LOGGER.warning(f"WARNING ⚠️ 'mode=' is missing. Valid modes are {modes}. Using default 'mode={mode}'.")
@ -266,7 +266,7 @@ def entrypoint(debug=False):
# Special modes -------------------------------------------------------------------------------------------------------- # Special modes --------------------------------------------------------------------------------------------------------
def copy_default_config(): def copy_default_cfg():
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml') new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
shutil.copy2(DEFAULT_CFG_PATH, new_file) shutil.copy2(DEFAULT_CFG_PATH, new_file)
LOGGER.info(f"{PREFIX}{DEFAULT_CFG_PATH} copied to {new_file}\n" LOGGER.info(f"{PREFIX}{DEFAULT_CFG_PATH} copied to {new_file}\n"

@ -44,8 +44,20 @@ class Compose:
self.transforms = transforms self.transforms = transforms
def __call__(self, data): def __call__(self, data):
mosaic_p = None
mosaic_imgsz = None
for t in self.transforms: for t in self.transforms:
data = t(data) if isinstance(t, Mosaic):
temp = t(data)
mosaic_p = False if temp == data else True
mosaic_imgsz = t.imgsz
data = temp
else:
if isinstance(t, RandomPerspective):
t.border = [-mosaic_imgsz // 2, -mosaic_imgsz // 2] if mosaic_p else [0, 0]
data = t(data)
return data return data
def append(self, transform): def append(self, transform):

@ -120,7 +120,8 @@ class BaseDataset(Dataset):
im = np.load(fn) im = np.load(fn)
else: # read image else: # read image
im = cv2.imread(f) # BGR im = cv2.imread(f) # BGR
assert im is not None, f"Image Not Found {f}" if im is None:
raise FileNotFoundError(f"Image Not Found {f}")
h0, w0 = im.shape[:2] # orig hw h0, w0 = im.shape[:2] # orig hw
r = self.imgsz / max(h0, w0) # ratio r = self.imgsz / max(h0, w0) # ratio
if r != 1: # if sizes are not equal if r != 1: # if sizes are not equal

@ -65,7 +65,7 @@ def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank
assert mode in ["train", "val"] assert mode in ["train", "val"]
shuffle = mode == "train" shuffle = mode == "train"
if cfg.rect and shuffle: if cfg.rect and shuffle:
LOGGER.warning("WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False") LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
shuffle = False shuffle = False
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = YOLODataset( dataset = YOLODataset(

@ -64,7 +64,7 @@ download: |
# Download # Download
dir = Path(yaml['path']) # dataset root dir dir = Path(yaml['path']) # dataset root dir
urls = ['https://argoverse-hd.s3.us-east-2.amazonaws.com/Argoverse-HD-Full.zip'] urls = ['https://argoverse-hd.s3.us-east-2.amazonaws.com/Argoverse-HD-Full.zip']
download(urls, dir=dir, delete=False) download(urls, dir=dir)
# Convert # Convert
annotations_dir = 'Argoverse-HD/annotations/' annotations_dir = 'Argoverse-HD/annotations/'

@ -411,12 +411,12 @@ download: |
# Download # Download
url = f"https://dorc.ks3-cn-beijing.ksyun.com/data-set/2020Objects365%E6%95%B0%E6%8D%AE%E9%9B%86/{split}/" url = f"https://dorc.ks3-cn-beijing.ksyun.com/data-set/2020Objects365%E6%95%B0%E6%8D%AE%E9%9B%86/{split}/"
if split == 'train': if split == 'train':
download([f'{url}zhiyuan_objv2_{split}.tar.gz'], dir=dir, delete=False) # annotations json download([f'{url}zhiyuan_objv2_{split}.tar.gz'], dir=dir) # annotations json
download([f'{url}patch{i}.tar.gz' for i in range(patches)], dir=images, curl=True, delete=False, threads=8) download([f'{url}patch{i}.tar.gz' for i in range(patches)], dir=images, curl=True, threads=8)
elif split == 'val': elif split == 'val':
download([f'{url}zhiyuan_objv2_{split}.json'], dir=dir, delete=False) # annotations json download([f'{url}zhiyuan_objv2_{split}.json'], dir=dir) # annotations json
download([f'{url}images/v1/patch{i}.tar.gz' for i in range(15 + 1)], dir=images, curl=True, delete=False, threads=8) download([f'{url}images/v1/patch{i}.tar.gz' for i in range(15 + 1)], dir=images, curl=True, threads=8)
download([f'{url}images/v2/patch{i}.tar.gz' for i in range(16, patches)], dir=images, curl=True, delete=False, threads=8) download([f'{url}images/v2/patch{i}.tar.gz' for i in range(16, patches)], dir=images, curl=True, threads=8)
# Move # Move
for f in tqdm(images.rglob('*.jpg'), desc=f'Moving {split} images'): for f in tqdm(images.rglob('*.jpg'), desc=f'Moving {split} images'):

@ -34,7 +34,7 @@ download: |
dir = Path(yaml['path']) # dataset root dir dir = Path(yaml['path']) # dataset root dir
parent = Path(dir.parent) # download dir parent = Path(dir.parent) # download dir
urls = ['http://trax-geometry.s3.amazonaws.com/cvpr_challenge/SKU110K_fixed.tar.gz'] urls = ['http://trax-geometry.s3.amazonaws.com/cvpr_challenge/SKU110K_fixed.tar.gz']
download(urls, dir=parent, delete=False) download(urls, dir=parent)
# Rename directories # Rename directories
if dir.exists(): if dir.exists():

@ -81,7 +81,7 @@ download: |
urls = [f'{url}VOCtrainval_06-Nov-2007.zip', # 446MB, 5012 images urls = [f'{url}VOCtrainval_06-Nov-2007.zip', # 446MB, 5012 images
f'{url}VOCtest_06-Nov-2007.zip', # 438MB, 4953 images f'{url}VOCtest_06-Nov-2007.zip', # 438MB, 4953 images
f'{url}VOCtrainval_11-May-2012.zip'] # 1.95GB, 17126 images f'{url}VOCtrainval_11-May-2012.zip'] # 1.95GB, 17126 images
download(urls, dir=dir / 'images', delete=False, curl=True, threads=3) download(urls, dir=dir / 'images', curl=True, threads=3)
# Convert # Convert
path = dir / 'images/VOCdevkit' path = dir / 'images/VOCdevkit'

@ -138,7 +138,7 @@ download: |
# urls = ['https://d307kc0mrhucc3.cloudfront.net/train_labels.zip', # train labels # urls = ['https://d307kc0mrhucc3.cloudfront.net/train_labels.zip', # train labels
# 'https://d307kc0mrhucc3.cloudfront.net/train_images.zip', # 15G, 847 train images # 'https://d307kc0mrhucc3.cloudfront.net/train_images.zip', # 15G, 847 train images
# 'https://d307kc0mrhucc3.cloudfront.net/val_images.zip'] # 5G, 282 val images (no labels) # 'https://d307kc0mrhucc3.cloudfront.net/val_images.zip'] # 5G, 282 val images (no labels)
# download(urls, dir=dir, delete=False) # download(urls, dir=dir)
# Convert labels # Convert labels
convert_labels(dir / 'xView_train.geojson') convert_labels(dir / 'xView_train.geojson')

@ -237,11 +237,7 @@ def check_det_dataset(dataset, autodownload=True):
raise FileNotFoundError(msg) raise FileNotFoundError(msg)
t = time.time() t = time.time()
if s.startswith('http') and s.endswith('.zip'): # URL if s.startswith('http') and s.endswith('.zip'): # URL
f = Path(s).name # filename safe_download(url=s, dir=DATASETS_DIR, delete=True)
safe_download(file=f, url=s)
Path(DATASETS_DIR).mkdir(parents=True, exist_ok=True) # create root
unzip_file(f, path=DATASETS_DIR) # unzip
Path(f).unlink() # remove zip
r = None # success r = None # success
elif s.startswith('bash '): # bash script elif s.startswith('bash '): # bash script
LOGGER.info(f'Running {s} ...') LOGGER.info(f'Running {s} ...')
@ -251,7 +247,7 @@ def check_det_dataset(dataset, autodownload=True):
dt = f'({round(time.time() - t, 1)}s)' dt = f'({round(time.time() - t, 1)}s)'
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt}" s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt}"
LOGGER.info(f"Dataset download {s}") LOGGER.info(f"Dataset download {s}")
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf') # download fonts
return data # dictionary return data # dictionary

@ -125,7 +125,7 @@ class Exporter:
Initializes the Exporter class. Initializes the Exporter class.
Args: Args:
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG. cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
overrides (dict, optional): Configuration overrides. Defaults to None. overrides (dict, optional): Configuration overrides. Defaults to None.
""" """
self.args = get_cfg(cfg, overrides) self.args = get_cfg(cfg, overrides)

@ -28,7 +28,6 @@ DEFAULT_CFG_PATH = ROOT / "yolo/cfg/default.yaml"
RANK = int(os.getenv('RANK', -1)) RANK = int(os.getenv('RANK', -1))
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
VERBOSE = str(os.getenv('YOLO_VERBOSE', True)).lower() == 'true' # global verbose mode VERBOSE = str(os.getenv('YOLO_VERBOSE', True)).lower() == 'true' # global verbose mode
TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' # tqdm bar format TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' # tqdm bar format
LOGGING_NAME = 'ultralytics' LOGGING_NAME = 'ultralytics'
@ -328,6 +327,20 @@ def get_git_origin_url():
return None # if not git dir or on error return None # if not git dir or on error
def get_git_branch():
"""
Returns the current git branch name. If not in a git repository, returns None.
Returns:
(str) or (None): The current git branch name.
"""
if is_git_dir():
with contextlib.suppress(subprocess.CalledProcessError):
origin = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"])
return origin.decode().strip()
return None # if not git dir or on error
def get_default_args(func): def get_default_args(func):
# Get func() default arguments # Get func() default arguments
signature = inspect.signature(func) signature = inspect.signature(func)
@ -466,7 +479,8 @@ def set_sentry():
if SETTINGS['sync'] and \ if SETTINGS['sync'] and \
not is_pytest_running() and \ not is_pytest_running() and \
not is_github_actions_ci() and \ not is_github_actions_ci() and \
(is_pip_package() or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git"): (is_pip_package() or
(get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git" and get_git_branch() == "main")):
import sentry_sdk # noqa import sentry_sdk # noqa
import ultralytics import ultralytics

@ -28,7 +28,7 @@ def autobatch(model, imgsz=640, fraction=0.7, batch_size=16):
# Check device # Check device
prefix = colorstr('AutoBatch: ') prefix = colorstr('AutoBatch: ')
LOGGER.info(f'{prefix}Computing optimal batch size for --imgsz {imgsz}') LOGGER.info(f'{prefix}Computing optimal batch size for imgsz={imgsz}')
device = next(model.parameters()).device # get model device device = next(model.parameters()).device # get model device
if device.type == 'cpu': if device.type == 'cpu':
LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}') LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')

@ -17,9 +17,10 @@ import pkg_resources as pkg
import psutil import psutil
import torch import torch
from IPython import display from IPython import display
from matplotlib import font_manager
from ultralytics.yolo.utils import (AUTOINSTALL, FONT, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, downloads, from ultralytics.yolo.utils import (AUTOINSTALL, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, downloads, emojis,
emojis, is_colab, is_docker, is_jupyter) is_colab, is_docker, is_jupyter)
def is_ascii(s) -> bool: def is_ascii(s) -> bool:
@ -57,15 +58,14 @@ def check_imgsz(imgsz, stride=32, min_dim=1, floor=0):
stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride) stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride)
# Convert image size to list if it is an integer # Convert image size to list if it is an integer
if isinstance(imgsz, int): imgsz = [imgsz] if isinstance(imgsz, int) else list(imgsz)
imgsz = [imgsz]
# Make image size a multiple of the stride # Make image size a multiple of the stride
sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz] sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz]
# Print warning message if image size was updated # Print warning message if image size was updated
if sz != imgsz: if sz != imgsz:
LOGGER.warning(f'WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {stride}, updating to {sz}') LOGGER.warning(f'WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}')
# Add missing dimensions if necessary # Add missing dimensions if necessary
sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz
@ -104,26 +104,33 @@ def check_version(current: str = "0.0.0",
return result return result
def check_font(font: str = FONT, progress: bool = False) -> None: def check_font(font='Arial.ttf'):
""" """
Download font file to the user's configuration directory if it does not already exist. Find font locally or download to user's configuration directory if it does not already exist.
Args: Args:
font (str): Path to font file. font (str): Path or name of font.
progress (bool): If True, display a progress bar during the download.
Returns: Returns:
None file (Path): Resolved font file path.
""" """
font = Path(font) name = Path(font).name
# Check USER_CONFIG_DIR
file = USER_CONFIG_DIR / name
if file.exists():
return file
# Destination path for the font file # Check system fonts
file = USER_CONFIG_DIR / font.name matches = [s for s in font_manager.findSystemFonts() if font in s]
if any(matches):
return matches[0]
# Check if font file exists at the source or destination path # Download to USER_CONFIG_DIR if missing
if not font.exists() and not file.exists(): url = f'https://ultralytics.com/assets/{name}'
# Download font file if downloads.is_url(url):
downloads.safe_download(file=file, url=f'https://ultralytics.com/assets/{font.name}', progress=progress) downloads.safe_download(url=url, file=file)
return file
def check_online() -> bool: def check_online() -> bool:
@ -213,7 +220,7 @@ def check_file(file, suffix=''):
if Path(file).is_file(): if Path(file).is_file():
LOGGER.info(f'Found {url} locally at {file}') # file already exists LOGGER.info(f'Found {url} locally at {file}') # file already exists
else: else:
downloads.safe_download(file=file, url=url) downloads.safe_download(url=url, file=file)
return file return file
else: # search else: # search
files = [] files = []

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
import logging import contextlib
import os import os
import subprocess import subprocess
import urllib import urllib
@ -15,27 +15,6 @@ import torch
from ultralytics.yolo.utils import LOGGER from ultralytics.yolo.utils import LOGGER
def safe_download(file, url, url2=None, min_bytes=1E0, error_msg='', progress=True):
# Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
file = Path(file)
assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
try: # url1
LOGGER.info(f'Downloading {url} to {file}...')
torch.hub.download_url_to_file(url, str(file), progress=progress and LOGGER.level <= logging.INFO)
assert file.exists() and file.stat().st_size > min_bytes, assert_msg # check
except Exception as e: # url2
if file.exists():
file.unlink() # remove partial downloads
LOGGER.info(f'ERROR: {e}\nRe-attempting {url2 or url} to {file}...')
os.system(f"curl -# -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
finally:
if not file.exists() or file.stat().st_size < min_bytes: # check
if file.exists():
file.unlink() # remove partial downloads
LOGGER.warning(f"ERROR: {assert_msg}\n{error_msg}")
LOGGER.info('')
def is_url(url, check=True): def is_url(url, check=True):
# Check if string is URL and check if URL exists # Check if string is URL and check if URL exists
try: try:
@ -47,7 +26,71 @@ def is_url(url, check=True):
return False return False
def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'): def safe_download(url,
file=None,
dir=None,
unzip=True,
delete=False,
curl=False,
retry=3,
min_bytes=1E0,
progress=True):
"""
Function for downloading files from a URL, with options for retrying, unzipping, and deleting the downloaded file.
Args:
url: str: The URL of the file to be downloaded.
file: str, optional: The filename of the downloaded file.
If not provided, the file will be saved with the same name as the URL.
dir: str, optional: The directory to save the downloaded file.
If not provided, the file will be saved in the current working directory.
unzip: bool, optional: Whether to unzip the downloaded file. Default: True.
delete: bool, optional: Whether to delete the downloaded file after unzipping. Default: False.
curl: bool, optional: Whether to use curl command line tool for downloading. Default: False.
retry: int, optional: The number of times to retry the download in case of failure. Default: 3.
min_bytes: float, optional: The minimum number of bytes that the downloaded file should have, to be considered
a successful download. Default: 1E0.
progress: bool, optional: Whether to display a progress bar during the download. Default: True.
"""
if '://' not in str(url) and Path(url).is_file(): # exists ('://' check required in Windows Python<3.10)
f = Path(url) # filename
else: # does not exist
assert dir or file, 'dir or file required for download'
f = dir / Path(url).name if dir else Path(file)
LOGGER.info(f'Downloading {url} to {f}...')
f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing
for i in range(retry + 1):
try:
if curl or i > 0: # curl download with retry, continue
s = 'sS' * (not progress) # silent
r = os.system(f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -')
else: # torch download
r = torch.hub.download_url_to_file(url, f, progress=progress)
assert r in {0, None}
except Exception as e:
if i >= retry:
raise ConnectionError(f'❌ Download failure for {url}') from e
LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
continue
if f.exists():
if f.stat().st_size > min_bytes:
break # success
f.unlink() # remove partial downloads
if unzip and f.exists() and f.suffix in {'.zip', '.tar', '.gz'}:
LOGGER.info(f'Unzipping {f}...')
if f.suffix == '.zip':
ZipFile(f).extractall(path=f.parent) # unzip
elif f.suffix == '.tar':
os.system(f'tar xf {f} --directory {f.parent}') # unzip
elif f.suffix == '.gz':
os.system(f'tar xfz {f} --directory {f.parent}') # unzip
if delete:
f.unlink() # remove zip
def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'):
# Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc. # Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
from ultralytics.yolo.utils import SETTINGS from ultralytics.yolo.utils import SETTINGS
@ -73,7 +116,7 @@ def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'):
if Path(file).is_file(): if Path(file).is_file():
LOGGER.info(f'Found {url} locally at {file}') # file already exists LOGGER.info(f'Found {url} locally at {file}') # file already exists
else: else:
safe_download(file=file, url=url, min_bytes=1E5) safe_download(url=url, file=file, min_bytes=1E5)
return file return file
# GitHub assets # GitHub assets
@ -91,61 +134,23 @@ def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'):
file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required) file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
if name in assets: if name in assets:
safe_download(file, safe_download(url=f'https://github.com/{repo}/releases/download/{tag}/{name}', file=file, min_bytes=1E5)
url=f'https://github.com/{repo}/releases/download/{tag}/{name}',
min_bytes=1E5,
error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/{tag}')
return str(file) return str(file)
def download(url, dir=Path.cwd(), unzip=True, delete=True, curl=False, threads=1, retry=3): def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3):
# Multithreaded file download and unzip function, used in data.yaml for autodownload # Multithreaded file download and unzip function, used in data.yaml for autodownload
def download_one(url, dir):
# Download 1 file
success = True
if '://' not in str(url) and Path(url).is_file(): # exists ('://' check required in Windows Python<3.10)
f = Path(url) # filename
else: # does not exist
f = dir / Path(url).name
LOGGER.info(f'Downloading {url} to {f}...')
for i in range(retry + 1):
if curl: # curl download with retry, continue
s = 'sS' * (threads > 1) # silent
r = os.system(f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -')
success = r == 0
else: # torch download
torch.hub.download_url_to_file(url, f, progress=threads == 1)
success = f.is_file()
if success:
break
elif i < retry:
LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
else:
LOGGER.warning(f'❌ Failed to download {url}...')
if unzip and success and f.suffix in ('.zip', '.tar', '.gz'):
LOGGER.info(f'Unzipping {f}...')
if f.suffix == '.zip':
ZipFile(f).extractall(path=dir) # unzip
elif f.suffix == '.tar':
os.system(f'tar xf {f} --directory {f.parent}') # unzip
elif f.suffix == '.gz':
os.system(f'tar xfz {f} --directory {f.parent}') # unzip
if delete:
f.unlink() # remove zip
dir = Path(dir) dir = Path(dir)
dir.mkdir(parents=True, exist_ok=True) # make directory dir.mkdir(parents=True, exist_ok=True) # make directory
if threads > 1: if threads > 1:
# pool = ThreadPool(threads)
# pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
# pool.close()
# pool.join()
with ThreadPool(threads) as pool: with ThreadPool(threads) as pool:
pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded pool.map(
lambda x: safe_download(
url=x[0], dir=x[1], unzip=unzip, delete=delete, curl=curl, retry=retry, progress=threads <= 1),
zip(url, repeat(dir)))
pool.close() pool.close()
pool.join() pool.join()
else: else:
for u in [url] if isinstance(url, (str, Path)) else url: for u in [url] if isinstance(url, (str, Path)) else url:
download_one(u, dir) safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry)

@ -3,7 +3,6 @@
import contextlib import contextlib
import math import math
from pathlib import Path from pathlib import Path
from urllib.error import URLError
import cv2 import cv2
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -12,9 +11,9 @@ import pandas as pd
import torch import torch
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
from ultralytics.yolo.utils import FONT, USER_CONFIG_DIR, threaded from ultralytics.yolo.utils import threaded
from .checks import check_font, check_requirements, is_ascii from .checks import check_font, is_ascii
from .files import increment_path from .files import increment_path
from .ops import clip_coords, scale_image, xywh2xyxy, xyxy2xywh from .ops import clip_coords, scale_image, xywh2xyxy, xyxy2xywh
@ -49,14 +48,20 @@ class Annotator:
if self.pil: # use PIL if self.pil: # use PIL
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im) self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
self.draw = ImageDraw.Draw(self.im) self.draw = ImageDraw.Draw(self.im)
self.font = check_pil_font(font='Arial.Unicode.ttf' if non_ascii else font, try:
size=font_size or max(round(sum(self.im.size) / 2 * 0.035), 12)) font = check_font('Arial.Unicode.ttf' if non_ascii else font)
size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12)
self.font = ImageFont.truetype(str(font), size)
except Exception:
self.font = ImageFont.load_default()
else: # use cv2 else: # use cv2
self.im = im self.im = im
self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)): def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
# Add one xyxy box to image with label # Add one xyxy box to image with label
if isinstance(box, torch.Tensor):
box = box.tolist()
if self.pil or not is_ascii(label): if self.pil or not is_ascii(label):
self.draw.rectangle(box, width=self.lw, outline=color) # box self.draw.rectangle(box, width=self.lw, outline=color) # box
if label: if label:
@ -139,22 +144,6 @@ class Annotator:
return np.asarray(self.im) return np.asarray(self.im)
def check_pil_font(font=FONT, size=10):
# Return a PIL TrueType Font, downloading to CONFIG_DIR if necessary
font = Path(font)
font = font if font.exists() else (USER_CONFIG_DIR / font.name)
try:
return ImageFont.truetype(str(font) if font.exists() else font.name, size)
except Exception: # download if missing
try:
check_font(font)
return ImageFont.truetype(str(font), size)
except TypeError:
check_requirements('Pillow>=8.4.0') # known issue https://github.com/ultralytics/yolov5/issues/5374
except URLError: # not online
return ImageFont.load_default()
def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True): def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
# Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop # Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
xyxy = torch.tensor(xyxy).view(-1, 4) xyxy = torch.tensor(xyxy).view(-1, 4)

@ -85,8 +85,8 @@ def select_device(device='', batch=0, newline=False):
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7 devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
n = len(devices) # device count n = len(devices) # device count
if n > 1 and batch > 0 and batch % n != 0: # check batch_size is divisible by device_count if n > 1 and batch > 0 and batch % n != 0: # check batch_size is divisible by device_count
raise ValueError(f'batch={batch} is not multiple of GPU count {n}.\n' raise ValueError(f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
f'Try batch={batch // n} or batch={batch // n + 1}') f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}.")
space = ' ' * (len(s) + 1) space = ' ' * (len(s) + 1)
for i, d in enumerate(devices): for i, d in enumerate(devices):
p = torch.cuda.get_device_properties(i) p = torch.cuda.get_device_properties(i)

@ -74,7 +74,7 @@ def predict(cfg=DEFAULT_CFG, use_python=False):
from ultralytics import YOLO from ultralytics import YOLO
YOLO(model)(**args) YOLO(model)(**args)
else: else:
predictor = ClassificationPredictor(args) predictor = ClassificationPredictor(overrides=args)
predictor.predict_cli() predictor.predict_cli()

@ -146,7 +146,7 @@ def train(cfg=DEFAULT_CFG, use_python=False):
from ultralytics import YOLO from ultralytics import YOLO
YOLO(model).train(**args) YOLO(model).train(**args)
else: else:
trainer = ClassificationTrainer(args) trainer = ClassificationTrainer(overrides=args)
trainer.train() trainer.train()

@ -92,7 +92,7 @@ def predict(cfg=DEFAULT_CFG, use_python=False):
from ultralytics import YOLO from ultralytics import YOLO
YOLO(model)(**args) YOLO(model)(**args)
else: else:
predictor = DetectionPredictor(args) predictor = DetectionPredictor(overrides=args)
predictor.predict_cli() predictor.predict_cli()

@ -204,7 +204,7 @@ def train(cfg=DEFAULT_CFG, use_python=False):
from ultralytics import YOLO from ultralytics import YOLO
YOLO(model).train(**args) YOLO(model).train(**args)
else: else:
trainer = DetectionTrainer(args) trainer = DetectionTrainer(overrides=args)
trainer.train() trainer.train()

@ -110,7 +110,7 @@ def predict(cfg=DEFAULT_CFG, use_python=False):
from ultralytics import YOLO from ultralytics import YOLO
YOLO(model)(**args) YOLO(model)(**args)
else: else:
predictor = SegmentationPredictor(args) predictor = SegmentationPredictor(overrides=args)
predictor.predict_cli() predictor.predict_cli()

@ -150,7 +150,7 @@ def train(cfg=DEFAULT_CFG, use_python=False):
from ultralytics import YOLO from ultralytics import YOLO
YOLO(model).train(**args) YOLO(model).train(**args)
else: else:
trainer = SegmentationTrainer(args) trainer = SegmentationTrainer(overrides=args)
trainer.train() trainer.train()

Loading…
Cancel
Save