`imgsz` warning fix, download function consolidation (#681)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: HaeJin Lee <seareale@gmail.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent 0609561549
commit 899abe9f82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
__version__ = "8.0.22"
__version__ = "8.0.23"
from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils import ops

@ -14,7 +14,7 @@ from PIL import Image
from ultralytics.yolo.utils import LOGGER, ROOT, yaml_load
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_version
from ultralytics.yolo.utils.downloads import attempt_download, is_url
from ultralytics.yolo.utils.downloads import attempt_download_asset, is_url
from ultralytics.yolo.utils.ops import xywh2xyxy
@ -58,7 +58,7 @@ class AutoBackend(nn.Module):
model = None # TODO: resolves ONNX inference, verify effect on other backends
cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
if not (pt or triton or nn_module):
w = attempt_download(w) # download if not local
w = attempt_download_asset(w) # download if not local
# NOTE: special case: in-memory pytorch model
if nn_module:

@ -325,9 +325,9 @@ def torch_safe_load(weight):
Returns:
The loaded PyTorch model.
"""
from ultralytics.yolo.utils.downloads import attempt_download
from ultralytics.yolo.utils.downloads import attempt_download_asset
file = attempt_download(weight) # search online if missing locally
file = attempt_download_asset(weight) # search online if missing locally
try:
return torch.load(file, map_location='cpu') # load
except ModuleNotFoundError as e:

@ -90,7 +90,7 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG, override
# Type checks
for k in 'project', 'name':
if isinstance(cfg[k], (int, float)):
if k in cfg and isinstance(cfg[k], (int, float)):
cfg[k] = str(cfg[k])
# Return instance
@ -176,7 +176,7 @@ def entrypoint(debug=False):
'version': lambda: LOGGER.info(__version__),
'settings': lambda: yaml_print(USER_CONFIG_DIR / 'settings.yaml'),
'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
'copy-cfg': copy_default_config}
'copy-cfg': copy_default_cfg}
overrides = {} # basic overrides, i.e. imgsz=320
for a in merge_equals_args(args): # merge spaces around '=' sign
@ -221,7 +221,7 @@ def entrypoint(debug=False):
task2data = dict(detect='coco128.yaml', segment='coco128-seg.yaml', classify='mnist160')
# Mode
mode = overrides['mode']
mode = overrides.get('mode', None)
if mode is None:
mode = DEFAULT_CFG.mode or 'predict'
LOGGER.warning(f"WARNING ⚠️ 'mode=' is missing. Valid modes are {modes}. Using default 'mode={mode}'.")
@ -266,7 +266,7 @@ def entrypoint(debug=False):
# Special modes --------------------------------------------------------------------------------------------------------
def copy_default_config():
def copy_default_cfg():
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
shutil.copy2(DEFAULT_CFG_PATH, new_file)
LOGGER.info(f"{PREFIX}{DEFAULT_CFG_PATH} copied to {new_file}\n"

@ -44,8 +44,20 @@ class Compose:
self.transforms = transforms
def __call__(self, data):
mosaic_p = None
mosaic_imgsz = None
for t in self.transforms:
if isinstance(t, Mosaic):
temp = t(data)
mosaic_p = False if temp == data else True
mosaic_imgsz = t.imgsz
data = temp
else:
if isinstance(t, RandomPerspective):
t.border = [-mosaic_imgsz // 2, -mosaic_imgsz // 2] if mosaic_p else [0, 0]
data = t(data)
return data
def append(self, transform):

@ -120,7 +120,8 @@ class BaseDataset(Dataset):
im = np.load(fn)
else: # read image
im = cv2.imread(f) # BGR
assert im is not None, f"Image Not Found {f}"
if im is None:
raise FileNotFoundError(f"Image Not Found {f}")
h0, w0 = im.shape[:2] # orig hw
r = self.imgsz / max(h0, w0) # ratio
if r != 1: # if sizes are not equal

@ -65,7 +65,7 @@ def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank
assert mode in ["train", "val"]
shuffle = mode == "train"
if cfg.rect and shuffle:
LOGGER.warning("WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False")
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
shuffle = False
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = YOLODataset(

@ -64,7 +64,7 @@ download: |
# Download
dir = Path(yaml['path']) # dataset root dir
urls = ['https://argoverse-hd.s3.us-east-2.amazonaws.com/Argoverse-HD-Full.zip']
download(urls, dir=dir, delete=False)
download(urls, dir=dir)
# Convert
annotations_dir = 'Argoverse-HD/annotations/'

@ -411,12 +411,12 @@ download: |
# Download
url = f"https://dorc.ks3-cn-beijing.ksyun.com/data-set/2020Objects365%E6%95%B0%E6%8D%AE%E9%9B%86/{split}/"
if split == 'train':
download([f'{url}zhiyuan_objv2_{split}.tar.gz'], dir=dir, delete=False) # annotations json
download([f'{url}patch{i}.tar.gz' for i in range(patches)], dir=images, curl=True, delete=False, threads=8)
download([f'{url}zhiyuan_objv2_{split}.tar.gz'], dir=dir) # annotations json
download([f'{url}patch{i}.tar.gz' for i in range(patches)], dir=images, curl=True, threads=8)
elif split == 'val':
download([f'{url}zhiyuan_objv2_{split}.json'], dir=dir, delete=False) # annotations json
download([f'{url}images/v1/patch{i}.tar.gz' for i in range(15 + 1)], dir=images, curl=True, delete=False, threads=8)
download([f'{url}images/v2/patch{i}.tar.gz' for i in range(16, patches)], dir=images, curl=True, delete=False, threads=8)
download([f'{url}zhiyuan_objv2_{split}.json'], dir=dir) # annotations json
download([f'{url}images/v1/patch{i}.tar.gz' for i in range(15 + 1)], dir=images, curl=True, threads=8)
download([f'{url}images/v2/patch{i}.tar.gz' for i in range(16, patches)], dir=images, curl=True, threads=8)
# Move
for f in tqdm(images.rglob('*.jpg'), desc=f'Moving {split} images'):

@ -34,7 +34,7 @@ download: |
dir = Path(yaml['path']) # dataset root dir
parent = Path(dir.parent) # download dir
urls = ['http://trax-geometry.s3.amazonaws.com/cvpr_challenge/SKU110K_fixed.tar.gz']
download(urls, dir=parent, delete=False)
download(urls, dir=parent)
# Rename directories
if dir.exists():

@ -81,7 +81,7 @@ download: |
urls = [f'{url}VOCtrainval_06-Nov-2007.zip', # 446MB, 5012 images
f'{url}VOCtest_06-Nov-2007.zip', # 438MB, 4953 images
f'{url}VOCtrainval_11-May-2012.zip'] # 1.95GB, 17126 images
download(urls, dir=dir / 'images', delete=False, curl=True, threads=3)
download(urls, dir=dir / 'images', curl=True, threads=3)
# Convert
path = dir / 'images/VOCdevkit'

@ -138,7 +138,7 @@ download: |
# urls = ['https://d307kc0mrhucc3.cloudfront.net/train_labels.zip', # train labels
# 'https://d307kc0mrhucc3.cloudfront.net/train_images.zip', # 15G, 847 train images
# 'https://d307kc0mrhucc3.cloudfront.net/val_images.zip'] # 5G, 282 val images (no labels)
# download(urls, dir=dir, delete=False)
# download(urls, dir=dir)
# Convert labels
convert_labels(dir / 'xView_train.geojson')

@ -237,11 +237,7 @@ def check_det_dataset(dataset, autodownload=True):
raise FileNotFoundError(msg)
t = time.time()
if s.startswith('http') and s.endswith('.zip'): # URL
f = Path(s).name # filename
safe_download(file=f, url=s)
Path(DATASETS_DIR).mkdir(parents=True, exist_ok=True) # create root
unzip_file(f, path=DATASETS_DIR) # unzip
Path(f).unlink() # remove zip
safe_download(url=s, dir=DATASETS_DIR, delete=True)
r = None # success
elif s.startswith('bash '): # bash script
LOGGER.info(f'Running {s} ...')
@ -251,7 +247,7 @@ def check_det_dataset(dataset, autodownload=True):
dt = f'({round(time.time() - t, 1)}s)'
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt}"
LOGGER.info(f"Dataset download {s}")
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf') # download fonts
return data # dictionary

@ -125,7 +125,7 @@ class Exporter:
Initializes the Exporter class.
Args:
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
overrides (dict, optional): Configuration overrides. Defaults to None.
"""
self.args = get_cfg(cfg, overrides)

@ -28,7 +28,6 @@ DEFAULT_CFG_PATH = ROOT / "yolo/cfg/default.yaml"
RANK = int(os.getenv('RANK', -1))
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
VERBOSE = str(os.getenv('YOLO_VERBOSE', True)).lower() == 'true' # global verbose mode
TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' # tqdm bar format
LOGGING_NAME = 'ultralytics'
@ -328,6 +327,20 @@ def get_git_origin_url():
return None # if not git dir or on error
def get_git_branch():
"""
Returns the current git branch name. If not in a git repository, returns None.
Returns:
(str) or (None): The current git branch name.
"""
if is_git_dir():
with contextlib.suppress(subprocess.CalledProcessError):
origin = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"])
return origin.decode().strip()
return None # if not git dir or on error
def get_default_args(func):
# Get func() default arguments
signature = inspect.signature(func)
@ -466,7 +479,8 @@ def set_sentry():
if SETTINGS['sync'] and \
not is_pytest_running() and \
not is_github_actions_ci() and \
(is_pip_package() or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git"):
(is_pip_package() or
(get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git" and get_git_branch() == "main")):
import sentry_sdk # noqa
import ultralytics

@ -28,7 +28,7 @@ def autobatch(model, imgsz=640, fraction=0.7, batch_size=16):
# Check device
prefix = colorstr('AutoBatch: ')
LOGGER.info(f'{prefix}Computing optimal batch size for --imgsz {imgsz}')
LOGGER.info(f'{prefix}Computing optimal batch size for imgsz={imgsz}')
device = next(model.parameters()).device # get model device
if device.type == 'cpu':
LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')

@ -17,9 +17,10 @@ import pkg_resources as pkg
import psutil
import torch
from IPython import display
from matplotlib import font_manager
from ultralytics.yolo.utils import (AUTOINSTALL, FONT, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, downloads,
emojis, is_colab, is_docker, is_jupyter)
from ultralytics.yolo.utils import (AUTOINSTALL, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, downloads, emojis,
is_colab, is_docker, is_jupyter)
def is_ascii(s) -> bool:
@ -57,15 +58,14 @@ def check_imgsz(imgsz, stride=32, min_dim=1, floor=0):
stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride)
# Convert image size to list if it is an integer
if isinstance(imgsz, int):
imgsz = [imgsz]
imgsz = [imgsz] if isinstance(imgsz, int) else list(imgsz)
# Make image size a multiple of the stride
sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz]
# Print warning message if image size was updated
if sz != imgsz:
LOGGER.warning(f'WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {stride}, updating to {sz}')
LOGGER.warning(f'WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}')
# Add missing dimensions if necessary
sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz
@ -104,26 +104,33 @@ def check_version(current: str = "0.0.0",
return result
def check_font(font: str = FONT, progress: bool = False) -> None:
def check_font(font='Arial.ttf'):
"""
Download font file to the user's configuration directory if it does not already exist.
Find font locally or download to user's configuration directory if it does not already exist.
Args:
font (str): Path to font file.
progress (bool): If True, display a progress bar during the download.
font (str): Path or name of font.
Returns:
None
file (Path): Resolved font file path.
"""
font = Path(font)
name = Path(font).name
# Check USER_CONFIG_DIR
file = USER_CONFIG_DIR / name
if file.exists():
return file
# Destination path for the font file
file = USER_CONFIG_DIR / font.name
# Check system fonts
matches = [s for s in font_manager.findSystemFonts() if font in s]
if any(matches):
return matches[0]
# Check if font file exists at the source or destination path
if not font.exists() and not file.exists():
# Download font file
downloads.safe_download(file=file, url=f'https://ultralytics.com/assets/{font.name}', progress=progress)
# Download to USER_CONFIG_DIR if missing
url = f'https://ultralytics.com/assets/{name}'
if downloads.is_url(url):
downloads.safe_download(url=url, file=file)
return file
def check_online() -> bool:
@ -213,7 +220,7 @@ def check_file(file, suffix=''):
if Path(file).is_file():
LOGGER.info(f'Found {url} locally at {file}') # file already exists
else:
downloads.safe_download(file=file, url=url)
downloads.safe_download(url=url, file=file)
return file
else: # search
files = []

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
import logging
import contextlib
import os
import subprocess
import urllib
@ -15,27 +15,6 @@ import torch
from ultralytics.yolo.utils import LOGGER
def safe_download(file, url, url2=None, min_bytes=1E0, error_msg='', progress=True):
# Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
file = Path(file)
assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
try: # url1
LOGGER.info(f'Downloading {url} to {file}...')
torch.hub.download_url_to_file(url, str(file), progress=progress and LOGGER.level <= logging.INFO)
assert file.exists() and file.stat().st_size > min_bytes, assert_msg # check
except Exception as e: # url2
if file.exists():
file.unlink() # remove partial downloads
LOGGER.info(f'ERROR: {e}\nRe-attempting {url2 or url} to {file}...')
os.system(f"curl -# -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
finally:
if not file.exists() or file.stat().st_size < min_bytes: # check
if file.exists():
file.unlink() # remove partial downloads
LOGGER.warning(f"ERROR: {assert_msg}\n{error_msg}")
LOGGER.info('')
def is_url(url, check=True):
# Check if string is URL and check if URL exists
try:
@ -47,7 +26,71 @@ def is_url(url, check=True):
return False
def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'):
def safe_download(url,
file=None,
dir=None,
unzip=True,
delete=False,
curl=False,
retry=3,
min_bytes=1E0,
progress=True):
"""
Function for downloading files from a URL, with options for retrying, unzipping, and deleting the downloaded file.
Args:
url: str: The URL of the file to be downloaded.
file: str, optional: The filename of the downloaded file.
If not provided, the file will be saved with the same name as the URL.
dir: str, optional: The directory to save the downloaded file.
If not provided, the file will be saved in the current working directory.
unzip: bool, optional: Whether to unzip the downloaded file. Default: True.
delete: bool, optional: Whether to delete the downloaded file after unzipping. Default: False.
curl: bool, optional: Whether to use curl command line tool for downloading. Default: False.
retry: int, optional: The number of times to retry the download in case of failure. Default: 3.
min_bytes: float, optional: The minimum number of bytes that the downloaded file should have, to be considered
a successful download. Default: 1E0.
progress: bool, optional: Whether to display a progress bar during the download. Default: True.
"""
if '://' not in str(url) and Path(url).is_file(): # exists ('://' check required in Windows Python<3.10)
f = Path(url) # filename
else: # does not exist
assert dir or file, 'dir or file required for download'
f = dir / Path(url).name if dir else Path(file)
LOGGER.info(f'Downloading {url} to {f}...')
f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing
for i in range(retry + 1):
try:
if curl or i > 0: # curl download with retry, continue
s = 'sS' * (not progress) # silent
r = os.system(f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -')
else: # torch download
r = torch.hub.download_url_to_file(url, f, progress=progress)
assert r in {0, None}
except Exception as e:
if i >= retry:
raise ConnectionError(f'❌ Download failure for {url}') from e
LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
continue
if f.exists():
if f.stat().st_size > min_bytes:
break # success
f.unlink() # remove partial downloads
if unzip and f.exists() and f.suffix in {'.zip', '.tar', '.gz'}:
LOGGER.info(f'Unzipping {f}...')
if f.suffix == '.zip':
ZipFile(f).extractall(path=f.parent) # unzip
elif f.suffix == '.tar':
os.system(f'tar xf {f} --directory {f.parent}') # unzip
elif f.suffix == '.gz':
os.system(f'tar xfz {f} --directory {f.parent}') # unzip
if delete:
f.unlink() # remove zip
def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'):
# Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
from ultralytics.yolo.utils import SETTINGS
@ -73,7 +116,7 @@ def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'):
if Path(file).is_file():
LOGGER.info(f'Found {url} locally at {file}') # file already exists
else:
safe_download(file=file, url=url, min_bytes=1E5)
safe_download(url=url, file=file, min_bytes=1E5)
return file
# GitHub assets
@ -91,61 +134,23 @@ def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'):
file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
if name in assets:
safe_download(file,
url=f'https://github.com/{repo}/releases/download/{tag}/{name}',
min_bytes=1E5,
error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/{tag}')
safe_download(url=f'https://github.com/{repo}/releases/download/{tag}/{name}', file=file, min_bytes=1E5)
return str(file)
def download(url, dir=Path.cwd(), unzip=True, delete=True, curl=False, threads=1, retry=3):
def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3):
# Multithreaded file download and unzip function, used in data.yaml for autodownload
def download_one(url, dir):
# Download 1 file
success = True
if '://' not in str(url) and Path(url).is_file(): # exists ('://' check required in Windows Python<3.10)
f = Path(url) # filename
else: # does not exist
f = dir / Path(url).name
LOGGER.info(f'Downloading {url} to {f}...')
for i in range(retry + 1):
if curl: # curl download with retry, continue
s = 'sS' * (threads > 1) # silent
r = os.system(f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -')
success = r == 0
else: # torch download
torch.hub.download_url_to_file(url, f, progress=threads == 1)
success = f.is_file()
if success:
break
elif i < retry:
LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
else:
LOGGER.warning(f'❌ Failed to download {url}...')
if unzip and success and f.suffix in ('.zip', '.tar', '.gz'):
LOGGER.info(f'Unzipping {f}...')
if f.suffix == '.zip':
ZipFile(f).extractall(path=dir) # unzip
elif f.suffix == '.tar':
os.system(f'tar xf {f} --directory {f.parent}') # unzip
elif f.suffix == '.gz':
os.system(f'tar xfz {f} --directory {f.parent}') # unzip
if delete:
f.unlink() # remove zip
dir = Path(dir)
dir.mkdir(parents=True, exist_ok=True) # make directory
if threads > 1:
# pool = ThreadPool(threads)
# pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
# pool.close()
# pool.join()
with ThreadPool(threads) as pool:
pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
pool.map(
lambda x: safe_download(
url=x[0], dir=x[1], unzip=unzip, delete=delete, curl=curl, retry=retry, progress=threads <= 1),
zip(url, repeat(dir)))
pool.close()
pool.join()
else:
for u in [url] if isinstance(url, (str, Path)) else url:
download_one(u, dir)
safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry)

@ -3,7 +3,6 @@
import contextlib
import math
from pathlib import Path
from urllib.error import URLError
import cv2
import matplotlib.pyplot as plt
@ -12,9 +11,9 @@ import pandas as pd
import torch
from PIL import Image, ImageDraw, ImageFont
from ultralytics.yolo.utils import FONT, USER_CONFIG_DIR, threaded
from ultralytics.yolo.utils import threaded
from .checks import check_font, check_requirements, is_ascii
from .checks import check_font, is_ascii
from .files import increment_path
from .ops import clip_coords, scale_image, xywh2xyxy, xyxy2xywh
@ -49,14 +48,20 @@ class Annotator:
if self.pil: # use PIL
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
self.draw = ImageDraw.Draw(self.im)
self.font = check_pil_font(font='Arial.Unicode.ttf' if non_ascii else font,
size=font_size or max(round(sum(self.im.size) / 2 * 0.035), 12))
try:
font = check_font('Arial.Unicode.ttf' if non_ascii else font)
size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12)
self.font = ImageFont.truetype(str(font), size)
except Exception:
self.font = ImageFont.load_default()
else: # use cv2
self.im = im
self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
# Add one xyxy box to image with label
if isinstance(box, torch.Tensor):
box = box.tolist()
if self.pil or not is_ascii(label):
self.draw.rectangle(box, width=self.lw, outline=color) # box
if label:
@ -139,22 +144,6 @@ class Annotator:
return np.asarray(self.im)
def check_pil_font(font=FONT, size=10):
# Return a PIL TrueType Font, downloading to CONFIG_DIR if necessary
font = Path(font)
font = font if font.exists() else (USER_CONFIG_DIR / font.name)
try:
return ImageFont.truetype(str(font) if font.exists() else font.name, size)
except Exception: # download if missing
try:
check_font(font)
return ImageFont.truetype(str(font), size)
except TypeError:
check_requirements('Pillow>=8.4.0') # known issue https://github.com/ultralytics/yolov5/issues/5374
except URLError: # not online
return ImageFont.load_default()
def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
# Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
xyxy = torch.tensor(xyxy).view(-1, 4)

@ -85,8 +85,8 @@ def select_device(device='', batch=0, newline=False):
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
n = len(devices) # device count
if n > 1 and batch > 0 and batch % n != 0: # check batch_size is divisible by device_count
raise ValueError(f'batch={batch} is not multiple of GPU count {n}.\n'
f'Try batch={batch // n} or batch={batch // n + 1}')
raise ValueError(f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}.")
space = ' ' * (len(s) + 1)
for i, d in enumerate(devices):
p = torch.cuda.get_device_properties(i)

@ -74,7 +74,7 @@ def predict(cfg=DEFAULT_CFG, use_python=False):
from ultralytics import YOLO
YOLO(model)(**args)
else:
predictor = ClassificationPredictor(args)
predictor = ClassificationPredictor(overrides=args)
predictor.predict_cli()

@ -146,7 +146,7 @@ def train(cfg=DEFAULT_CFG, use_python=False):
from ultralytics import YOLO
YOLO(model).train(**args)
else:
trainer = ClassificationTrainer(args)
trainer = ClassificationTrainer(overrides=args)
trainer.train()

@ -92,7 +92,7 @@ def predict(cfg=DEFAULT_CFG, use_python=False):
from ultralytics import YOLO
YOLO(model)(**args)
else:
predictor = DetectionPredictor(args)
predictor = DetectionPredictor(overrides=args)
predictor.predict_cli()

@ -204,7 +204,7 @@ def train(cfg=DEFAULT_CFG, use_python=False):
from ultralytics import YOLO
YOLO(model).train(**args)
else:
trainer = DetectionTrainer(args)
trainer = DetectionTrainer(overrides=args)
trainer.train()

@ -110,7 +110,7 @@ def predict(cfg=DEFAULT_CFG, use_python=False):
from ultralytics import YOLO
YOLO(model)(**args)
else:
predictor = SegmentationPredictor(args)
predictor = SegmentationPredictor(overrides=args)
predictor.predict_cli()

@ -150,7 +150,7 @@ def train(cfg=DEFAULT_CFG, use_python=False):
from ultralytics import YOLO
YOLO(model).train(**args)
else:
trainer = SegmentationTrainer(args)
trainer = SegmentationTrainer(overrides=args)
trainer.train()

Loading…
Cancel
Save