Trainer + Dataloaders (#27)
Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Ayush Chaurasia <ayushchaurasia@Ayushs-MacBook-Pro.local> Co-authored-by: Ayush Chaurasia <ayush.chuararsia@gmail.com>
This commit is contained in:
@ -0,0 +1,17 @@
|
||||
from .general import WorkingDirectory, check_version, download, increment_path, save_yaml
|
||||
from .torch_utils import LOCAL_RANK, RANK, WORLD_SIZE, DDP_model, select_device, torch_distributed_zero_first
|
||||
|
||||
__all__ = [
|
||||
# general
|
||||
"increment_path",
|
||||
"save_yaml",
|
||||
"WorkingDirectory",
|
||||
"download",
|
||||
"check_version",
|
||||
# torch
|
||||
"torch_distributed_zero_first",
|
||||
"LOCAL_RANK",
|
||||
"RANK",
|
||||
"WORLD_SIZE",
|
||||
"DDP_model",
|
||||
"select_device"]
|
||||
|
0
ultralytics/yolo/utils/configs/__init__.py
Normal file
0
ultralytics/yolo/utils/configs/__init__.py
Normal file
53
ultralytics/yolo/utils/configs/defaults.yaml
Normal file
53
ultralytics/yolo/utils/configs/defaults.yaml
Normal file
@ -0,0 +1,53 @@
|
||||
train:
|
||||
epochs: 300
|
||||
batch_size: 16
|
||||
img_size: 640
|
||||
nosave: False
|
||||
cache: False # True/ram for ram, or disc
|
||||
device: '' # cuda device, i.e. 0 or 0,1,2,3 or cpu
|
||||
workers: 8
|
||||
project: "ultralytics-yolo"
|
||||
name: "exp" # TODO: make this informative, maybe exp{#number}_{datetime} ?
|
||||
exist_ok: False
|
||||
pretrained: False
|
||||
optimizer: "Adam" # choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
|
||||
verbose: False
|
||||
seed: 0
|
||||
local_rank: -1
|
||||
|
||||
hyps:
|
||||
lr0: 0.001 # initial learning rate (SGD=1E-2, Adam=1E-3)
|
||||
lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf)
|
||||
momentum: 0.937 # SGD momentum/Adam beta1
|
||||
weight_decay: 0.0005 # optimizer weight decay 5e-4
|
||||
warmup_epochs: 3.0 # warmup epochs (fractions ok)
|
||||
warmup_momentum: 0.8 # warmup initial momentum
|
||||
warmup_bias_lr: 0.1 # warmup initial bias lr
|
||||
box: 0.05 # box loss gain
|
||||
cls: 0.5 # cls loss gain
|
||||
cls_pw: 1.0 # cls BCELoss positive_weight
|
||||
obj: 1.0 # obj loss gain (scale with pixels)
|
||||
obj_pw: 1.0 # obj BCELoss positive_weight
|
||||
iou_t: 0.20 # IoU training threshold
|
||||
anchor_t: 4.0 # anchor-multiple threshold
|
||||
# anchors: 3 # anchors per output layer (0 to ignore)
|
||||
fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
|
||||
hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
|
||||
hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
|
||||
hsv_v: 0.4 # image HSV-Value augmentation (fraction)
|
||||
degrees: 0.0 # image rotation (+/- deg)
|
||||
translate: 0.1 # image translation (+/- fraction)
|
||||
scale: 0.5 # image scale (+/- gain)
|
||||
shear: 0.0 # image shear (+/- deg)
|
||||
perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
|
||||
flipud: 0.0 # image flip up-down (probability)
|
||||
fliplr: 0.5 # image flip left-right (probability)
|
||||
mosaic: 1.0 # image mosaic (probability)
|
||||
mixup: 0.0 # image mixup (probability)
|
||||
copy_paste: 0.0 # segment copy-paste (probability)
|
||||
|
||||
# to disable hydra directory creation
|
||||
hydra:
|
||||
output_subdir: null
|
||||
run:
|
||||
dir: .
|
353
ultralytics/yolo/utils/general.py
Normal file
353
ultralytics/yolo/utils/general.py
Normal file
@ -0,0 +1,353 @@
|
||||
# TODO: Follow google docs format for all functions. Easier for automatic doc parser
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import urllib
|
||||
from itertools import repeat
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
from zipfile import ZipFile
|
||||
|
||||
import numpy as np
|
||||
import pkg_resources as pkg
|
||||
import requests
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
FILE = Path(__file__).resolve()
|
||||
ROOT = FILE.parents[2] # YOLOv5 root directory
|
||||
RANK = int(os.getenv('RANK', -1))
|
||||
|
||||
# Settings
|
||||
DATASETS_DIR = ROOT.parent / 'datasets' # YOLOv5 datasets directory
|
||||
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
|
||||
AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
|
||||
VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode
|
||||
FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
|
||||
|
||||
|
||||
def is_colab():
|
||||
# Is environment a Google Colab instance?
|
||||
return "COLAB_GPU" in os.environ
|
||||
|
||||
|
||||
def is_kaggle():
|
||||
# Is environment a Kaggle Notebook?
|
||||
return os.environ.get("PWD") == "/kaggle/working" and os.environ.get("KAGGLE_URL_BASE") == "https://www.kaggle.com"
|
||||
|
||||
|
||||
def emojis(str=""):
|
||||
# Return platform-dependent emoji-safe version of string
|
||||
return str.encode().decode("ascii", "ignore") if platform.system() == "Windows" else str
|
||||
|
||||
|
||||
def set_logging(name=None, verbose=VERBOSE):
|
||||
# Sets level and returns logger
|
||||
if is_kaggle() or is_colab():
|
||||
for h in logging.root.handlers:
|
||||
logging.root.removeHandler(h) # remove all handlers associated with the root logger object
|
||||
rank = int(os.getenv("RANK", -1)) # rank in world for Multi-GPU trainings
|
||||
level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
|
||||
log = logging.getLogger(name)
|
||||
log.setLevel(level)
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
handler.setLevel(level)
|
||||
log.addHandler(handler)
|
||||
|
||||
|
||||
set_logging() # run before defining LOGGER
|
||||
LOGGER = logging.getLogger("yolov5") # define globally (used in train.py, val.py, detect.py, etc.)
|
||||
if platform.system() == "Windows":
|
||||
for fn in LOGGER.info, LOGGER.warning:
|
||||
setattr(LOGGER, fn.__name__, lambda x: fn(emojis(x))) # emoji safe logging
|
||||
|
||||
|
||||
def segment2box(segment, width=640, height=640):
|
||||
# Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
|
||||
x, y = segment.T # segment xy
|
||||
inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
|
||||
x, y, = (
|
||||
x[inside],
|
||||
y[inside],
|
||||
)
|
||||
return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros(4) # xyxy
|
||||
|
||||
|
||||
def check_version(current="0.0.0", minimum="0.0.0", name="version ", pinned=False, hard=False, verbose=False):
|
||||
# Check version vs. required version
|
||||
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
|
||||
result = (current == minimum) if pinned else (current >= minimum) # bool
|
||||
s = f"WARNING ⚠️ {name}{minimum} is required by YOLOv5, but {name}{current} is currently installed" # string
|
||||
if hard:
|
||||
assert result, emojis(s) # assert min requirements met
|
||||
if verbose and not result:
|
||||
LOGGER.warning(s)
|
||||
return result
|
||||
|
||||
|
||||
def colorstr(*input):
|
||||
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
|
||||
*args, string = input if len(input) > 1 else ("blue", "bold", input[0]) # color arguments, string
|
||||
colors = {
|
||||
"black": "\033[30m", # basic colors
|
||||
"red": "\033[31m",
|
||||
"green": "\033[32m",
|
||||
"yellow": "\033[33m",
|
||||
"blue": "\033[34m",
|
||||
"magenta": "\033[35m",
|
||||
"cyan": "\033[36m",
|
||||
"white": "\033[37m",
|
||||
"bright_black": "\033[90m", # bright colors
|
||||
"bright_red": "\033[91m",
|
||||
"bright_green": "\033[92m",
|
||||
"bright_yellow": "\033[93m",
|
||||
"bright_blue": "\033[94m",
|
||||
"bright_magenta": "\033[95m",
|
||||
"bright_cyan": "\033[96m",
|
||||
"bright_white": "\033[97m",
|
||||
"end": "\033[0m", # misc
|
||||
"bold": "\033[1m",
|
||||
"underline": "\033[4m",}
|
||||
return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
|
||||
|
||||
|
||||
def xyxy2xywh(x):
|
||||
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
|
||||
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
|
||||
y[:, 2] = x[:, 2] - x[:, 0] # width
|
||||
y[:, 3] = x[:, 3] - x[:, 1] # height
|
||||
return y
|
||||
|
||||
|
||||
def xywh2xyxy(x):
|
||||
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
||||
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
||||
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
|
||||
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
|
||||
return y
|
||||
|
||||
|
||||
def xywh2ltwh(x):
|
||||
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, w, h] where xy1=top-left
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
||||
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
||||
return y
|
||||
|
||||
|
||||
def xyxy2ltwh(x):
|
||||
# Convert nx4 boxes from [x1, y1, x2, y2] to [x1, y1, w, h] where xy1=top-left, xy2=bottom-right
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[:, 2] = x[:, 2] - x[:, 0] # width
|
||||
y[:, 3] = x[:, 3] - x[:, 1] # height
|
||||
return y
|
||||
|
||||
|
||||
def ltwh2xywh(x):
|
||||
# Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[:, 0] = x[:, 0] + x[:, 2] / 2 # center x
|
||||
y[:, 1] = x[:, 1] + x[:, 3] / 2 # center y
|
||||
return y
|
||||
|
||||
|
||||
def ltwh2xyxy(x):
|
||||
# Convert nx4 boxes from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[:, 2] = x[:, 2] + x[:, 0] # width
|
||||
y[:, 3] = x[:, 3] + x[:, 1] # height
|
||||
return y
|
||||
|
||||
|
||||
def segments2boxes(segments):
|
||||
# Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
|
||||
boxes = []
|
||||
for s in segments:
|
||||
x, y = s.T # segment xy
|
||||
boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
|
||||
return xyxy2xywh(np.array(boxes)) # cls, xywh
|
||||
|
||||
|
||||
def resample_segments(segments, n=1000):
|
||||
# Up-sample an (n,2) segment
|
||||
for i, s in enumerate(segments):
|
||||
s = np.concatenate((s, s[0:1, :]), axis=0)
|
||||
x = np.linspace(0, len(s) - 1, n)
|
||||
xp = np.arange(len(s))
|
||||
segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
|
||||
return segments
|
||||
|
||||
|
||||
def increment_path(path, exist_ok=False, sep='', mkdir=False):
|
||||
"""
|
||||
Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
|
||||
# TODO: docs
|
||||
"""
|
||||
path = Path(path) # os-agnostic
|
||||
if path.exists() and not exist_ok:
|
||||
path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
|
||||
|
||||
# Method 1
|
||||
for n in range(2, 9999):
|
||||
p = f'{path}{sep}{n}{suffix}' # increment path
|
||||
if not os.path.exists(p): #
|
||||
break
|
||||
path = Path(p)
|
||||
|
||||
if mkdir:
|
||||
path.mkdir(parents=True, exist_ok=True) # make directory
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def save_yaml(file='data.yaml', data={}):
|
||||
# Single-line safe yaml saving
|
||||
with open(file, 'w') as f:
|
||||
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
|
||||
|
||||
|
||||
def download(url, dir=Path.cwd(), unzip=True, delete=True, curl=False, threads=1, retry=3):
|
||||
# Multithreaded file download and unzip function, used in data.yaml for autodownload
|
||||
def download_one(url, dir):
|
||||
# Download 1 file
|
||||
success = True
|
||||
if Path(url).is_file():
|
||||
f = Path(url) # filename
|
||||
else: # does not exist
|
||||
f = dir / Path(url).name
|
||||
LOGGER.info(f'Downloading {url} to {f}...')
|
||||
for i in range(retry + 1):
|
||||
if curl:
|
||||
s = 'sS' if threads > 1 else '' # silent
|
||||
r = os.system(
|
||||
f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -') # curl download with retry, continue
|
||||
success = r == 0
|
||||
else:
|
||||
torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
|
||||
success = f.is_file()
|
||||
if success:
|
||||
break
|
||||
elif i < retry:
|
||||
LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
|
||||
else:
|
||||
LOGGER.warning(f'❌ Failed to download {url}...')
|
||||
|
||||
if unzip and success and f.suffix in ('.zip', '.tar', '.gz'):
|
||||
LOGGER.info(f'Unzipping {f}...')
|
||||
if f.suffix == '.zip':
|
||||
ZipFile(f).extractall(path=dir) # unzip
|
||||
elif f.suffix == '.tar':
|
||||
os.system(f'tar xf {f} --directory {f.parent}') # unzip
|
||||
elif f.suffix == '.gz':
|
||||
os.system(f'tar xfz {f} --directory {f.parent}') # unzip
|
||||
if delete:
|
||||
f.unlink() # remove zip
|
||||
|
||||
dir = Path(dir)
|
||||
dir.mkdir(parents=True, exist_ok=True) # make directory
|
||||
if threads > 1:
|
||||
pool = ThreadPool(threads)
|
||||
pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
|
||||
pool.close()
|
||||
pool.join()
|
||||
else:
|
||||
for u in [url] if isinstance(url, (str, Path)) else url:
|
||||
download_one(u, dir)
|
||||
|
||||
|
||||
class WorkingDirectory(contextlib.ContextDecorator):
|
||||
# Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
|
||||
def __init__(self, new_dir):
|
||||
self.dir = new_dir # new dir
|
||||
self.cwd = Path.cwd().resolve() # current dir
|
||||
|
||||
def __enter__(self):
|
||||
os.chdir(self.dir)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
os.chdir(self.cwd)
|
||||
|
||||
|
||||
def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
|
||||
# Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
|
||||
from utils.general import LOGGER
|
||||
|
||||
file = Path(file)
|
||||
assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
|
||||
try: # url1
|
||||
LOGGER.info(f'Downloading {url} to {file}...')
|
||||
torch.hub.download_url_to_file(url, str(file), progress=LOGGER.level <= logging.INFO)
|
||||
assert file.exists() and file.stat().st_size > min_bytes, assert_msg # check
|
||||
except Exception as e: # url2
|
||||
if file.exists():
|
||||
file.unlink() # remove partial downloads
|
||||
LOGGER.info(f'ERROR: {e}\nRe-attempting {url2 or url} to {file}...')
|
||||
os.system(f"curl -# -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
|
||||
finally:
|
||||
if not file.exists() or file.stat().st_size < min_bytes: # check
|
||||
if file.exists():
|
||||
file.unlink() # remove partial downloads
|
||||
LOGGER.info(f"ERROR: {assert_msg}\n{error_msg}")
|
||||
LOGGER.info('')
|
||||
|
||||
|
||||
def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'):
|
||||
# Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
|
||||
from utils.general import LOGGER
|
||||
|
||||
def github_assets(repository, version='latest'):
|
||||
# Return GitHub repo tag and assets (i.e. ['yolov5s.pt', 'yolov5m.pt', ...])
|
||||
if version != 'latest':
|
||||
version = f'tags/{version}' # i.e. tags/v6.2
|
||||
response = requests.get(f'https://api.github.com/repos/{repository}/releases/{version}').json() # github api
|
||||
return response['tag_name'], [x['name'] for x in response['assets']] # tag, assets
|
||||
|
||||
file = Path(str(file).strip().replace("'", ''))
|
||||
if not file.exists():
|
||||
# URL specified
|
||||
name = Path(urllib.parse.unquote(str(file))).name # decode '%2F' to '/' etc.
|
||||
if str(file).startswith(('http:/', 'https:/')): # download
|
||||
url = str(file).replace(':/', '://') # Pathlib turns :// -> :/
|
||||
file = name.split('?')[0] # parse authentication https://url.com/file.txt?auth...
|
||||
if Path(file).is_file():
|
||||
LOGGER.info(f'Found {url} locally at {file}') # file already exists
|
||||
else:
|
||||
safe_download(file=file, url=url, min_bytes=1E5)
|
||||
return file
|
||||
|
||||
# GitHub assets
|
||||
assets = [f'yolov5{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] # default
|
||||
try:
|
||||
tag, assets = github_assets(repo, release)
|
||||
except Exception:
|
||||
try:
|
||||
tag, assets = github_assets(repo) # latest release
|
||||
except Exception:
|
||||
try:
|
||||
tag = subprocess.check_output('git tag', shell=True, stderr=subprocess.STDOUT).decode().split()[-1]
|
||||
except Exception:
|
||||
tag = release
|
||||
|
||||
file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
|
||||
if name in assets:
|
||||
url3 = 'https://drive.google.com/drive/folders/1EFQTEUeXWSFww0luse2jB9M1QNZQGwNl' # backup gdrive mirror
|
||||
safe_download(
|
||||
file,
|
||||
url=f'https://github.com/{repo}/releases/download/{tag}/{name}',
|
||||
min_bytes=1E5,
|
||||
error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/{tag} or {url3}')
|
||||
|
||||
return str(file)
|
||||
|
||||
|
||||
def get_model(model: str):
|
||||
# check for local weights
|
||||
pass
|
326
ultralytics/yolo/utils/instance.py
Normal file
326
ultralytics/yolo/utils/instance.py
Normal file
@ -0,0 +1,326 @@
|
||||
from collections import abc
|
||||
from itertools import repeat
|
||||
from numbers import Number
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .general import ltwh2xywh, ltwh2xyxy, resample_segments, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
def _ntuple(n):
|
||||
|
||||
def parse(x):
|
||||
return x if isinstance(x, abc.Iterable) else tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
to_4tuple = _ntuple(4)
|
||||
|
||||
# `xyxy` means left top and right bottom
|
||||
# `xywh` means center x, center y and width, height(yolo format)
|
||||
# `ltwh` means left top and width, height(coco format)
|
||||
_formats = ["xyxy", "xywh", "ltwh"]
|
||||
|
||||
__all__ = ["Bboxes"]
|
||||
|
||||
|
||||
class Bboxes:
|
||||
"""Now only numpy is supported"""
|
||||
|
||||
def __init__(self, bboxes, format="xyxy") -> None:
|
||||
assert format in _formats
|
||||
bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes
|
||||
assert bboxes.ndim == 2
|
||||
assert bboxes.shape[1] == 4
|
||||
self.bboxes = bboxes
|
||||
self.format = format
|
||||
# self.normalized = normalized
|
||||
|
||||
# def convert(self, format):
|
||||
# assert format in _formats
|
||||
# if self.format == format:
|
||||
# bboxes = self.bboxes
|
||||
# elif self.format == "xyxy":
|
||||
# if format == "xywh":
|
||||
# bboxes = xyxy2xywh(self.bboxes)
|
||||
# else:
|
||||
# bboxes = xyxy2ltwh(self.bboxes)
|
||||
# elif self.format == "xywh":
|
||||
# if format == "xyxy":
|
||||
# bboxes = xywh2xyxy(self.bboxes)
|
||||
# else:
|
||||
# bboxes = xywh2ltwh(self.bboxes)
|
||||
# else:
|
||||
# if format == "xyxy":
|
||||
# bboxes = ltwh2xyxy(self.bboxes)
|
||||
# else:
|
||||
# bboxes = ltwh2xywh(self.bboxes)
|
||||
#
|
||||
# return Bboxes(bboxes, format)
|
||||
|
||||
def convert(self, format):
|
||||
assert format in _formats
|
||||
if self.format == format:
|
||||
return
|
||||
elif self.format == "xyxy":
|
||||
bboxes = xyxy2xywh(self.bboxes) if format == "xywh" else xyxy2ltwh(self.bboxes)
|
||||
elif self.format == "xywh":
|
||||
bboxes = xywh2xyxy(self.bboxes) if format == "xyxy" else xywh2ltwh(self.bboxes)
|
||||
else:
|
||||
bboxes = ltwh2xyxy(self.bboxes) if format == "xyxy" else ltwh2xywh(self.bboxes)
|
||||
self.bboxes = bboxes
|
||||
self.format = format
|
||||
|
||||
def areas(self):
|
||||
self.convert("xyxy")
|
||||
return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1])
|
||||
|
||||
# def denormalize(self, w, h):
|
||||
# if not self.normalized:
|
||||
# return
|
||||
# assert (self.bboxes <= 1.0).all()
|
||||
# self.bboxes[:, 0::2] *= w
|
||||
# self.bboxes[:, 1::2] *= h
|
||||
# self.normalized = False
|
||||
#
|
||||
# def normalize(self, w, h):
|
||||
# if self.normalized:
|
||||
# return
|
||||
# assert (self.bboxes > 1.0).any()
|
||||
# self.bboxes[:, 0::2] /= w
|
||||
# self.bboxes[:, 1::2] /= h
|
||||
# self.normalized = True
|
||||
|
||||
def mul(self, scale):
|
||||
"""
|
||||
Args:
|
||||
scale (tuple | List | int): the scale for four coords.
|
||||
"""
|
||||
if isinstance(scale, Number):
|
||||
scale = to_4tuple(scale)
|
||||
assert isinstance(scale, (tuple, list))
|
||||
assert len(scale) == 4
|
||||
self.bboxes[:, 0] *= scale[0]
|
||||
self.bboxes[:, 1] *= scale[1]
|
||||
self.bboxes[:, 2] *= scale[2]
|
||||
self.bboxes[:, 3] *= scale[3]
|
||||
|
||||
def add(self, offset):
|
||||
"""
|
||||
Args:
|
||||
offset (tuple | List | int): the offset for four coords.
|
||||
"""
|
||||
if isinstance(offset, Number):
|
||||
offset = to_4tuple(offset)
|
||||
assert isinstance(offset, (tuple, list))
|
||||
assert len(offset) == 4
|
||||
self.bboxes[:, 0] += offset[0]
|
||||
self.bboxes[:, 1] += offset[1]
|
||||
self.bboxes[:, 2] += offset[2]
|
||||
self.bboxes[:, 3] += offset[3]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.bboxes)
|
||||
|
||||
@classmethod
|
||||
def concatenate(cls, boxes_list: List["Bboxes"], axis=0) -> "Bboxes":
|
||||
"""
|
||||
Concatenates a list of Boxes into a single Bboxes
|
||||
|
||||
Arguments:
|
||||
boxes_list (list[Bboxes])
|
||||
|
||||
Returns:
|
||||
Bboxes: the concatenated Boxes
|
||||
"""
|
||||
assert isinstance(boxes_list, (list, tuple))
|
||||
if not boxes_list:
|
||||
return cls(np.empty(0))
|
||||
assert all(isinstance(box, Bboxes) for box in boxes_list)
|
||||
|
||||
if len(boxes_list) == 1:
|
||||
return boxes_list[0]
|
||||
return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))
|
||||
|
||||
def __getitem__(self, index) -> "Bboxes":
|
||||
"""
|
||||
Args:
|
||||
index: int, slice, or a BoolArray
|
||||
|
||||
Returns:
|
||||
Bboxes: Create a new :class:`Bboxes` by indexing.
|
||||
"""
|
||||
if isinstance(index, int):
|
||||
return Bboxes(self.bboxes[index].view(1, -1))
|
||||
b = self.bboxes[index]
|
||||
assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!"
|
||||
return Bboxes(b)
|
||||
|
||||
|
||||
class Instances:
|
||||
|
||||
def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:
|
||||
"""
|
||||
Args:
|
||||
bboxes (ndarray): bboxes with shape [N, 4].
|
||||
segments (list | ndarray): segments.
|
||||
keypoints (ndarray): keypoints with shape [N, 17, 2].
|
||||
"""
|
||||
self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format)
|
||||
self.keypoints = keypoints
|
||||
self.normalized = normalized
|
||||
|
||||
if isinstance(segments, list) and len(segments) > 0:
|
||||
# list[np.array(1000, 2)] * num_samples
|
||||
segments = resample_segments(segments)
|
||||
# (N, 1000, 2)
|
||||
segments = np.stack(segments, axis=0)
|
||||
self.segments = segments
|
||||
|
||||
def convert_bbox(self, format):
|
||||
self._bboxes.convert(format=format)
|
||||
|
||||
def bbox_areas(self):
|
||||
self._bboxes.areas()
|
||||
|
||||
def scale(self, scale_w, scale_h, bbox_only=False):
|
||||
"""this might be similar with denormalize func but without normalized sign"""
|
||||
self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h))
|
||||
if bbox_only:
|
||||
return
|
||||
if self.segments is not None:
|
||||
self.segments[..., 0] *= scale_w
|
||||
self.segments[..., 1] *= scale_h
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 0] *= scale_w
|
||||
self.keypoints[..., 1] *= scale_h
|
||||
|
||||
def denormalize(self, w, h):
|
||||
if not self.normalized:
|
||||
return
|
||||
self._bboxes.mul(scale=(w, h, w, h))
|
||||
if self.segments is not None:
|
||||
self.segments[..., 0] *= w
|
||||
self.segments[..., 1] *= h
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 0] *= w
|
||||
self.keypoints[..., 1] *= h
|
||||
self.normalized = False
|
||||
|
||||
def normalize(self, w, h):
|
||||
if self.normalized:
|
||||
return
|
||||
self._bboxes.mul(scale=(1 / w, 1 / h, 1 / w, 1 / h))
|
||||
if self.segments is not None:
|
||||
self.segments[..., 0] /= w
|
||||
self.segments[..., 1] /= h
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 0] /= w
|
||||
self.keypoints[..., 1] /= h
|
||||
self.normalized = True
|
||||
|
||||
def add_padding(self, padw, padh):
|
||||
# handle rect and mosaic situation
|
||||
assert not self.normalized, "you should add padding with absolute coordinates."
|
||||
self._bboxes.add(offset=(padw, padh, padw, padh))
|
||||
if self.segments is not None:
|
||||
self.segments[..., 0] += padw
|
||||
self.segments[..., 1] += padh
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 0] += padw
|
||||
self.keypoints[..., 1] += padh
|
||||
|
||||
def __getitem__(self, index) -> "Instances":
|
||||
"""
|
||||
Args:
|
||||
index: int, slice, or a BoolArray
|
||||
|
||||
Returns:
|
||||
Instances: Create a new :class:`Instances` by indexing.
|
||||
"""
|
||||
segments = self.segments[index] if self.segments is not None else None
|
||||
keypoints = self.keypoints[index] if self.keypoints is not None else None
|
||||
bboxes = self.bboxes[index]
|
||||
bbox_format = self._bboxes.format
|
||||
return Instances(
|
||||
bboxes=bboxes,
|
||||
segments=segments,
|
||||
keypoints=keypoints,
|
||||
bbox_format=bbox_format,
|
||||
normalized=self.normalized,
|
||||
)
|
||||
|
||||
def flipud(self, h):
|
||||
# this function may not be very logical, just for clean code when using augment flipud
|
||||
self.bboxes[:, 1] = h - self.bboxes[:, 1]
|
||||
if self.segments is not None:
|
||||
self.segments[..., 1] = h - self.segments[..., 1]
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 1] = h - self.keypoints[..., 1]
|
||||
|
||||
def fliplr(self, w):
|
||||
# this function may not be very logical, just for clean code when using augment fliplr
|
||||
self.bboxes[:, 0] = w - self.bboxes[:, 0]
|
||||
if self.segments is not None:
|
||||
self.segments[..., 0] = w - self.segments[..., 0]
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 0] = w - self.keypoints[..., 0]
|
||||
|
||||
def clip(self, w, h):
|
||||
self.convert_bbox(format="xyxy")
|
||||
self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
|
||||
self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)
|
||||
if self.segments is not None:
|
||||
self.segments[..., 0] = self.segments[..., 0].clip(0, w)
|
||||
self.segments[..., 1] = self.segments[..., 1].clip(0, h)
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 0] = self.keypoints[..., 0].clip(0, w)
|
||||
self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h)
|
||||
|
||||
def update(self, bboxes, segments=None, keypoints=None):
|
||||
new_bboxes = Bboxes(bboxes, format=self._bboxes.format)
|
||||
self._bboxes = new_bboxes
|
||||
if segments is not None:
|
||||
self.segments = segments
|
||||
if keypoints is not None:
|
||||
self.keypoints = keypoints
|
||||
|
||||
def __len__(self):
|
||||
return len(self.bboxes)
|
||||
|
||||
@classmethod
|
||||
def concatenate(cls, instances_list: List["Instances"], axis=0) -> "Instances":
|
||||
"""
|
||||
Concatenates a list of Boxes into a single Bboxes
|
||||
|
||||
Arguments:
|
||||
instances_list (list[Bboxes])
|
||||
axis
|
||||
|
||||
Returns:
|
||||
Boxes: the concatenated Boxes
|
||||
"""
|
||||
assert isinstance(instances_list, (list, tuple))
|
||||
if not instances_list:
|
||||
return cls(np.empty(0))
|
||||
assert all(isinstance(instance, Instances) for instance in instances_list)
|
||||
|
||||
if len(instances_list) == 1:
|
||||
return instances_list[0]
|
||||
|
||||
use_segment = instances_list[0].segments is not None
|
||||
use_keypoint = instances_list[0].keypoints is not None
|
||||
bbox_format = instances_list[0]._bboxes.format
|
||||
normalized = instances_list[0].normalized
|
||||
|
||||
cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis)
|
||||
cat_segments = np.concatenate([b.segments for b in instances_list], axis=axis) if use_segment else None
|
||||
cat_keypoints = np.concatenate([b.keypoints for b in instances_list], axis=axis) if use_keypoint else None
|
||||
return cls(cat_boxes, cat_segments, cat_keypoints, bbox_format, normalized)
|
||||
|
||||
@property
|
||||
def bboxes(self):
|
||||
return self._bboxes.bboxes
|
3
ultralytics/yolo/utils/loggers/__init__.py
Normal file
3
ultralytics/yolo/utils/loggers/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .base import default_callbacks
|
||||
|
||||
__all__ = ["default_callbacks"]
|
32
ultralytics/yolo/utils/loggers/base.py
Normal file
32
ultralytics/yolo/utils/loggers/base.py
Normal file
@ -0,0 +1,32 @@
|
||||
def before_train(trainer):
|
||||
# Initialize tensorboard logger
|
||||
pass
|
||||
|
||||
|
||||
def on_epoch_start(trainer):
|
||||
pass
|
||||
|
||||
|
||||
def on_batch_start(trainer):
|
||||
pass
|
||||
|
||||
|
||||
def on_val_start(trainer):
|
||||
pass
|
||||
|
||||
|
||||
def on_val_end(trainer):
|
||||
pass
|
||||
|
||||
|
||||
def on_model_save(trainer):
|
||||
pass
|
||||
|
||||
|
||||
default_callbacks = {
|
||||
"before_train": before_train,
|
||||
"on_epoch_start": on_epoch_start,
|
||||
"on_batch_start": on_batch_start,
|
||||
"on_val_start": on_val_start,
|
||||
"on_val_end": on_val_end,
|
||||
"on_model_save": on_model_save}
|
27
ultralytics/yolo/utils/metrics.py
Normal file
27
ultralytics/yolo/utils/metrics.py
Normal file
@ -0,0 +1,27 @@
|
||||
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
||||
"""
|
||||
Model validation metrics
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
|
||||
def bbox_ioa(box1, box2, eps=1e-7):
|
||||
"""Returns the intersection over box2 area given box1, box2. Boxes are x1y1x2y2
|
||||
box1: np.array of shape(4)
|
||||
box2: np.array of shape(nx4)
|
||||
returns: np.array of shape(n)
|
||||
"""
|
||||
|
||||
# Get the coordinates of bounding boxes
|
||||
b1_x1, b1_y1, b1_x2, b1_y2 = box1
|
||||
b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
|
||||
|
||||
# Intersection area
|
||||
inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
|
||||
(np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
|
||||
|
||||
# box2 area
|
||||
box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + eps
|
||||
|
||||
# Intersection over box2 area
|
||||
return inter_area / box2_area
|
70
ultralytics/yolo/utils/torch_utils.py
Normal file
70
ultralytics/yolo/utils/torch_utils.py
Normal file
@ -0,0 +1,70 @@
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from ultralytics.yolo.utils import check_version
|
||||
|
||||
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||
RANK = int(os.getenv('RANK', -1))
|
||||
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def torch_distributed_zero_first(local_rank: int):
|
||||
# Decorator to make all processes in distributed training wait for each local_master to do something
|
||||
if local_rank not in [-1, 0]:
|
||||
dist.barrier(device_ids=[local_rank])
|
||||
yield
|
||||
if local_rank == 0:
|
||||
dist.barrier(device_ids=[0])
|
||||
|
||||
|
||||
def DDP_model(model):
|
||||
# Model DDP creation with checks
|
||||
assert not check_version(torch.__version__, '1.12.0', pinned=True), \
|
||||
'torch==1.12.0 torchvision==0.13.0 DDP training is not supported due to a known issue. ' \
|
||||
'Please upgrade or downgrade torch to use DDP. See https://github.com/ultralytics/yolov5/issues/8395'
|
||||
if check_version(torch.__version__, '1.11.0'):
|
||||
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
|
||||
else:
|
||||
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
|
||||
|
||||
|
||||
def select_device(device='', batch_size=0, newline=True):
|
||||
# device = None or 'cpu' or 0 or '0' or '0,1,2,3'
|
||||
# s = f'YOLOv5 🚀 {git_describe() or file_date()} Python-{platform.python_version()} torch-{torch.__version__} '
|
||||
s = f'YOLOv5 🚀 torch-{torch.__version__} '
|
||||
device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0'
|
||||
cpu = device == 'cpu'
|
||||
mps = device == 'mps' # Apple Metal Performance Shaders (MPS)
|
||||
if cpu or mps:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
|
||||
elif device: # non-cpu device requested
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
|
||||
assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
|
||||
f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
|
||||
|
||||
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
|
||||
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
||||
n = len(devices) # device count
|
||||
if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
|
||||
assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
|
||||
space = ' ' * (len(s) + 1)
|
||||
for i, d in enumerate(devices):
|
||||
p = torch.cuda.get_device_properties(i)
|
||||
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
|
||||
arg = 'cuda:0'
|
||||
elif mps and getattr(torch, 'has_mps', False) and torch.backends.mps.is_available(): # prefer MPS if available
|
||||
s += 'MPS\n'
|
||||
arg = 'mps'
|
||||
else: # revert to CPU
|
||||
s += 'CPU\n'
|
||||
arg = 'cpu'
|
||||
|
||||
if not newline:
|
||||
s = s.rstrip()
|
||||
print(s)
|
||||
return torch.device(arg)
|
Reference in New Issue
Block a user