Predictor support (#65)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
@ -6,6 +6,8 @@ import sys
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
import IPython
|
||||
|
||||
# Constants
|
||||
FILE = Path(__file__).resolve()
|
||||
ROOT = FILE.parents[2] # YOLO
|
||||
@ -29,6 +31,23 @@ def is_kaggle():
|
||||
return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
|
||||
|
||||
|
||||
def is_notebook():
|
||||
# Is environment a Jupyter notebook? Verified on Colab, Jupyterlab, Kaggle, Paperspace
|
||||
ipython_type = str(type(IPython.get_ipython()))
|
||||
return 'colab' in ipython_type or 'zmqshell' in ipython_type
|
||||
|
||||
|
||||
def is_docker() -> bool:
|
||||
"""Check if the process runs inside a docker container."""
|
||||
if Path("/.dockerenv").exists():
|
||||
return True
|
||||
try: # check if docker is in control groups
|
||||
with open("/proc/self/cgroup") as file:
|
||||
return any("docker" in line for line in file)
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def is_writeable(dir, test=False):
|
||||
# Return True if directory has write permissions, test opening a file with write permissions if test=True
|
||||
if not test:
|
||||
|
@ -6,10 +6,13 @@ from pathlib import Path
|
||||
from subprocess import check_output
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pkg_resources as pkg
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.utils import AUTOINSTALL, FONT, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, emojis
|
||||
from ultralytics.yolo.utils import (AUTOINSTALL, FONT, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, emojis,
|
||||
is_docker, is_notebook)
|
||||
|
||||
|
||||
def is_ascii(s=''):
|
||||
@ -131,6 +134,22 @@ def check_yaml(file, suffix=('.yaml', '.yml')):
|
||||
return check_file(file, suffix)
|
||||
|
||||
|
||||
def check_imshow(warn=False):
|
||||
# Check if environment supports image displays
|
||||
try:
|
||||
assert not is_notebook()
|
||||
assert not is_docker()
|
||||
cv2.imshow('test', np.zeros((1, 1, 3)))
|
||||
cv2.waitKey(1)
|
||||
cv2.destroyAllWindows()
|
||||
cv2.waitKey(1)
|
||||
return True
|
||||
except Exception as e:
|
||||
if warn:
|
||||
LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}')
|
||||
return False
|
||||
|
||||
|
||||
def git_describe(path=ROOT): # path must be a directory
|
||||
# Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
|
||||
try:
|
||||
|
@ -0,0 +1,23 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
|
||||
def get_config(config: Union[str, DictConfig], overrides: Union[str, Dict] = {}):
|
||||
"""
|
||||
Accepts yaml file name or DictConfig containing experiment configuration.
|
||||
Returns training args namespace
|
||||
:param config: Optional file name or DictConfig object
|
||||
"""
|
||||
if isinstance(config, (str, Path)):
|
||||
config = OmegaConf.load(config)
|
||||
elif isinstance(config, Dict):
|
||||
config = OmegaConf.create(config)
|
||||
# override
|
||||
if isinstance(overrides, str):
|
||||
overrides = OmegaConf.load(overrides)
|
||||
elif isinstance(overrides, Dict):
|
||||
overrides = OmegaConf.create(overrides)
|
||||
|
||||
return OmegaConf.merge(config, overrides)
|
||||
|
@ -46,7 +46,22 @@ max_det: 300
|
||||
half: True
|
||||
dnn: False # use OpenCV DNN for ONNX inference
|
||||
plots: False
|
||||
|
||||
# Prediction settings:
|
||||
source: "ultralytics/assets/"
|
||||
view_img: False
|
||||
save_txt: False
|
||||
save_conf: False
|
||||
save_crop: False
|
||||
hide_labels: False # hide labels
|
||||
hide_conf: False
|
||||
vid_stride: 1 # video frame-rate stride
|
||||
line_thickness: 3 # bounding box thickness (pixels)
|
||||
update: False # Update all models
|
||||
visualize: False
|
||||
augment: False
|
||||
agnostic_nms: False # class-agnostic NMS
|
||||
retina_masks: False
|
||||
|
||||
# Hyperparameters ------------------------------------------------------------------------------------------------------
|
||||
lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)
|
||||
|
@ -1,7 +1,6 @@
|
||||
import contextlib
|
||||
|
||||
import torchvision
|
||||
import yaml
|
||||
|
||||
from ultralytics.yolo.utils.downloads import attempt_download
|
||||
from ultralytics.yolo.utils.modeling.modules import *
|
||||
|
@ -1,5 +1,6 @@
|
||||
import contextlib
|
||||
import math
|
||||
import re
|
||||
import time
|
||||
|
||||
import cv2
|
||||
@ -374,3 +375,75 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False):
|
||||
if upsample:
|
||||
masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
|
||||
return masks.gt_(0.5)
|
||||
|
||||
|
||||
def process_mask_native(protos, masks_in, bboxes, shape):
|
||||
"""
|
||||
Crop after upsample.
|
||||
protos: [mask_dim, mask_h, mask_w]
|
||||
masks_in: [n, mask_dim], n is number of masks after nms
|
||||
bboxes: [n, 4], n is number of masks after nms
|
||||
shape: input_image_size, (h, w)
|
||||
return: h, w, n
|
||||
"""
|
||||
c, mh, mw = protos.shape # CHW
|
||||
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
|
||||
gain = min(mh / shape[0], mw / shape[1]) # gain = old / new
|
||||
pad = (mw - shape[1] * gain) / 2, (mh - shape[0] * gain) / 2 # wh padding
|
||||
top, left = int(pad[1]), int(pad[0]) # y, x
|
||||
bottom, right = int(mh - pad[1]), int(mw - pad[0])
|
||||
masks = masks[:, top:bottom, left:right]
|
||||
|
||||
masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW
|
||||
masks = crop_mask(masks, bboxes) # CHW
|
||||
return masks.gt_(0.5)
|
||||
|
||||
|
||||
def scale_segments(img1_shape, segments, img0_shape, ratio_pad=None, normalize=False):
|
||||
# Rescale coords (xyxy) from img1_shape to img0_shape
|
||||
if ratio_pad is None: # calculate from img0_shape
|
||||
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
||||
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
|
||||
else:
|
||||
gain = ratio_pad[0][0]
|
||||
pad = ratio_pad[1]
|
||||
|
||||
segments[:, 0] -= pad[0] # x padding
|
||||
segments[:, 1] -= pad[1] # y padding
|
||||
segments /= gain
|
||||
clip_segments(segments, img0_shape)
|
||||
if normalize:
|
||||
segments[:, 0] /= img0_shape[1] # width
|
||||
segments[:, 1] /= img0_shape[0] # height
|
||||
return segments
|
||||
|
||||
|
||||
def masks2segments(masks, strategy='largest'):
|
||||
# Convert masks(n,160,160) into segments(n,xy)
|
||||
segments = []
|
||||
for x in masks.int().cpu().numpy().astype('uint8'):
|
||||
c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
|
||||
if c:
|
||||
if strategy == 'concat': # concatenate all segments
|
||||
c = np.concatenate([x.reshape(-1, 2) for x in c])
|
||||
elif strategy == 'largest': # select largest segment
|
||||
c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
|
||||
else:
|
||||
c = np.zeros((0, 2)) # no segments found
|
||||
segments.append(c.astype('float32'))
|
||||
return segments
|
||||
|
||||
|
||||
def clip_segments(segments, shape):
|
||||
# Clip segments (xy1,xy2,...) to image shape (height, width)
|
||||
if isinstance(segments, torch.Tensor): # faster individually
|
||||
segments[:, 0].clamp_(0, shape[1]) # x
|
||||
segments[:, 1].clamp_(0, shape[0]) # y
|
||||
else: # np.array (faster grouped)
|
||||
segments[:, 0] = segments[:, 0].clip(0, shape[1]) # x
|
||||
segments[:, 1] = segments[:, 1].clip(0, shape[0]) # y
|
||||
|
||||
|
||||
def clean_str(s):
|
||||
# Cleans a string by replacing special characters with underscore _
|
||||
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
|
||||
|
@ -36,6 +36,14 @@ def torch_distributed_zero_first(local_rank: int):
|
||||
dist.barrier(device_ids=[0])
|
||||
|
||||
|
||||
def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
|
||||
# Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator
|
||||
def decorate(fn):
|
||||
return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def DDP_model(model):
|
||||
# Model DDP creation with checks
|
||||
assert not check_version(torch.__version__, '1.12.0', pinned=True), \
|
||||
@ -192,14 +200,6 @@ def copy_attr(a, b, include=(), exclude=()):
|
||||
setattr(a, k, v)
|
||||
|
||||
|
||||
def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
|
||||
# Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator
|
||||
def decorate(fn):
|
||||
return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def intersect_state_dicts(da, db, exclude=()):
|
||||
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
|
||||
return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
|
||||
|
Reference in New Issue
Block a user