Update .pre-commit-config.yaml (#1026)

This commit is contained in:
Glenn Jocher
2023-02-17 22:26:40 +01:00
committed by GitHub
parent 9047d737f4
commit edd3ff1669
76 changed files with 928 additions and 935 deletions

View File

@ -27,7 +27,7 @@ from ultralytics import __version__
# Constants
FILE = Path(__file__).resolve()
ROOT = FILE.parents[2] # YOLO
DEFAULT_CFG_PATH = ROOT / "yolo/cfg/default.yaml"
DEFAULT_CFG_PATH = ROOT / 'yolo/cfg/default.yaml'
RANK = int(os.getenv('RANK', -1))
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
@ -111,7 +111,7 @@ class IterableSimpleNamespace(SimpleNamespace):
return iter(vars(self).items())
def __str__(self):
return '\n'.join(f"{k}={v}" for k, v in vars(self).items())
return '\n'.join(f'{k}={v}' for k, v in vars(self).items())
def __getattr__(self, attr):
name = self.__class__.__name__
@ -288,7 +288,7 @@ def is_pytest_running():
(bool): True if pytest is running, False otherwise.
"""
with contextlib.suppress(Exception):
return "pytest" in sys.modules
return 'pytest' in sys.modules
return False
@ -336,7 +336,7 @@ def get_git_origin_url():
"""
if is_git_dir():
with contextlib.suppress(subprocess.CalledProcessError):
origin = subprocess.check_output(["git", "config", "--get", "remote.origin.url"])
origin = subprocess.check_output(['git', 'config', '--get', 'remote.origin.url'])
return origin.decode().strip()
return None # if not git dir or on error
@ -350,7 +350,7 @@ def get_git_branch():
"""
if is_git_dir():
with contextlib.suppress(subprocess.CalledProcessError):
origin = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"])
origin = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
return origin.decode().strip()
return None # if not git dir or on error
@ -365,9 +365,9 @@ def get_latest_pypi_version(package_name='ultralytics'):
Returns:
str: The latest version of the package.
"""
response = requests.get(f"https://pypi.org/pypi/{package_name}/json")
response = requests.get(f'https://pypi.org/pypi/{package_name}/json')
if response.status_code == 200:
return response.json()["info"]["version"]
return response.json()['info']['version']
return None
@ -424,28 +424,28 @@ def emojis(string=''):
def colorstr(*input):
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
*args, string = input if len(input) > 1 else ("blue", "bold", input[0]) # color arguments, string
*args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
colors = {
"black": "\033[30m", # basic colors
"red": "\033[31m",
"green": "\033[32m",
"yellow": "\033[33m",
"blue": "\033[34m",
"magenta": "\033[35m",
"cyan": "\033[36m",
"white": "\033[37m",
"bright_black": "\033[90m", # bright colors
"bright_red": "\033[91m",
"bright_green": "\033[92m",
"bright_yellow": "\033[93m",
"bright_blue": "\033[94m",
"bright_magenta": "\033[95m",
"bright_cyan": "\033[96m",
"bright_white": "\033[97m",
"end": "\033[0m", # misc
"bold": "\033[1m",
"underline": "\033[4m"}
return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
'black': '\033[30m', # basic colors
'red': '\033[31m',
'green': '\033[32m',
'yellow': '\033[33m',
'blue': '\033[34m',
'magenta': '\033[35m',
'cyan': '\033[36m',
'white': '\033[37m',
'bright_black': '\033[90m', # bright colors
'bright_red': '\033[91m',
'bright_green': '\033[92m',
'bright_yellow': '\033[93m',
'bright_blue': '\033[94m',
'bright_magenta': '\033[95m',
'bright_cyan': '\033[96m',
'bright_white': '\033[97m',
'end': '\033[0m', # misc
'bold': '\033[1m',
'underline': '\033[4m'}
return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
def remove_ansi_codes(string):
@ -466,21 +466,21 @@ def set_logging(name=LOGGING_NAME, verbose=True):
rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
logging.config.dictConfig({
"version": 1,
"disable_existing_loggers": False,
"formatters": {
'version': 1,
'disable_existing_loggers': False,
'formatters': {
name: {
"format": "%(message)s"}},
"handlers": {
'format': '%(message)s'}},
'handlers': {
name: {
"class": "logging.StreamHandler",
"formatter": name,
"level": level}},
"loggers": {
'class': 'logging.StreamHandler',
'formatter': name,
'level': level}},
'loggers': {
name: {
"level": level,
"handlers": [name],
"propagate": False}}})
'level': level,
'handlers': [name],
'propagate': False}}})
class TryExcept(contextlib.ContextDecorator):
@ -521,10 +521,10 @@ def set_sentry():
return None # do not send event
event['tags'] = {
"sys_argv": sys.argv[0],
"sys_argv_name": Path(sys.argv[0]).name,
"install": 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
"os": ENVIRONMENT}
'sys_argv': sys.argv[0],
'sys_argv_name': Path(sys.argv[0]).name,
'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
'os': ENVIRONMENT}
return event
if SETTINGS['sync'] and \
@ -533,24 +533,24 @@ def set_sentry():
not is_pytest_running() and \
not is_github_actions_ci() and \
((is_pip_package() and not is_git_dir()) or
(get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git" and get_git_branch() == "main")):
(get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git' and get_git_branch() == 'main')):
import hashlib
import sentry_sdk # noqa
sentry_sdk.init(
dsn="https://f805855f03bb4363bc1e16cb7d87b654@o4504521589325824.ingest.sentry.io/4504521592406016",
dsn='https://f805855f03bb4363bc1e16cb7d87b654@o4504521589325824.ingest.sentry.io/4504521592406016',
debug=False,
traces_sample_rate=1.0,
release=__version__,
environment='production', # 'dev' or 'production'
before_send=before_send,
ignore_errors=[KeyboardInterrupt, FileNotFoundError])
sentry_sdk.set_user({"id": SETTINGS['uuid']})
sentry_sdk.set_user({'id': SETTINGS['uuid']})
# Disable all sentry logging
for logger in "sentry_sdk", "sentry_sdk.errors":
for logger in 'sentry_sdk', 'sentry_sdk.errors':
logging.getLogger(logger).setLevel(logging.CRITICAL)
@ -620,7 +620,7 @@ if WINDOWS:
setattr(LOGGER, fn.__name__, lambda x: fn(emojis(x))) # emoji safe logging
# Check first-install steps
PREFIX = colorstr("Ultralytics: ")
PREFIX = colorstr('Ultralytics: ')
SETTINGS = get_settings()
DATASETS_DIR = Path(SETTINGS['datasets_dir']) # global datasets directory
ENVIRONMENT = 'Colab' if is_colab() else 'Kaggle' if is_kaggle() else 'Jupyter' if is_jupyter() else \

View File

@ -11,7 +11,7 @@ except (ImportError, AssertionError):
clearml = None
def _log_images(imgs_dict, group="", step=0):
def _log_images(imgs_dict, group='', step=0):
task = Task.current_task()
if task:
for k, v in imgs_dict.items():
@ -20,7 +20,7 @@ def _log_images(imgs_dict, group="", step=0):
def on_pretrain_routine_start(trainer):
# TODO: reuse existing task
task = Task.init(project_name=trainer.args.project or "YOLOv8",
task = Task.init(project_name=trainer.args.project or 'YOLOv8',
task_name=trainer.args.name,
tags=['YOLOv8'],
output_uri=True,
@ -31,15 +31,15 @@ def on_pretrain_routine_start(trainer):
def on_train_epoch_end(trainer):
if trainer.epoch == 1:
_log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, "Mosaic", trainer.epoch)
_log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, 'Mosaic', trainer.epoch)
def on_fit_epoch_end(trainer):
if trainer.epoch == 0:
model_info = {
"Parameters": get_num_params(trainer.model),
"GFLOPs": round(get_flops(trainer.model), 3),
"Inference speed (ms/img)": round(trainer.validator.speed[1], 3)}
'Parameters': get_num_params(trainer.model),
'GFLOPs': round(get_flops(trainer.model), 3),
'Inference speed (ms/img)': round(trainer.validator.speed[1], 3)}
Task.current_task().connect(model_info, name='Model')
@ -50,7 +50,7 @@ def on_train_end(trainer):
callbacks = {
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end} if clearml else {}
'on_pretrain_routine_start': on_pretrain_routine_start,
'on_train_epoch_end': on_train_epoch_end,
'on_fit_epoch_end': on_fit_epoch_end,
'on_train_end': on_train_end} if clearml else {}

View File

@ -10,13 +10,13 @@ except ImportError:
def on_pretrain_routine_start(trainer):
experiment = comet_ml.Experiment(project_name=trainer.args.project or "YOLOv8")
experiment = comet_ml.Experiment(project_name=trainer.args.project or 'YOLOv8')
experiment.log_parameters(vars(trainer.args))
def on_train_epoch_end(trainer):
experiment = comet_ml.get_global_experiment()
experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1)
experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1)
if trainer.epoch == 1:
for f in trainer.save_dir.glob('train_batch*.jpg'):
experiment.log_image(f, name=f.stem, step=trainer.epoch + 1)
@ -27,19 +27,19 @@ def on_fit_epoch_end(trainer):
experiment.log_metrics(trainer.metrics, step=trainer.epoch + 1)
if trainer.epoch == 0:
model_info = {
"model/parameters": get_num_params(trainer.model),
"model/GFLOPs": round(get_flops(trainer.model), 3),
"model/speed(ms)": round(trainer.validator.speed[1], 3)}
'model/parameters': get_num_params(trainer.model),
'model/GFLOPs': round(get_flops(trainer.model), 3),
'model/speed(ms)': round(trainer.validator.speed[1], 3)}
experiment.log_metrics(model_info, step=trainer.epoch + 1)
def on_train_end(trainer):
experiment = comet_ml.get_global_experiment()
experiment.log_model("YOLOv8", file_or_folder=str(trainer.best), file_name="best.pt", overwrite=True)
experiment.log_model('YOLOv8', file_or_folder=str(trainer.best), file_name='best.pt', overwrite=True)
callbacks = {
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end} if comet_ml else {}
'on_pretrain_routine_start': on_pretrain_routine_start,
'on_train_epoch_end': on_train_epoch_end,
'on_fit_epoch_end': on_fit_epoch_end,
'on_train_end': on_train_end} if comet_ml else {}

View File

@ -11,7 +11,7 @@ def on_pretrain_routine_end(trainer):
session = getattr(trainer, 'hub_session', None)
if session:
# Start timer for upload rate limit
LOGGER.info(f"{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀")
LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀')
session.t = {'metrics': time(), 'ckpt': time()} # start timer on self.rate_limit
@ -31,7 +31,7 @@ def on_model_save(trainer):
# Upload checkpoints with rate limiting
is_best = trainer.best_fitness == trainer.fitness
if time() - session.t['ckpt'] > session.rate_limits['ckpt']:
LOGGER.info(f"{PREFIX}Uploading checkpoint {session.model_id}")
LOGGER.info(f'{PREFIX}Uploading checkpoint {session.model_id}')
session.upload_model(trainer.epoch, trainer.last, is_best)
session.t['ckpt'] = time() # reset timer
@ -40,11 +40,11 @@ def on_train_end(trainer):
session = getattr(trainer, 'hub_session', None)
if session:
# Upload final model and metrics with exponential standoff
LOGGER.info(f"{PREFIX}Training completed successfully ✅\n"
f"{PREFIX}Uploading final {session.model_id}")
LOGGER.info(f'{PREFIX}Training completed successfully ✅\n'
f'{PREFIX}Uploading final {session.model_id}')
session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics['metrics/mAP50-95(B)'], final=True)
session.shutdown() # stop heartbeats
LOGGER.info(f"{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀")
LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀')
def on_train_start(trainer):
@ -64,11 +64,11 @@ def on_export_start(exporter):
callbacks = {
"on_pretrain_routine_end": on_pretrain_routine_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_model_save": on_model_save,
"on_train_end": on_train_end,
"on_train_start": on_train_start,
"on_val_start": on_val_start,
"on_predict_start": on_predict_start,
"on_export_start": on_export_start}
'on_pretrain_routine_end': on_pretrain_routine_end,
'on_fit_epoch_end': on_fit_epoch_end,
'on_model_save': on_model_save,
'on_train_end': on_train_end,
'on_train_start': on_train_start,
'on_val_start': on_val_start,
'on_predict_start': on_predict_start,
'on_export_start': on_export_start}

View File

@ -16,7 +16,7 @@ def on_pretrain_routine_start(trainer):
def on_batch_end(trainer):
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
def on_fit_epoch_end(trainer):
@ -24,6 +24,6 @@ def on_fit_epoch_end(trainer):
callbacks = {
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_fit_epoch_end": on_fit_epoch_end,
"on_batch_end": on_batch_end}
'on_pretrain_routine_start': on_pretrain_routine_start,
'on_fit_epoch_end': on_fit_epoch_end,
'on_batch_end': on_batch_end}

View File

@ -71,7 +71,7 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
msg = "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list " \
"or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'"
if max_dim != 1:
raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}")
raise ValueError(f'imgsz={imgsz} is not a valid image size. {msg}')
LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}")
imgsz = [max(imgsz)]
# Make image size a multiple of the stride
@ -87,9 +87,9 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
return sz
def check_version(current: str = "0.0.0",
minimum: str = "0.0.0",
name: str = "version ",
def check_version(current: str = '0.0.0',
minimum: str = '0.0.0',
name: str = 'version ',
pinned: bool = False,
hard: bool = False,
verbose: bool = False) -> bool:
@ -109,7 +109,7 @@ def check_version(current: str = "0.0.0",
"""
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
result = (current == minimum) if pinned else (current >= minimum) # bool
warning_message = f"WARNING ⚠️ {name}{minimum} is required by YOLOv8, but {name}{current} is currently installed"
warning_message = f'WARNING ⚠️ {name}{minimum} is required by YOLOv8, but {name}{current} is currently installed'
if hard:
assert result, emojis(warning_message) # assert min requirements met
if verbose and not result:
@ -155,7 +155,7 @@ def check_online() -> bool:
"""
import socket
with contextlib.suppress(Exception):
host = socket.gethostbyname("www.github.com")
host = socket.gethostbyname('www.github.com')
socket.create_connection((host, 80), timeout=2)
return True
return False
@ -182,7 +182,7 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
file = None
if isinstance(requirements, Path): # requirements.txt file
file = requirements.resolve()
assert file.exists(), f"{prefix} {file} not found, check failed."
assert file.exists(), f'{prefix} {file} not found, check failed.'
with file.open() as f:
requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
elif isinstance(requirements, str):
@ -200,7 +200,7 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
if s and install and AUTOINSTALL: # check environment variable
LOGGER.info(f"{prefix} YOLOv8 requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...")
try:
assert check_online(), "AutoUpdate skipped (offline)"
assert check_online(), 'AutoUpdate skipped (offline)'
LOGGER.info(subprocess.check_output(f'pip install {s} {cmds}', shell=True).decode())
s = f"{prefix} {n} package{'s' * (n > 1)} updated per {file or requirements}\n" \
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
@ -217,19 +217,19 @@ def check_suffix(file='yolov8n.pt', suffix=('.pt',), msg=''):
for f in file if isinstance(file, (list, tuple)) else [file]:
s = Path(f).suffix.lower() # file suffix
if len(s):
assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
assert s in suffix, f'{msg}{f} acceptable suffix is {suffix}'
def check_yolov5u_filename(file: str):
# Replace legacy YOLOv5 filenames with updated YOLOv5u filenames
if 'yolov3' in file or 'yolov5' in file and 'u' not in file:
original_file = file
file = re.sub(r"(.*yolov5([nsmlx]))\.", "\\1u.", file) # i.e. yolov5n.pt -> yolov5nu.pt
file = re.sub(r"(.*yolov3(|-tiny|-spp))\.", "\\1u.", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
file = re.sub(r'(.*yolov5([nsmlx]))\.', '\\1u.', file) # i.e. yolov5n.pt -> yolov5nu.pt
file = re.sub(r'(.*yolov3(|-tiny|-spp))\.', '\\1u.', file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
if file != original_file:
LOGGER.info(f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are "
f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs "
f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n")
f'trained with https://github.com/ultralytics/ultralytics and feature improved performance vs '
f'standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n')
return file
@ -290,7 +290,7 @@ def check_yolo(verbose=True):
# System info
gib = 1 << 30 # bytes per GiB
ram = psutil.virtual_memory().total
total, used, free = shutil.disk_usage("/")
total, used, free = shutil.disk_usage('/')
s = f'({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)'
with contextlib.suppress(Exception): # clear display if ipython is installed
from IPython import display

View File

@ -22,7 +22,7 @@ def find_free_network_port() -> int:
def generate_ddp_file(trainer):
import_path = '.'.join(str(trainer.__class__).split(".")[1:-1])
import_path = '.'.join(str(trainer.__class__).split('.')[1:-1])
if not trainer.resume:
shutil.rmtree(trainer.save_dir) # remove the save_dir
@ -32,9 +32,9 @@ def generate_ddp_file(trainer):
trainer = {trainer.__class__.__name__}(cfg=cfg)
trainer.train()'''
(USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
with tempfile.NamedTemporaryFile(prefix="_temp_",
suffix=f"{id(trainer)}.py",
mode="w+",
with tempfile.NamedTemporaryFile(prefix='_temp_',
suffix=f'{id(trainer)}.py',
mode='w+',
encoding='utf-8',
dir=USER_CONFIG_DIR / 'DDP',
delete=False) as file:
@ -47,18 +47,18 @@ def generate_ddp_command(world_size, trainer):
# Get file and args (do not use sys.argv due to security vulnerability)
exclude_args = ['save_dir']
args = [f"{k}={v}" for k, v in vars(trainer.args).items() if k not in exclude_args]
args = [f'{k}={v}' for k, v in vars(trainer.args).items() if k not in exclude_args]
file = generate_ddp_file(trainer) # if argv[0].endswith('yolo') else os.path.abspath(argv[0])
# Build command
torch_distributed_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch"
torch_distributed_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
cmd = [
sys.executable, "-m", torch_distributed_cmd, "--nproc_per_node", f"{world_size}", "--master_port",
f"{find_free_network_port()}", file] + args
sys.executable, '-m', torch_distributed_cmd, '--nproc_per_node', f'{world_size}', '--master_port',
f'{find_free_network_port()}', file] + args
return cmd, file
def ddp_cleanup(trainer, file):
# delete temp file if created
if f"{id(trainer)}.py" in file: # if temp_file suffix in file
if f'{id(trainer)}.py' in file: # if temp_file suffix in file
os.remove(file)

View File

@ -95,14 +95,14 @@ def safe_download(url,
torch.hub.download_url_to_file(url, f, progress=progress)
else:
from ultralytics.yolo.utils import TQDM_BAR_FORMAT
with request.urlopen(url) as response, tqdm(total=int(response.getheader("Content-Length", 0)),
with request.urlopen(url) as response, tqdm(total=int(response.getheader('Content-Length', 0)),
desc=desc,
disable=not progress,
unit='B',
unit_scale=True,
unit_divisor=1024,
bar_format=TQDM_BAR_FORMAT) as pbar:
with open(f, "wb") as f_opened:
with open(f, 'wb') as f_opened:
for data in response:
f_opened.write(data)
pbar.update(len(data))
@ -171,7 +171,7 @@ def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'):
tag, assets = github_assets(repo) # latest release
except Exception:
try:
tag = subprocess.check_output(["git", "tag"]).decode().split()[-1]
tag = subprocess.check_output(['git', 'tag']).decode().split()[-1]
except Exception:
tag = release

View File

@ -24,15 +24,15 @@ to_4tuple = _ntuple(4)
# `xyxy` means left top and right bottom
# `xywh` means center x, center y and width, height(yolo format)
# `ltwh` means left top and width, height(coco format)
_formats = ["xyxy", "xywh", "ltwh"]
_formats = ['xyxy', 'xywh', 'ltwh']
__all__ = ["Bboxes"]
__all__ = ['Bboxes']
class Bboxes:
"""Now only numpy is supported"""
def __init__(self, bboxes, format="xyxy") -> None:
def __init__(self, bboxes, format='xyxy') -> None:
assert format in _formats
bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes
assert bboxes.ndim == 2
@ -67,17 +67,17 @@ class Bboxes:
assert format in _formats
if self.format == format:
return
elif self.format == "xyxy":
bboxes = xyxy2xywh(self.bboxes) if format == "xywh" else xyxy2ltwh(self.bboxes)
elif self.format == "xywh":
bboxes = xywh2xyxy(self.bboxes) if format == "xyxy" else xywh2ltwh(self.bboxes)
elif self.format == 'xyxy':
bboxes = xyxy2xywh(self.bboxes) if format == 'xywh' else xyxy2ltwh(self.bboxes)
elif self.format == 'xywh':
bboxes = xywh2xyxy(self.bboxes) if format == 'xyxy' else xywh2ltwh(self.bboxes)
else:
bboxes = ltwh2xyxy(self.bboxes) if format == "xyxy" else ltwh2xywh(self.bboxes)
bboxes = ltwh2xyxy(self.bboxes) if format == 'xyxy' else ltwh2xywh(self.bboxes)
self.bboxes = bboxes
self.format = format
def areas(self):
self.convert("xyxy")
self.convert('xyxy')
return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1])
# def denormalize(self, w, h):
@ -128,7 +128,7 @@ class Bboxes:
return len(self.bboxes)
@classmethod
def concatenate(cls, boxes_list: List["Bboxes"], axis=0) -> "Bboxes":
def concatenate(cls, boxes_list: List['Bboxes'], axis=0) -> 'Bboxes':
"""
Concatenates a list of Boxes into a single Bboxes
@ -147,7 +147,7 @@ class Bboxes:
return boxes_list[0]
return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))
def __getitem__(self, index) -> "Bboxes":
def __getitem__(self, index) -> 'Bboxes':
"""
Args:
index: int, slice, or a BoolArray
@ -158,13 +158,13 @@ class Bboxes:
if isinstance(index, int):
return Bboxes(self.bboxes[index].view(1, -1))
b = self.bboxes[index]
assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!"
assert b.ndim == 2, f'Indexing on Bboxes with {index} failed to return a matrix!'
return Bboxes(b)
class Instances:
def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:
def __init__(self, bboxes, segments=None, keypoints=None, bbox_format='xywh', normalized=True) -> None:
"""
Args:
bboxes (ndarray): bboxes with shape [N, 4].
@ -227,7 +227,7 @@ class Instances:
def add_padding(self, padw, padh):
# handle rect and mosaic situation
assert not self.normalized, "you should add padding with absolute coordinates."
assert not self.normalized, 'you should add padding with absolute coordinates.'
self._bboxes.add(offset=(padw, padh, padw, padh))
self.segments[..., 0] += padw
self.segments[..., 1] += padh
@ -235,7 +235,7 @@ class Instances:
self.keypoints[..., 0] += padw
self.keypoints[..., 1] += padh
def __getitem__(self, index) -> "Instances":
def __getitem__(self, index) -> 'Instances':
"""
Args:
index: int, slice, or a BoolArray
@ -256,7 +256,7 @@ class Instances:
)
def flipud(self, h):
if self._bboxes.format == "xyxy":
if self._bboxes.format == 'xyxy':
y1 = self.bboxes[:, 1].copy()
y2 = self.bboxes[:, 3].copy()
self.bboxes[:, 1] = h - y2
@ -268,7 +268,7 @@ class Instances:
self.keypoints[..., 1] = h - self.keypoints[..., 1]
def fliplr(self, w):
if self._bboxes.format == "xyxy":
if self._bboxes.format == 'xyxy':
x1 = self.bboxes[:, 0].copy()
x2 = self.bboxes[:, 2].copy()
self.bboxes[:, 0] = w - x2
@ -281,10 +281,10 @@ class Instances:
def clip(self, w, h):
ori_format = self._bboxes.format
self.convert_bbox(format="xyxy")
self.convert_bbox(format='xyxy')
self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)
if ori_format != "xyxy":
if ori_format != 'xyxy':
self.convert_bbox(format=ori_format)
self.segments[..., 0] = self.segments[..., 0].clip(0, w)
self.segments[..., 1] = self.segments[..., 1].clip(0, h)
@ -304,7 +304,7 @@ class Instances:
return len(self.bboxes)
@classmethod
def concatenate(cls, instances_list: List["Instances"], axis=0) -> "Instances":
def concatenate(cls, instances_list: List['Instances'], axis=0) -> 'Instances':
"""
Concatenates a list of Boxes into a single Bboxes

View File

@ -16,7 +16,7 @@ class VarifocalLoss(nn.Module):
def forward(self, pred_score, gt_score, label, alpha=0.75, gamma=2.0):
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
with torch.cuda.amp.autocast(enabled=False):
loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") *
loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') *
weight).sum()
return loss
@ -52,5 +52,5 @@ class BboxLoss(nn.Module):
tr = tl + 1 # target right
wl = tr - target # weight left
wr = 1 - wl # weight right
return (F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl +
F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr).mean(-1, keepdim=True)
return (F.cross_entropy(pred_dist, tl.view(-1), reduction='none').view(tl.shape) * wl +
F.cross_entropy(pred_dist, tr.view(-1), reduction='none').view(tl.shape) * wr).mean(-1, keepdim=True)

View File

@ -238,14 +238,14 @@ class ConfusionMatrix:
nc, nn = self.nc, len(names) # number of classes, names
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
ticklabels = (names + ['background']) if labels else "auto"
ticklabels = (names + ['background']) if labels else 'auto'
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
sn.heatmap(array,
ax=ax,
annot=nc < 30,
annot_kws={
"size": 8},
'size': 8},
cmap='Blues',
fmt='.2f',
square=True,
@ -287,7 +287,7 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
ax.set_ylabel('Precision')
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
ax.legend(bbox_to_anchor=(1.04, 1), loc='upper left')
ax.set_title('Precision-Recall Curve')
fig.savefig(save_dir, dpi=250)
plt.close(fig)
@ -309,7 +309,7 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi
ax.set_ylabel(ylabel)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
ax.legend(bbox_to_anchor=(1.04, 1), loc='upper left')
ax.set_title(f'{ylabel}-Confidence Curve')
fig.savefig(save_dir, dpi=250)
plt.close(fig)
@ -343,7 +343,7 @@ def compute_ap(recall, precision):
return ap, mpre, mrec
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), names=(), eps=1e-16, prefix=""):
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), names=(), eps=1e-16, prefix=''):
""" Compute the average precision, given the recall and precision curves.
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
# Arguments
@ -507,7 +507,7 @@ class Metric:
class DetMetrics:
def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
def __init__(self, save_dir=Path('.'), plot=False, names=()) -> None:
self.save_dir = save_dir
self.plot = plot
self.names = names
@ -521,7 +521,7 @@ class DetMetrics:
@property
def keys(self):
return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
return ['metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)']
def mean_results(self):
return self.box.mean_results()
@ -543,12 +543,12 @@ class DetMetrics:
@property
def results_dict(self):
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
class SegmentMetrics:
def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
def __init__(self, save_dir=Path('.'), plot=False, names=()) -> None:
self.save_dir = save_dir
self.plot = plot
self.names = names
@ -563,7 +563,7 @@ class SegmentMetrics:
plot=self.plot,
save_dir=self.save_dir,
names=self.names,
prefix="Mask")[2:]
prefix='Mask')[2:]
self.seg.nc = len(self.names)
self.seg.update(results_mask)
results_box = ap_per_class(tp_b,
@ -573,15 +573,15 @@ class SegmentMetrics:
plot=self.plot,
save_dir=self.save_dir,
names=self.names,
prefix="Box")[2:]
prefix='Box')[2:]
self.box.nc = len(self.names)
self.box.update(results_box)
@property
def keys(self):
return [
"metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)",
"metrics/precision(M)", "metrics/recall(M)", "metrics/mAP50(M)", "metrics/mAP50-95(M)"]
'metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)',
'metrics/precision(M)', 'metrics/recall(M)', 'metrics/mAP50(M)', 'metrics/mAP50-95(M)']
def mean_results(self):
return self.box.mean_results() + self.seg.mean_results()
@ -604,7 +604,7 @@ class SegmentMetrics:
@property
def results_dict(self):
return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
class ClassifyMetrics:
@ -626,8 +626,8 @@ class ClassifyMetrics:
@property
def results_dict(self):
return dict(zip(self.keys + ["fitness"], [self.top1, self.top5, self.fitness]))
return dict(zip(self.keys + ['fitness'], [self.top1, self.top5, self.fitness]))
@property
def keys(self):
return ["metrics/accuracy_top1", "metrics/accuracy_top5"]
return ['metrics/accuracy_top1', 'metrics/accuracy_top5']

View File

@ -715,4 +715,4 @@ def clean_str(s):
Returns:
(str): a string with special characters replaced by an underscore _
"""
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
return re.sub(pattern='[|@#!¡·$€%&()=?¿^*;:,¨´><+]', repl='_', string=s)

View File

@ -61,7 +61,7 @@ def DDP_model(model):
def select_device(device='', batch=0, newline=False):
# device = None or 'cpu' or 0 or '0' or '0,1,2,3'
s = f"Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} "
s = f'Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
device = str(device).lower()
for remove in 'cuda:', 'none', '(', ')', '[', ']', "'", ' ':
device = device.replace(remove, '') # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
@ -74,15 +74,15 @@ def select_device(device='', batch=0, newline=False):
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', ''))):
LOGGER.info(s)
install = "See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no " \
"CUDA devices are seen by torch.\n" if torch.cuda.device_count() == 0 else ""
install = 'See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no ' \
'CUDA devices are seen by torch.\n' if torch.cuda.device_count() == 0 else ''
raise ValueError(f"Invalid CUDA 'device={device}' requested."
f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}"
f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}"
f'\ntorch.cuda.is_available(): {torch.cuda.is_available()}'
f'\ntorch.cuda.device_count(): {torch.cuda.device_count()}'
f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
f"{install}")
f'{install}')
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
@ -177,7 +177,7 @@ def model_info(model, verbose=False, imgsz=640):
fused = ' (fused)' if model.is_fused() else ''
fs = f', {flops:.1f} GFLOPs' if flops else ''
m = Path(getattr(model, 'yaml_file', '') or model.yaml.get('yaml_file', '')).stem.replace('yolo', 'YOLO') or 'Model'
LOGGER.info(f"{m} summary{fused}: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
LOGGER.info(f'{m} summary{fused}: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}')
def get_num_params(model):