ultralytics 8.0.80
single-line docstring fixes (#2060)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -199,7 +199,7 @@ def plt_settings(rcparams={'font.size': 11}, backend='Agg'):
|
||||
|
||||
|
||||
def set_logging(name=LOGGING_NAME, verbose=True):
|
||||
# sets up logging for the given name
|
||||
"""Sets up logging for the given name."""
|
||||
rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
|
||||
level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
|
||||
logging.config.dictConfig({
|
||||
@ -539,12 +539,12 @@ SETTINGS_YAML = USER_CONFIG_DIR / 'settings.yaml'
|
||||
|
||||
|
||||
def emojis(string=''):
|
||||
# Return platform-dependent emoji-safe version of string
|
||||
"""Return platform-dependent emoji-safe version of string."""
|
||||
return string.encode().decode('ascii', 'ignore') if WINDOWS else string
|
||||
|
||||
|
||||
def colorstr(*input):
|
||||
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
|
||||
"""Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')."""
|
||||
*args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
|
||||
colors = {
|
||||
'black': '\033[30m', # basic colors
|
||||
@ -570,7 +570,8 @@ def colorstr(*input):
|
||||
|
||||
|
||||
class TryExcept(contextlib.ContextDecorator):
|
||||
# YOLOv8 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager
|
||||
"""YOLOv8 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager."""
|
||||
|
||||
def __init__(self, msg='', verbose=True):
|
||||
self.msg = msg
|
||||
self.verbose = verbose
|
||||
@ -585,7 +586,8 @@ class TryExcept(contextlib.ContextDecorator):
|
||||
|
||||
|
||||
def threaded(func):
|
||||
# Multi-threads a target function and returns thread. Usage: @threaded decorator
|
||||
"""Multi-threads a target function and returns thread. Usage: @threaded decorator."""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
|
||||
thread.start()
|
||||
@ -703,13 +705,13 @@ def deprecation_warn(arg, new_arg, version=None):
|
||||
|
||||
|
||||
def clean_url(url):
|
||||
# Strip auth from URL, i.e. https://url.com/file.txt?auth -> https://url.com/file.txt
|
||||
"""Strip auth from URL, i.e. https://url.com/file.txt?auth -> https://url.com/file.txt."""
|
||||
url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
|
||||
return urllib.parse.unquote(url).split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
|
||||
|
||||
|
||||
def url2file(url):
|
||||
# Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
|
||||
"""Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt."""
|
||||
return Path(clean_url(url)).name
|
||||
|
||||
|
||||
|
@ -15,20 +15,20 @@ except (ImportError, AssertionError):
|
||||
|
||||
COMET_MODE = os.getenv('COMET_MODE', 'online')
|
||||
COMET_MODEL_NAME = os.getenv('COMET_MODEL_NAME', 'YOLOv8')
|
||||
# determines how many batches of image predictions to log from the validation set
|
||||
# Determines how many batches of image predictions to log from the validation set
|
||||
COMET_EVAL_BATCH_LOGGING_INTERVAL = int(os.getenv('COMET_EVAL_BATCH_LOGGING_INTERVAL', 1))
|
||||
# determines whether to log confusion matrix every evaluation epoch
|
||||
# Determines whether to log confusion matrix every evaluation epoch
|
||||
COMET_EVAL_LOG_CONFUSION_MATRIX = (os.getenv('COMET_EVAL_LOG_CONFUSION_MATRIX', 'true').lower() == 'true')
|
||||
# determines whether to log image predictions every evaluation epoch
|
||||
# Determines whether to log image predictions every evaluation epoch
|
||||
COMET_EVAL_LOG_IMAGE_PREDICTIONS = (os.getenv('COMET_EVAL_LOG_IMAGE_PREDICTIONS', 'true').lower() == 'true')
|
||||
COMET_MAX_IMAGE_PREDICTIONS = int(os.getenv('COMET_MAX_IMAGE_PREDICTIONS', 100))
|
||||
|
||||
# ensures certain logging functions only run for supported tasks
|
||||
# Ensures certain logging functions only run for supported tasks
|
||||
COMET_SUPPORTED_TASKS = ['detect']
|
||||
# scales reported confidence scores (0.0-1.0) by this value
|
||||
# Scales reported confidence scores (0.0-1.0) by this value
|
||||
COMET_MAX_CONFIDENCE_SCORE = int(os.getenv('COMET_MAX_CONFIDENCE_SCORE', 100))
|
||||
|
||||
# names of plots created by YOLOv8 that are logged to Comet
|
||||
# Names of plots created by YOLOv8 that are logged to Comet
|
||||
EVALUATION_PLOT_NAMES = 'F1_curve', 'P_curve', 'R_curve', 'PR_curve', 'confusion_matrix'
|
||||
LABEL_PLOT_NAMES = 'labels', 'labels_correlogram'
|
||||
|
||||
@ -43,7 +43,7 @@ def _get_experiment_type(mode, project_name):
|
||||
|
||||
|
||||
def _create_experiment(args):
|
||||
# Ensures that the experiment object is only created in a single process during distributed training.
|
||||
"""Ensures that the experiment object is only created in a single process during distributed training."""
|
||||
if RANK not in (-1, 0):
|
||||
return
|
||||
try:
|
||||
@ -83,13 +83,13 @@ def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, origin
|
||||
|
||||
resized_image_height, resized_image_width = resized_image_shape
|
||||
|
||||
# convert normalized xywh format predictions to xyxy in resized scale format
|
||||
# Convert normalized xywh format predictions to xyxy in resized scale format
|
||||
box = ops.xywhn2xyxy(box, h=resized_image_height, w=resized_image_width)
|
||||
# scale box predictions from resized image scale back to original image scale
|
||||
# Scale box predictions from resized image scale back to original image scale
|
||||
box = ops.scale_boxes(resized_image_shape, box, original_image_shape, ratio_pad)
|
||||
# Convert bounding box format from xyxy to xywh for Comet logging
|
||||
box = ops.xyxy2xywh(box)
|
||||
# adjust xy center to correspond top-left corner
|
||||
# Adjust xy center to correspond top-left corner
|
||||
box[:2] -= box[2:] / 2
|
||||
box = box.tolist()
|
||||
|
||||
|
@ -244,7 +244,7 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
|
||||
|
||||
|
||||
def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''):
|
||||
# Check file(s) for acceptable suffix
|
||||
"""Check file(s) for acceptable suffix."""
|
||||
if file and suffix:
|
||||
if isinstance(suffix, str):
|
||||
suffix = (suffix, )
|
||||
@ -255,7 +255,7 @@ def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''):
|
||||
|
||||
|
||||
def check_yolov5u_filename(file: str, verbose: bool = True):
|
||||
# Replace legacy YOLOv5 filenames with updated YOLOv5u filenames
|
||||
"""Replace legacy YOLOv5 filenames with updated YOLOv5u filenames."""
|
||||
if ('yolov3' in file or 'yolov5' in file) and 'u' not in file:
|
||||
original_file = file
|
||||
file = re.sub(r'(.*yolov5([nsmlx]))\.pt', '\\1u.pt', file) # i.e. yolov5n.pt -> yolov5nu.pt
|
||||
@ -269,7 +269,7 @@ def check_yolov5u_filename(file: str, verbose: bool = True):
|
||||
|
||||
|
||||
def check_file(file, suffix='', download=True, hard=True):
|
||||
# Search/download file (if necessary) and return path
|
||||
"""Search/download file (if necessary) and return path."""
|
||||
check_suffix(file, suffix) # optional
|
||||
file = str(file).strip() # convert to string and strip spaces
|
||||
file = check_yolov5u_filename(file) # yolov5n -> yolov5nu
|
||||
@ -300,7 +300,7 @@ def check_yaml(file, suffix=('.yaml', '.yml'), hard=True):
|
||||
|
||||
|
||||
def check_imshow(warn=False):
|
||||
# Check if environment supports image displays
|
||||
"""Check if environment supports image displays."""
|
||||
try:
|
||||
assert not any((is_colab(), is_kaggle(), is_docker()))
|
||||
cv2.imshow('test', np.zeros((1, 1, 3)))
|
||||
@ -346,9 +346,10 @@ def git_describe(path=ROOT): # path must be a directory
|
||||
|
||||
|
||||
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
|
||||
# Print function arguments (optional args dict)
|
||||
"""Print function arguments (optional args dict)."""
|
||||
|
||||
def strip_auth(v):
|
||||
# Clean longer Ultralytics HUB URLs by stripping potential authentication information
|
||||
"""Clean longer Ultralytics HUB URLs by stripping potential authentication information."""
|
||||
return clean_url(v) if (isinstance(v, str) and v.startswith('http') and len(v) > 100) else v
|
||||
|
||||
x = inspect.currentframe().f_back # previous frame
|
||||
|
@ -59,6 +59,6 @@ def generate_ddp_command(world_size, trainer):
|
||||
|
||||
|
||||
def ddp_cleanup(trainer, file):
|
||||
# delete temp file if created
|
||||
"""Delete temp file if created."""
|
||||
if f'{id(trainer)}.py' in file: # if temp_file suffix in file
|
||||
os.remove(file)
|
||||
|
@ -21,7 +21,7 @@ GITHUB_ASSET_STEMS = [Path(k).stem for k in GITHUB_ASSET_NAMES]
|
||||
|
||||
|
||||
def is_url(url, check=True):
|
||||
# Check if string is URL and check if URL exists
|
||||
"""Check if string is URL and check if URL exists."""
|
||||
with contextlib.suppress(Exception):
|
||||
url = str(url)
|
||||
result = parse.urlparse(url)
|
||||
@ -141,11 +141,11 @@ def safe_download(url,
|
||||
|
||||
|
||||
def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'):
|
||||
# Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
|
||||
"""Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc."""
|
||||
from ultralytics.yolo.utils import SETTINGS # scoped for circular import
|
||||
|
||||
def github_assets(repository, version='latest'):
|
||||
# Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov8s.pt', ...])
|
||||
"""Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov8s.pt', ...])."""
|
||||
if version != 'latest':
|
||||
version = f'tags/{version}' # i.e. tags/v6.2
|
||||
response = requests.get(f'https://api.github.com/repos/{repository}/releases/{version}').json() # github api
|
||||
|
@ -8,7 +8,8 @@ from pathlib import Path
|
||||
|
||||
|
||||
class WorkingDirectory(contextlib.ContextDecorator):
|
||||
# Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
|
||||
"""Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager."""
|
||||
|
||||
def __init__(self, new_dir):
|
||||
self.dir = new_dir # new dir
|
||||
self.cwd = Path.cwd().resolve() # current dir
|
||||
@ -56,19 +57,19 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False):
|
||||
|
||||
|
||||
def file_age(path=__file__):
|
||||
# Return days since last file update
|
||||
"""Return days since last file update."""
|
||||
dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
|
||||
return dt.days # + dt.seconds / 86400 # fractional days
|
||||
|
||||
|
||||
def file_date(path=__file__):
|
||||
# Return human-readable file modification date, i.e. '2021-3-26'
|
||||
"""Return human-readable file modification date, i.e. '2021-3-26'."""
|
||||
t = datetime.fromtimestamp(Path(path).stat().st_mtime)
|
||||
return f'{t.year}-{t.month}-{t.day}'
|
||||
|
||||
|
||||
def file_size(path):
|
||||
# Return file/dir size (MB)
|
||||
"""Return file/dir size (MB)."""
|
||||
if isinstance(path, (str, Path)):
|
||||
mb = 1 << 20 # bytes to MiB (1024 ** 2)
|
||||
path = Path(path)
|
||||
@ -80,6 +81,6 @@ def file_size(path):
|
||||
|
||||
|
||||
def get_latest_run(search_dir='.'):
|
||||
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
|
||||
"""Return path to most recent 'last.pt' in /runs (i.e. to --resume from)."""
|
||||
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
|
||||
return max(last_list, key=os.path.getctime) if last_list else ''
|
||||
|
@ -11,7 +11,8 @@ from .ops import ltwh2xywh, ltwh2xyxy, resample_segments, xywh2ltwh, xywh2xyxy,
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
# From PyTorch internals
|
||||
"""From PyTorch internals."""
|
||||
|
||||
def parse(x):
|
||||
return x if isinstance(x, abc.Iterable) else tuple(repeat(x, n))
|
||||
|
||||
@ -29,7 +30,7 @@ __all__ = 'Bboxes', # tuple or list
|
||||
|
||||
|
||||
class Bboxes:
|
||||
"""Now only numpy is supported"""
|
||||
"""Now only numpy is supported."""
|
||||
|
||||
def __init__(self, bboxes, format='xyxy') -> None:
|
||||
assert format in _formats
|
||||
@ -80,7 +81,7 @@ class Bboxes:
|
||||
return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1])
|
||||
|
||||
# def denormalize(self, w, h):
|
||||
# if not self.normalized:
|
||||
# if not self.normalized:
|
||||
# return
|
||||
# assert (self.bboxes <= 1.0).all()
|
||||
# self.bboxes[:, 0::2] *= w
|
||||
@ -207,7 +208,7 @@ class Instances:
|
||||
self._bboxes.areas()
|
||||
|
||||
def scale(self, scale_w, scale_h, bbox_only=False):
|
||||
"""this might be similar with denormalize func but without normalized sign"""
|
||||
"""this might be similar with denormalize func but without normalized sign."""
|
||||
self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h))
|
||||
if bbox_only:
|
||||
return
|
||||
@ -240,7 +241,7 @@ class Instances:
|
||||
self.normalized = True
|
||||
|
||||
def add_padding(self, padw, padh):
|
||||
# handle rect and mosaic situation
|
||||
"""Handle rect and mosaic situation."""
|
||||
assert not self.normalized, 'you should add padding with absolute coordinates.'
|
||||
self._bboxes.add(offset=(padw, padh, padw, padh))
|
||||
self.segments[..., 0] += padw
|
||||
|
@ -9,7 +9,8 @@ from .tal import bbox2dist
|
||||
|
||||
|
||||
class VarifocalLoss(nn.Module):
|
||||
# Varifocal loss by Zhang et al. https://arxiv.org/abs/2008.13367
|
||||
"""Varifocal loss by Zhang et al. https://arxiv.org/abs/2008.13367."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@ -29,7 +30,7 @@ class BboxLoss(nn.Module):
|
||||
self.use_dfl = use_dfl
|
||||
|
||||
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
|
||||
# IoU loss
|
||||
"""IoU loss."""
|
||||
weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1)
|
||||
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
|
||||
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
|
||||
@ -46,7 +47,7 @@ class BboxLoss(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def _df_loss(pred_dist, target):
|
||||
# Return sum of left and right DFL losses
|
||||
"""Return sum of left and right DFL losses."""
|
||||
# Distribution Focal Loss (DFL) proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
|
||||
tl = target.long() # target left
|
||||
tr = tl + 1 # target right
|
||||
|
@ -16,9 +16,9 @@ from ultralytics.yolo.utils import LOGGER, SimpleClass, TryExcept, plt_settings
|
||||
OKS_SIGMA = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0
|
||||
|
||||
|
||||
# boxes
|
||||
# Boxes
|
||||
def box_area(box):
|
||||
# box = xyxy(4,n)
|
||||
"""Return box area, where box shape is xyxy(4,n)."""
|
||||
return (box[2] - box[0]) * (box[3] - box[1])
|
||||
|
||||
|
||||
@ -175,9 +175,10 @@ def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#iss
|
||||
return 1.0 - 0.5 * eps, 0.5 * eps
|
||||
|
||||
|
||||
# losses
|
||||
# Losses
|
||||
class FocalLoss(nn.Module):
|
||||
# Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
|
||||
"""Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
|
||||
|
||||
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
|
||||
super().__init__()
|
||||
self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
|
||||
@ -341,7 +342,7 @@ class ConfusionMatrix:
|
||||
|
||||
|
||||
def smooth(y, f=0.05):
|
||||
# Box filter of fraction f
|
||||
"""Box filter of fraction f."""
|
||||
nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
|
||||
p = np.ones(nf // 2) # ones padding
|
||||
yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded
|
||||
@ -350,7 +351,7 @@ def smooth(y, f=0.05):
|
||||
|
||||
@plt_settings()
|
||||
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
|
||||
# Precision-recall curve
|
||||
"""Plots a precision-recall curve."""
|
||||
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
||||
py = np.stack(py, axis=1)
|
||||
|
||||
@ -373,7 +374,7 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
|
||||
|
||||
@plt_settings()
|
||||
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
|
||||
# Metric-confidence curve
|
||||
"""Plots a metric-confidence curve."""
|
||||
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
||||
|
||||
if 0 < len(names) < 21: # display per-class legend if < 21 classes
|
||||
@ -614,23 +615,23 @@ class Metric(SimpleClass):
|
||||
return self.all_ap.mean() if len(self.all_ap) else 0.0
|
||||
|
||||
def mean_results(self):
|
||||
"""Mean of results, return mp, mr, map50, map"""
|
||||
"""Mean of results, return mp, mr, map50, map."""
|
||||
return [self.mp, self.mr, self.map50, self.map]
|
||||
|
||||
def class_result(self, i):
|
||||
"""class-aware result, return p[i], r[i], ap50[i], ap[i]"""
|
||||
"""class-aware result, return p[i], r[i], ap50[i], ap[i]."""
|
||||
return self.p[i], self.r[i], self.ap50[i], self.ap[i]
|
||||
|
||||
@property
|
||||
def maps(self):
|
||||
"""mAP of each class"""
|
||||
"""mAP of each class."""
|
||||
maps = np.zeros(self.nc) + self.map
|
||||
for i, c in enumerate(self.ap_class_index):
|
||||
maps[c] = self.ap[i]
|
||||
return maps
|
||||
|
||||
def fitness(self):
|
||||
# Model fitness as a weighted combination of metrics
|
||||
"""Model fitness as a weighted combination of metrics."""
|
||||
w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
|
||||
return (np.array(self.mean_results()) * w).sum()
|
||||
|
||||
@ -800,7 +801,7 @@ class SegmentMetrics(SimpleClass):
|
||||
|
||||
@property
|
||||
def ap_class_index(self):
|
||||
# boxes and masks have the same ap_class_index
|
||||
"""Boxes and masks have the same ap_class_index."""
|
||||
return self.box.ap_class_index
|
||||
|
||||
@property
|
||||
@ -926,7 +927,7 @@ class ClassifyMetrics(SimpleClass):
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
|
||||
def process(self, targets, pred):
|
||||
# target classes and predicted classes
|
||||
"""Target classes and predicted classes."""
|
||||
pred, targets = torch.cat(pred), torch.cat(targets)
|
||||
correct = (targets[:, None] == pred).float()
|
||||
acc = torch.stack((correct[:, 0], correct.max(1).values), dim=1) # (top1, top5) accuracy
|
||||
|
@ -246,7 +246,7 @@ def non_max_suppression(
|
||||
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
||||
i = i[:max_det] # limit detections
|
||||
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
|
||||
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
|
||||
# Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
|
||||
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
|
||||
weights = iou * scores[None] # box weights
|
||||
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
|
||||
|
@ -21,7 +21,7 @@ from .ops import clip_boxes, scale_image, xywh2xyxy, xyxy2xywh
|
||||
class Colors:
|
||||
# Ultralytics color palette https://ultralytics.com/
|
||||
def __init__(self):
|
||||
# hex = matplotlib.colors.TABLEAU_COLORS.values()
|
||||
"""Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values()."""
|
||||
hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
|
||||
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
|
||||
self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
|
||||
@ -63,7 +63,7 @@ class Annotator:
|
||||
else: # use cv2
|
||||
self.im = im
|
||||
self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
|
||||
# pose
|
||||
# Pose
|
||||
self.skeleton = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13], [6, 7], [6, 8], [7, 9],
|
||||
[8, 10], [9, 11], [2, 3], [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]]
|
||||
|
||||
@ -115,7 +115,7 @@ class Annotator:
|
||||
alpha (float): mask transparency: 0.0 fully transparent, 1.0 opaque
|
||||
"""
|
||||
if self.pil:
|
||||
# convert to numpy first
|
||||
# Convert to numpy first
|
||||
self.im = np.asarray(self.im).copy()
|
||||
if len(masks) == 0:
|
||||
self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
|
||||
@ -136,7 +136,7 @@ class Annotator:
|
||||
im_mask_np = im_mask.byte().cpu().numpy()
|
||||
self.im[:] = im_mask_np if retina_masks else scale_image(im_mask_np, self.im.shape)
|
||||
if self.pil:
|
||||
# convert im back to PIL and update draw
|
||||
# Convert im back to PIL and update draw
|
||||
self.fromarray(self.im)
|
||||
|
||||
def kpts(self, kpts, shape=(640, 640), radius=5, kpt_line=True):
|
||||
@ -152,7 +152,7 @@ class Annotator:
|
||||
Note: `kpt_line=True` currently only supports human pose plotting.
|
||||
"""
|
||||
if self.pil:
|
||||
# convert to numpy first
|
||||
# Convert to numpy first
|
||||
self.im = np.asarray(self.im).copy()
|
||||
nkpt, ndim = kpts.shape
|
||||
is_pose = nkpt == 17 and ndim == 3
|
||||
@ -183,11 +183,11 @@ class Annotator:
|
||||
continue
|
||||
cv2.line(self.im, pos1, pos2, [int(x) for x in self.limb_color[i]], thickness=2, lineType=cv2.LINE_AA)
|
||||
if self.pil:
|
||||
# convert im back to PIL and update draw
|
||||
# Convert im back to PIL and update draw
|
||||
self.fromarray(self.im)
|
||||
|
||||
def rectangle(self, xy, fill=None, outline=None, width=1):
|
||||
# Add rectangle to image (PIL-only)
|
||||
"""Add rectangle to image (PIL-only)."""
|
||||
self.draw.rectangle(xy, fill, outline, width)
|
||||
|
||||
def text(self, xy, text, txt_color=(255, 255, 255), anchor='top'):
|
||||
@ -202,12 +202,12 @@ class Annotator:
|
||||
cv2.putText(self.im, text, xy, 0, self.lw / 3, txt_color, thickness=tf, lineType=cv2.LINE_AA)
|
||||
|
||||
def fromarray(self, im):
|
||||
# Update self.im from a numpy array
|
||||
"""Update self.im from a numpy array."""
|
||||
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
|
||||
self.draw = ImageDraw.Draw(self.im)
|
||||
|
||||
def result(self):
|
||||
# Return annotated image as array
|
||||
"""Return annotated image as array."""
|
||||
return np.asarray(self.im)
|
||||
|
||||
|
||||
@ -217,18 +217,18 @@ def plot_labels(boxes, cls, names=(), save_dir=Path('')):
|
||||
import pandas as pd
|
||||
import seaborn as sn
|
||||
|
||||
# plot dataset labels
|
||||
# Plot dataset labels
|
||||
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
|
||||
b = boxes.transpose() # classes, boxes
|
||||
nc = int(cls.max() + 1) # number of classes
|
||||
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
|
||||
|
||||
# seaborn correlogram
|
||||
# Seaborn correlogram
|
||||
sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
|
||||
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
|
||||
plt.close()
|
||||
|
||||
# matplotlib labels
|
||||
# Matplotlib labels
|
||||
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
|
||||
y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
|
||||
with contextlib.suppress(Exception): # color histogram bars by class
|
||||
@ -242,7 +242,7 @@ def plot_labels(boxes, cls, names=(), save_dir=Path('')):
|
||||
sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
|
||||
sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
|
||||
|
||||
# rectangles
|
||||
# Rectangles
|
||||
boxes[:, 0:2] = 0.5 # center
|
||||
boxes = xywh2xyxy(boxes) * 1000
|
||||
img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)
|
||||
@ -401,7 +401,7 @@ def plot_images(images,
|
||||
|
||||
@plt_settings()
|
||||
def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False):
|
||||
# Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
|
||||
"""Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')."""
|
||||
import pandas as pd
|
||||
save_dir = Path(file).parent if file else Path(dir)
|
||||
if segment:
|
||||
@ -436,7 +436,7 @@ def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False):
|
||||
|
||||
|
||||
def output_to_target(output, max_det=300):
|
||||
# Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting
|
||||
"""Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
|
||||
targets = []
|
||||
for i, o in enumerate(output):
|
||||
box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
|
||||
|
@ -48,7 +48,7 @@ def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
|
||||
is_max_overlaps = is_max_overlaps.permute(0, 2, 1).to(overlaps.dtype) # (b, n_max_boxes, h*w)
|
||||
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos) # (b, n_max_boxes, h*w)
|
||||
fg_mask = mask_pos.sum(-2)
|
||||
# find each grid serve which gt(index)
|
||||
# Find each grid serve which gt(index)
|
||||
target_gt_idx = mask_pos.argmax(-2) # (b, h*w)
|
||||
return target_gt_idx, fg_mask, mask_pos
|
||||
|
||||
@ -112,10 +112,10 @@ class TaskAlignedAssigner(nn.Module):
|
||||
|
||||
target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
|
||||
|
||||
# assigned target
|
||||
# Assigned target
|
||||
target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)
|
||||
|
||||
# normalize
|
||||
# Normalize
|
||||
align_metric *= mask_pos
|
||||
pos_align_metrics = align_metric.amax(axis=-1, keepdim=True) # b, max_num_obj
|
||||
pos_overlaps = (overlaps * mask_pos).amax(axis=-1, keepdim=True) # b, max_num_obj
|
||||
@ -125,13 +125,13 @@ class TaskAlignedAssigner(nn.Module):
|
||||
return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
|
||||
|
||||
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
|
||||
# get in_gts mask, (b, max_num_obj, h*w)
|
||||
"""Get in_gts mask, (b, max_num_obj, h*w)."""
|
||||
mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
|
||||
# get anchor_align metric, (b, max_num_obj, h*w)
|
||||
# Get anchor_align metric, (b, max_num_obj, h*w)
|
||||
align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
|
||||
# get topk_metric mask, (b, max_num_obj, h*w)
|
||||
# Get topk_metric mask, (b, max_num_obj, h*w)
|
||||
mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.repeat([1, 1, self.topk]).bool())
|
||||
# merge all mask to a final mask, (b, max_num_obj, h*w)
|
||||
# Merge all mask to a final mask, (b, max_num_obj, h*w)
|
||||
mask_pos = mask_topk * mask_in_gts * mask_gt
|
||||
|
||||
return mask_pos, align_metric, overlaps
|
||||
@ -145,7 +145,7 @@ class TaskAlignedAssigner(nn.Module):
|
||||
ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj
|
||||
ind[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.n_max_boxes) # b, max_num_obj
|
||||
ind[1] = gt_labels.long().squeeze(-1) # b, max_num_obj
|
||||
# get the scores of each grid for each gt cls
|
||||
# Get the scores of each grid for each gt cls
|
||||
bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt] # b, max_num_obj, h*w
|
||||
|
||||
# (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
|
||||
|
@ -30,7 +30,7 @@ TORCH_2_X = check_version(torch.__version__, minimum='2.0')
|
||||
|
||||
@contextmanager
|
||||
def torch_distributed_zero_first(local_rank: int):
|
||||
# Decorator to make all processes in distributed training wait for each local_master to do something
|
||||
"""Decorator to make all processes in distributed training wait for each local_master to do something."""
|
||||
initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
|
||||
if initialized and local_rank not in (-1, 0):
|
||||
dist.barrier(device_ids=[local_rank])
|
||||
@ -40,7 +40,8 @@ def torch_distributed_zero_first(local_rank: int):
|
||||
|
||||
|
||||
def smart_inference_mode():
|
||||
# Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator
|
||||
"""Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator."""
|
||||
|
||||
def decorate(fn):
|
||||
return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)
|
||||
|
||||
@ -48,7 +49,7 @@ def smart_inference_mode():
|
||||
|
||||
|
||||
def select_device(device='', batch=0, newline=False, verbose=True):
|
||||
# device = None or 'cpu' or 0 or '0' or '0,1,2,3'
|
||||
"""Selects PyTorch Device. Options are device = None or 'cpu' or 0 or '0' or '0,1,2,3'."""
|
||||
s = f'Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
|
||||
device = str(device).lower()
|
||||
for remove in 'cuda:', 'none', '(', ')', '[', ']', "'", ' ':
|
||||
@ -84,7 +85,7 @@ def select_device(device='', batch=0, newline=False, verbose=True):
|
||||
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
|
||||
arg = 'cuda:0'
|
||||
elif mps and getattr(torch, 'has_mps', False) and torch.backends.mps.is_available() and TORCH_2_X:
|
||||
# prefer MPS if available
|
||||
# Prefer MPS if available
|
||||
s += 'MPS\n'
|
||||
arg = 'mps'
|
||||
else: # revert to CPU
|
||||
@ -97,14 +98,14 @@ def select_device(device='', batch=0, newline=False, verbose=True):
|
||||
|
||||
|
||||
def time_sync():
|
||||
# PyTorch-accurate time
|
||||
"""PyTorch-accurate time."""
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
return time.time()
|
||||
|
||||
|
||||
def fuse_conv_and_bn(conv, bn):
|
||||
# Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
||||
"""Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/."""
|
||||
fusedconv = nn.Conv2d(conv.in_channels,
|
||||
conv.out_channels,
|
||||
kernel_size=conv.kernel_size,
|
||||
@ -128,7 +129,7 @@ def fuse_conv_and_bn(conv, bn):
|
||||
|
||||
|
||||
def fuse_deconv_and_bn(deconv, bn):
|
||||
# Fuse ConvTranspose2d() and BatchNorm2d() layers
|
||||
"""Fuse ConvTranspose2d() and BatchNorm2d() layers."""
|
||||
fuseddconv = nn.ConvTranspose2d(deconv.in_channels,
|
||||
deconv.out_channels,
|
||||
kernel_size=deconv.kernel_size,
|
||||
@ -139,7 +140,7 @@ def fuse_deconv_and_bn(deconv, bn):
|
||||
groups=deconv.groups,
|
||||
bias=True).requires_grad_(False).to(deconv.weight.device)
|
||||
|
||||
# prepare filters
|
||||
# Prepare filters
|
||||
w_deconv = deconv.weight.clone().view(deconv.out_channels, -1)
|
||||
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
||||
fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))
|
||||
@ -153,7 +154,7 @@ def fuse_deconv_and_bn(deconv, bn):
|
||||
|
||||
|
||||
def model_info(model, detailed=False, verbose=True, imgsz=640):
|
||||
# Model information. imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]
|
||||
"""Model information. imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]."""
|
||||
if not verbose:
|
||||
return
|
||||
n_p = get_num_params(model)
|
||||
@ -174,17 +175,17 @@ def model_info(model, detailed=False, verbose=True, imgsz=640):
|
||||
|
||||
|
||||
def get_num_params(model):
|
||||
# Return the total number of parameters in a YOLO model
|
||||
"""Return the total number of parameters in a YOLO model."""
|
||||
return sum(x.numel() for x in model.parameters())
|
||||
|
||||
|
||||
def get_num_gradients(model):
|
||||
# Return the total number of parameters with gradients in a YOLO model
|
||||
"""Return the total number of parameters with gradients in a YOLO model."""
|
||||
return sum(x.numel() for x in model.parameters() if x.requires_grad)
|
||||
|
||||
|
||||
def get_flops(model, imgsz=640):
|
||||
# Return a YOLO model's FLOPs
|
||||
"""Return a YOLO model's FLOPs."""
|
||||
try:
|
||||
model = de_parallel(model)
|
||||
p = next(model.parameters())
|
||||
@ -199,7 +200,7 @@ def get_flops(model, imgsz=640):
|
||||
|
||||
|
||||
def initialize_weights(model):
|
||||
# Initialize model weights to random values
|
||||
"""Initialize model weights to random values."""
|
||||
for m in model.modules():
|
||||
t = type(m)
|
||||
if t is nn.Conv2d:
|
||||
@ -224,7 +225,7 @@ def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
|
||||
|
||||
|
||||
def make_divisible(x, divisor):
|
||||
# Returns nearest x divisible by divisor
|
||||
"""Returns nearest x divisible by divisor."""
|
||||
if isinstance(divisor, torch.Tensor):
|
||||
divisor = int(divisor.max()) # to int
|
||||
return math.ceil(x / divisor) * divisor
|
||||
@ -240,7 +241,7 @@ def copy_attr(a, b, include=(), exclude=()):
|
||||
|
||||
|
||||
def get_latest_opset():
|
||||
# Return second-most (for maturity) recently supported ONNX opset by this version of torch
|
||||
"""Return second-most (for maturity) recently supported ONNX opset by this version of torch."""
|
||||
return max(int(k[14:]) for k in vars(torch.onnx) if 'symbolic_opset' in k) - 1 # opset
|
||||
|
||||
|
||||
@ -250,22 +251,22 @@ def intersect_dicts(da, db, exclude=()):
|
||||
|
||||
|
||||
def is_parallel(model):
|
||||
# Returns True if model is of type DP or DDP
|
||||
"""Returns True if model is of type DP or DDP."""
|
||||
return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))
|
||||
|
||||
|
||||
def de_parallel(model):
|
||||
# De-parallelize a model: returns single-GPU model if model is of type DP or DDP
|
||||
"""De-parallelize a model: returns single-GPU model if model is of type DP or DDP."""
|
||||
return model.module if is_parallel(model) else model
|
||||
|
||||
|
||||
def one_cycle(y1=0.0, y2=1.0, steps=100):
|
||||
# lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
|
||||
"""Returns a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf."""
|
||||
return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
|
||||
|
||||
|
||||
def init_seeds(seed=0, deterministic=False):
|
||||
# Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
|
||||
"""Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
@ -280,14 +281,14 @@ def init_seeds(seed=0, deterministic=False):
|
||||
|
||||
|
||||
class ModelEMA:
|
||||
""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
|
||||
"""Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
|
||||
Keeps a moving average of everything in the model state_dict (parameters and buffers)
|
||||
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||
To disable EMA set the `enabled` attribute to `False`.
|
||||
"""
|
||||
|
||||
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
||||
# Create EMA
|
||||
"""Create EMA."""
|
||||
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
||||
self.updates = updates # number of EMA updates
|
||||
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
||||
@ -296,7 +297,7 @@ class ModelEMA:
|
||||
self.enabled = True
|
||||
|
||||
def update(self, model):
|
||||
# Update EMA parameters
|
||||
"""Update EMA parameters."""
|
||||
if self.enabled:
|
||||
self.updates += 1
|
||||
d = self.decay(self.updates)
|
||||
|
Reference in New Issue
Block a user