You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
354 lines
14 KiB
354 lines
14 KiB
2 years ago
|
# TODO: Follow google docs format for all functions. Easier for automatic doc parser
|
||
|
|
||
|
import contextlib
|
||
|
import logging
|
||
|
import os
|
||
|
import platform
|
||
|
import subprocess
|
||
|
import urllib
|
||
|
from itertools import repeat
|
||
|
from multiprocessing.pool import ThreadPool
|
||
|
from pathlib import Path
|
||
|
from zipfile import ZipFile
|
||
|
|
||
|
import numpy as np
|
||
|
import pkg_resources as pkg
|
||
|
import requests
|
||
|
import torch
|
||
|
import yaml
|
||
|
|
||
|
FILE = Path(__file__).resolve()
|
||
|
ROOT = FILE.parents[2] # YOLOv5 root directory
|
||
|
RANK = int(os.getenv('RANK', -1))
|
||
|
|
||
|
# Settings
|
||
|
DATASETS_DIR = ROOT.parent / 'datasets' # YOLOv5 datasets directory
|
||
|
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
|
||
|
AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
|
||
|
VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode
|
||
|
FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
|
||
|
|
||
|
|
||
|
def is_colab():
|
||
|
# Is environment a Google Colab instance?
|
||
|
return "COLAB_GPU" in os.environ
|
||
|
|
||
|
|
||
|
def is_kaggle():
|
||
|
# Is environment a Kaggle Notebook?
|
||
|
return os.environ.get("PWD") == "/kaggle/working" and os.environ.get("KAGGLE_URL_BASE") == "https://www.kaggle.com"
|
||
|
|
||
|
|
||
|
def emojis(str=""):
|
||
|
# Return platform-dependent emoji-safe version of string
|
||
|
return str.encode().decode("ascii", "ignore") if platform.system() == "Windows" else str
|
||
|
|
||
|
|
||
|
def set_logging(name=None, verbose=VERBOSE):
|
||
|
# Sets level and returns logger
|
||
|
if is_kaggle() or is_colab():
|
||
|
for h in logging.root.handlers:
|
||
|
logging.root.removeHandler(h) # remove all handlers associated with the root logger object
|
||
|
rank = int(os.getenv("RANK", -1)) # rank in world for Multi-GPU trainings
|
||
|
level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
|
||
|
log = logging.getLogger(name)
|
||
|
log.setLevel(level)
|
||
|
handler = logging.StreamHandler()
|
||
|
handler.setFormatter(logging.Formatter("%(message)s"))
|
||
|
handler.setLevel(level)
|
||
|
log.addHandler(handler)
|
||
|
|
||
|
|
||
|
set_logging() # run before defining LOGGER
|
||
|
LOGGER = logging.getLogger("yolov5") # define globally (used in train.py, val.py, detect.py, etc.)
|
||
|
if platform.system() == "Windows":
|
||
|
for fn in LOGGER.info, LOGGER.warning:
|
||
|
setattr(LOGGER, fn.__name__, lambda x: fn(emojis(x))) # emoji safe logging
|
||
|
|
||
|
|
||
|
def segment2box(segment, width=640, height=640):
|
||
|
# Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
|
||
|
x, y = segment.T # segment xy
|
||
|
inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
|
||
|
x, y, = (
|
||
|
x[inside],
|
||
|
y[inside],
|
||
|
)
|
||
|
return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros(4) # xyxy
|
||
|
|
||
|
|
||
|
def check_version(current="0.0.0", minimum="0.0.0", name="version ", pinned=False, hard=False, verbose=False):
|
||
|
# Check version vs. required version
|
||
|
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
|
||
|
result = (current == minimum) if pinned else (current >= minimum) # bool
|
||
|
s = f"WARNING ⚠️ {name}{minimum} is required by YOLOv5, but {name}{current} is currently installed" # string
|
||
|
if hard:
|
||
|
assert result, emojis(s) # assert min requirements met
|
||
|
if verbose and not result:
|
||
|
LOGGER.warning(s)
|
||
|
return result
|
||
|
|
||
|
|
||
|
def colorstr(*input):
|
||
|
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
|
||
|
*args, string = input if len(input) > 1 else ("blue", "bold", input[0]) # color arguments, string
|
||
|
colors = {
|
||
|
"black": "\033[30m", # basic colors
|
||
|
"red": "\033[31m",
|
||
|
"green": "\033[32m",
|
||
|
"yellow": "\033[33m",
|
||
|
"blue": "\033[34m",
|
||
|
"magenta": "\033[35m",
|
||
|
"cyan": "\033[36m",
|
||
|
"white": "\033[37m",
|
||
|
"bright_black": "\033[90m", # bright colors
|
||
|
"bright_red": "\033[91m",
|
||
|
"bright_green": "\033[92m",
|
||
|
"bright_yellow": "\033[93m",
|
||
|
"bright_blue": "\033[94m",
|
||
|
"bright_magenta": "\033[95m",
|
||
|
"bright_cyan": "\033[96m",
|
||
|
"bright_white": "\033[97m",
|
||
|
"end": "\033[0m", # misc
|
||
|
"bold": "\033[1m",
|
||
|
"underline": "\033[4m",}
|
||
|
return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
|
||
|
|
||
|
|
||
|
def xyxy2xywh(x):
|
||
|
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
|
||
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||
|
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
|
||
|
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
|
||
|
y[:, 2] = x[:, 2] - x[:, 0] # width
|
||
|
y[:, 3] = x[:, 3] - x[:, 1] # height
|
||
|
return y
|
||
|
|
||
|
|
||
|
def xywh2xyxy(x):
|
||
|
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
||
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||
|
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
||
|
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
||
|
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
|
||
|
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
|
||
|
return y
|
||
|
|
||
|
|
||
|
def xywh2ltwh(x):
|
||
|
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, w, h] where xy1=top-left
|
||
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||
|
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
||
|
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
||
|
return y
|
||
|
|
||
|
|
||
|
def xyxy2ltwh(x):
|
||
|
# Convert nx4 boxes from [x1, y1, x2, y2] to [x1, y1, w, h] where xy1=top-left, xy2=bottom-right
|
||
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||
|
y[:, 2] = x[:, 2] - x[:, 0] # width
|
||
|
y[:, 3] = x[:, 3] - x[:, 1] # height
|
||
|
return y
|
||
|
|
||
|
|
||
|
def ltwh2xywh(x):
|
||
|
# Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center
|
||
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||
|
y[:, 0] = x[:, 0] + x[:, 2] / 2 # center x
|
||
|
y[:, 1] = x[:, 1] + x[:, 3] / 2 # center y
|
||
|
return y
|
||
|
|
||
|
|
||
|
def ltwh2xyxy(x):
|
||
|
# Convert nx4 boxes from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
||
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||
|
y[:, 2] = x[:, 2] + x[:, 0] # width
|
||
|
y[:, 3] = x[:, 3] + x[:, 1] # height
|
||
|
return y
|
||
|
|
||
|
|
||
|
def segments2boxes(segments):
|
||
|
# Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
|
||
|
boxes = []
|
||
|
for s in segments:
|
||
|
x, y = s.T # segment xy
|
||
|
boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
|
||
|
return xyxy2xywh(np.array(boxes)) # cls, xywh
|
||
|
|
||
|
|
||
|
def resample_segments(segments, n=1000):
|
||
|
# Up-sample an (n,2) segment
|
||
|
for i, s in enumerate(segments):
|
||
|
s = np.concatenate((s, s[0:1, :]), axis=0)
|
||
|
x = np.linspace(0, len(s) - 1, n)
|
||
|
xp = np.arange(len(s))
|
||
|
segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
|
||
|
return segments
|
||
|
|
||
|
|
||
|
def increment_path(path, exist_ok=False, sep='', mkdir=False):
|
||
|
"""
|
||
|
Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
|
||
|
# TODO: docs
|
||
|
"""
|
||
|
path = Path(path) # os-agnostic
|
||
|
if path.exists() and not exist_ok:
|
||
|
path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
|
||
|
|
||
|
# Method 1
|
||
|
for n in range(2, 9999):
|
||
|
p = f'{path}{sep}{n}{suffix}' # increment path
|
||
|
if not os.path.exists(p): #
|
||
|
break
|
||
|
path = Path(p)
|
||
|
|
||
|
if mkdir:
|
||
|
path.mkdir(parents=True, exist_ok=True) # make directory
|
||
|
|
||
|
return path
|
||
|
|
||
|
|
||
|
def save_yaml(file='data.yaml', data={}):
|
||
|
# Single-line safe yaml saving
|
||
|
with open(file, 'w') as f:
|
||
|
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
|
||
|
|
||
|
|
||
|
def download(url, dir=Path.cwd(), unzip=True, delete=True, curl=False, threads=1, retry=3):
|
||
|
# Multithreaded file download and unzip function, used in data.yaml for autodownload
|
||
|
def download_one(url, dir):
|
||
|
# Download 1 file
|
||
|
success = True
|
||
|
if Path(url).is_file():
|
||
|
f = Path(url) # filename
|
||
|
else: # does not exist
|
||
|
f = dir / Path(url).name
|
||
|
LOGGER.info(f'Downloading {url} to {f}...')
|
||
|
for i in range(retry + 1):
|
||
|
if curl:
|
||
|
s = 'sS' if threads > 1 else '' # silent
|
||
|
r = os.system(
|
||
|
f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -') # curl download with retry, continue
|
||
|
success = r == 0
|
||
|
else:
|
||
|
torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
|
||
|
success = f.is_file()
|
||
|
if success:
|
||
|
break
|
||
|
elif i < retry:
|
||
|
LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
|
||
|
else:
|
||
|
LOGGER.warning(f'❌ Failed to download {url}...')
|
||
|
|
||
|
if unzip and success and f.suffix in ('.zip', '.tar', '.gz'):
|
||
|
LOGGER.info(f'Unzipping {f}...')
|
||
|
if f.suffix == '.zip':
|
||
|
ZipFile(f).extractall(path=dir) # unzip
|
||
|
elif f.suffix == '.tar':
|
||
|
os.system(f'tar xf {f} --directory {f.parent}') # unzip
|
||
|
elif f.suffix == '.gz':
|
||
|
os.system(f'tar xfz {f} --directory {f.parent}') # unzip
|
||
|
if delete:
|
||
|
f.unlink() # remove zip
|
||
|
|
||
|
dir = Path(dir)
|
||
|
dir.mkdir(parents=True, exist_ok=True) # make directory
|
||
|
if threads > 1:
|
||
|
pool = ThreadPool(threads)
|
||
|
pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
|
||
|
pool.close()
|
||
|
pool.join()
|
||
|
else:
|
||
|
for u in [url] if isinstance(url, (str, Path)) else url:
|
||
|
download_one(u, dir)
|
||
|
|
||
|
|
||
|
class WorkingDirectory(contextlib.ContextDecorator):
|
||
|
# Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
|
||
|
def __init__(self, new_dir):
|
||
|
self.dir = new_dir # new dir
|
||
|
self.cwd = Path.cwd().resolve() # current dir
|
||
|
|
||
|
def __enter__(self):
|
||
|
os.chdir(self.dir)
|
||
|
|
||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||
|
os.chdir(self.cwd)
|
||
|
|
||
|
|
||
|
def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
|
||
|
# Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
|
||
|
from utils.general import LOGGER
|
||
|
|
||
|
file = Path(file)
|
||
|
assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
|
||
|
try: # url1
|
||
|
LOGGER.info(f'Downloading {url} to {file}...')
|
||
|
torch.hub.download_url_to_file(url, str(file), progress=LOGGER.level <= logging.INFO)
|
||
|
assert file.exists() and file.stat().st_size > min_bytes, assert_msg # check
|
||
|
except Exception as e: # url2
|
||
|
if file.exists():
|
||
|
file.unlink() # remove partial downloads
|
||
|
LOGGER.info(f'ERROR: {e}\nRe-attempting {url2 or url} to {file}...')
|
||
|
os.system(f"curl -# -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
|
||
|
finally:
|
||
|
if not file.exists() or file.stat().st_size < min_bytes: # check
|
||
|
if file.exists():
|
||
|
file.unlink() # remove partial downloads
|
||
|
LOGGER.info(f"ERROR: {assert_msg}\n{error_msg}")
|
||
|
LOGGER.info('')
|
||
|
|
||
|
|
||
|
def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'):
|
||
|
# Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
|
||
|
from utils.general import LOGGER
|
||
|
|
||
|
def github_assets(repository, version='latest'):
|
||
|
# Return GitHub repo tag and assets (i.e. ['yolov5s.pt', 'yolov5m.pt', ...])
|
||
|
if version != 'latest':
|
||
|
version = f'tags/{version}' # i.e. tags/v6.2
|
||
|
response = requests.get(f'https://api.github.com/repos/{repository}/releases/{version}').json() # github api
|
||
|
return response['tag_name'], [x['name'] for x in response['assets']] # tag, assets
|
||
|
|
||
|
file = Path(str(file).strip().replace("'", ''))
|
||
|
if not file.exists():
|
||
|
# URL specified
|
||
|
name = Path(urllib.parse.unquote(str(file))).name # decode '%2F' to '/' etc.
|
||
|
if str(file).startswith(('http:/', 'https:/')): # download
|
||
|
url = str(file).replace(':/', '://') # Pathlib turns :// -> :/
|
||
|
file = name.split('?')[0] # parse authentication https://url.com/file.txt?auth...
|
||
|
if Path(file).is_file():
|
||
|
LOGGER.info(f'Found {url} locally at {file}') # file already exists
|
||
|
else:
|
||
|
safe_download(file=file, url=url, min_bytes=1E5)
|
||
|
return file
|
||
|
|
||
|
# GitHub assets
|
||
|
assets = [f'yolov5{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] # default
|
||
|
try:
|
||
|
tag, assets = github_assets(repo, release)
|
||
|
except Exception:
|
||
|
try:
|
||
|
tag, assets = github_assets(repo) # latest release
|
||
|
except Exception:
|
||
|
try:
|
||
|
tag = subprocess.check_output('git tag', shell=True, stderr=subprocess.STDOUT).decode().split()[-1]
|
||
|
except Exception:
|
||
|
tag = release
|
||
|
|
||
|
file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
|
||
|
if name in assets:
|
||
|
url3 = 'https://drive.google.com/drive/folders/1EFQTEUeXWSFww0luse2jB9M1QNZQGwNl' # backup gdrive mirror
|
||
|
safe_download(
|
||
|
file,
|
||
|
url=f'https://github.com/{repo}/releases/download/{tag}/{name}',
|
||
|
min_bytes=1E5,
|
||
|
error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/{tag} or {url3}')
|
||
|
|
||
|
return str(file)
|
||
|
|
||
|
|
||
|
def get_model(model: str):
|
||
|
# check for local weights
|
||
|
pass
|