ultralytics 8.0.80 single-line docstring fixes (#2060)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-04-16 15:20:11 +02:00
committed by GitHub
parent 31db8ed163
commit 5bce1c3021
48 changed files with 418 additions and 420 deletions

View File

@ -199,7 +199,7 @@ def plt_settings(rcparams={'font.size': 11}, backend='Agg'):
def set_logging(name=LOGGING_NAME, verbose=True):
# sets up logging for the given name
"""Sets up logging for the given name."""
rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
logging.config.dictConfig({
@ -539,12 +539,12 @@ SETTINGS_YAML = USER_CONFIG_DIR / 'settings.yaml'
def emojis(string=''):
# Return platform-dependent emoji-safe version of string
"""Return platform-dependent emoji-safe version of string."""
return string.encode().decode('ascii', 'ignore') if WINDOWS else string
def colorstr(*input):
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
"""Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')."""
*args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
colors = {
'black': '\033[30m', # basic colors
@ -570,7 +570,8 @@ def colorstr(*input):
class TryExcept(contextlib.ContextDecorator):
# YOLOv8 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager
"""YOLOv8 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager."""
def __init__(self, msg='', verbose=True):
self.msg = msg
self.verbose = verbose
@ -585,7 +586,8 @@ class TryExcept(contextlib.ContextDecorator):
def threaded(func):
# Multi-threads a target function and returns thread. Usage: @threaded decorator
"""Multi-threads a target function and returns thread. Usage: @threaded decorator."""
def wrapper(*args, **kwargs):
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
@ -703,13 +705,13 @@ def deprecation_warn(arg, new_arg, version=None):
def clean_url(url):
# Strip auth from URL, i.e. https://url.com/file.txt?auth -> https://url.com/file.txt
"""Strip auth from URL, i.e. https://url.com/file.txt?auth -> https://url.com/file.txt."""
url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
return urllib.parse.unquote(url).split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
def url2file(url):
# Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
"""Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt."""
return Path(clean_url(url)).name

View File

@ -15,20 +15,20 @@ except (ImportError, AssertionError):
COMET_MODE = os.getenv('COMET_MODE', 'online')
COMET_MODEL_NAME = os.getenv('COMET_MODEL_NAME', 'YOLOv8')
# determines how many batches of image predictions to log from the validation set
# Determines how many batches of image predictions to log from the validation set
COMET_EVAL_BATCH_LOGGING_INTERVAL = int(os.getenv('COMET_EVAL_BATCH_LOGGING_INTERVAL', 1))
# determines whether to log confusion matrix every evaluation epoch
# Determines whether to log confusion matrix every evaluation epoch
COMET_EVAL_LOG_CONFUSION_MATRIX = (os.getenv('COMET_EVAL_LOG_CONFUSION_MATRIX', 'true').lower() == 'true')
# determines whether to log image predictions every evaluation epoch
# Determines whether to log image predictions every evaluation epoch
COMET_EVAL_LOG_IMAGE_PREDICTIONS = (os.getenv('COMET_EVAL_LOG_IMAGE_PREDICTIONS', 'true').lower() == 'true')
COMET_MAX_IMAGE_PREDICTIONS = int(os.getenv('COMET_MAX_IMAGE_PREDICTIONS', 100))
# ensures certain logging functions only run for supported tasks
# Ensures certain logging functions only run for supported tasks
COMET_SUPPORTED_TASKS = ['detect']
# scales reported confidence scores (0.0-1.0) by this value
# Scales reported confidence scores (0.0-1.0) by this value
COMET_MAX_CONFIDENCE_SCORE = int(os.getenv('COMET_MAX_CONFIDENCE_SCORE', 100))
# names of plots created by YOLOv8 that are logged to Comet
# Names of plots created by YOLOv8 that are logged to Comet
EVALUATION_PLOT_NAMES = 'F1_curve', 'P_curve', 'R_curve', 'PR_curve', 'confusion_matrix'
LABEL_PLOT_NAMES = 'labels', 'labels_correlogram'
@ -43,7 +43,7 @@ def _get_experiment_type(mode, project_name):
def _create_experiment(args):
# Ensures that the experiment object is only created in a single process during distributed training.
"""Ensures that the experiment object is only created in a single process during distributed training."""
if RANK not in (-1, 0):
return
try:
@ -83,13 +83,13 @@ def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, origin
resized_image_height, resized_image_width = resized_image_shape
# convert normalized xywh format predictions to xyxy in resized scale format
# Convert normalized xywh format predictions to xyxy in resized scale format
box = ops.xywhn2xyxy(box, h=resized_image_height, w=resized_image_width)
# scale box predictions from resized image scale back to original image scale
# Scale box predictions from resized image scale back to original image scale
box = ops.scale_boxes(resized_image_shape, box, original_image_shape, ratio_pad)
# Convert bounding box format from xyxy to xywh for Comet logging
box = ops.xyxy2xywh(box)
# adjust xy center to correspond top-left corner
# Adjust xy center to correspond top-left corner
box[:2] -= box[2:] / 2
box = box.tolist()

View File

@ -244,7 +244,7 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''):
# Check file(s) for acceptable suffix
"""Check file(s) for acceptable suffix."""
if file and suffix:
if isinstance(suffix, str):
suffix = (suffix, )
@ -255,7 +255,7 @@ def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''):
def check_yolov5u_filename(file: str, verbose: bool = True):
# Replace legacy YOLOv5 filenames with updated YOLOv5u filenames
"""Replace legacy YOLOv5 filenames with updated YOLOv5u filenames."""
if ('yolov3' in file or 'yolov5' in file) and 'u' not in file:
original_file = file
file = re.sub(r'(.*yolov5([nsmlx]))\.pt', '\\1u.pt', file) # i.e. yolov5n.pt -> yolov5nu.pt
@ -269,7 +269,7 @@ def check_yolov5u_filename(file: str, verbose: bool = True):
def check_file(file, suffix='', download=True, hard=True):
# Search/download file (if necessary) and return path
"""Search/download file (if necessary) and return path."""
check_suffix(file, suffix) # optional
file = str(file).strip() # convert to string and strip spaces
file = check_yolov5u_filename(file) # yolov5n -> yolov5nu
@ -300,7 +300,7 @@ def check_yaml(file, suffix=('.yaml', '.yml'), hard=True):
def check_imshow(warn=False):
# Check if environment supports image displays
"""Check if environment supports image displays."""
try:
assert not any((is_colab(), is_kaggle(), is_docker()))
cv2.imshow('test', np.zeros((1, 1, 3)))
@ -346,9 +346,10 @@ def git_describe(path=ROOT): # path must be a directory
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
# Print function arguments (optional args dict)
"""Print function arguments (optional args dict)."""
def strip_auth(v):
# Clean longer Ultralytics HUB URLs by stripping potential authentication information
"""Clean longer Ultralytics HUB URLs by stripping potential authentication information."""
return clean_url(v) if (isinstance(v, str) and v.startswith('http') and len(v) > 100) else v
x = inspect.currentframe().f_back # previous frame

View File

@ -59,6 +59,6 @@ def generate_ddp_command(world_size, trainer):
def ddp_cleanup(trainer, file):
# delete temp file if created
"""Delete temp file if created."""
if f'{id(trainer)}.py' in file: # if temp_file suffix in file
os.remove(file)

View File

@ -21,7 +21,7 @@ GITHUB_ASSET_STEMS = [Path(k).stem for k in GITHUB_ASSET_NAMES]
def is_url(url, check=True):
# Check if string is URL and check if URL exists
"""Check if string is URL and check if URL exists."""
with contextlib.suppress(Exception):
url = str(url)
result = parse.urlparse(url)
@ -141,11 +141,11 @@ def safe_download(url,
def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'):
# Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
"""Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc."""
from ultralytics.yolo.utils import SETTINGS # scoped for circular import
def github_assets(repository, version='latest'):
# Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov8s.pt', ...])
"""Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov8s.pt', ...])."""
if version != 'latest':
version = f'tags/{version}' # i.e. tags/v6.2
response = requests.get(f'https://api.github.com/repos/{repository}/releases/{version}').json() # github api

View File

@ -8,7 +8,8 @@ from pathlib import Path
class WorkingDirectory(contextlib.ContextDecorator):
# Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
"""Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager."""
def __init__(self, new_dir):
self.dir = new_dir # new dir
self.cwd = Path.cwd().resolve() # current dir
@ -56,19 +57,19 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False):
def file_age(path=__file__):
# Return days since last file update
"""Return days since last file update."""
dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
return dt.days # + dt.seconds / 86400 # fractional days
def file_date(path=__file__):
# Return human-readable file modification date, i.e. '2021-3-26'
"""Return human-readable file modification date, i.e. '2021-3-26'."""
t = datetime.fromtimestamp(Path(path).stat().st_mtime)
return f'{t.year}-{t.month}-{t.day}'
def file_size(path):
# Return file/dir size (MB)
"""Return file/dir size (MB)."""
if isinstance(path, (str, Path)):
mb = 1 << 20 # bytes to MiB (1024 ** 2)
path = Path(path)
@ -80,6 +81,6 @@ def file_size(path):
def get_latest_run(search_dir='.'):
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
"""Return path to most recent 'last.pt' in /runs (i.e. to --resume from)."""
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
return max(last_list, key=os.path.getctime) if last_list else ''

View File

@ -11,7 +11,8 @@ from .ops import ltwh2xywh, ltwh2xyxy, resample_segments, xywh2ltwh, xywh2xyxy,
def _ntuple(n):
# From PyTorch internals
"""From PyTorch internals."""
def parse(x):
return x if isinstance(x, abc.Iterable) else tuple(repeat(x, n))
@ -29,7 +30,7 @@ __all__ = 'Bboxes', # tuple or list
class Bboxes:
"""Now only numpy is supported"""
"""Now only numpy is supported."""
def __init__(self, bboxes, format='xyxy') -> None:
assert format in _formats
@ -80,7 +81,7 @@ class Bboxes:
return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1])
# def denormalize(self, w, h):
# if not self.normalized:
# if not self.normalized:
# return
# assert (self.bboxes <= 1.0).all()
# self.bboxes[:, 0::2] *= w
@ -207,7 +208,7 @@ class Instances:
self._bboxes.areas()
def scale(self, scale_w, scale_h, bbox_only=False):
"""this might be similar with denormalize func but without normalized sign"""
"""this might be similar with denormalize func but without normalized sign."""
self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h))
if bbox_only:
return
@ -240,7 +241,7 @@ class Instances:
self.normalized = True
def add_padding(self, padw, padh):
# handle rect and mosaic situation
"""Handle rect and mosaic situation."""
assert not self.normalized, 'you should add padding with absolute coordinates.'
self._bboxes.add(offset=(padw, padh, padw, padh))
self.segments[..., 0] += padw

View File

@ -9,7 +9,8 @@ from .tal import bbox2dist
class VarifocalLoss(nn.Module):
# Varifocal loss by Zhang et al. https://arxiv.org/abs/2008.13367
"""Varifocal loss by Zhang et al. https://arxiv.org/abs/2008.13367."""
def __init__(self):
super().__init__()
@ -29,7 +30,7 @@ class BboxLoss(nn.Module):
self.use_dfl = use_dfl
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
# IoU loss
"""IoU loss."""
weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1)
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
@ -46,7 +47,7 @@ class BboxLoss(nn.Module):
@staticmethod
def _df_loss(pred_dist, target):
# Return sum of left and right DFL losses
"""Return sum of left and right DFL losses."""
# Distribution Focal Loss (DFL) proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
tl = target.long() # target left
tr = tl + 1 # target right

View File

@ -16,9 +16,9 @@ from ultralytics.yolo.utils import LOGGER, SimpleClass, TryExcept, plt_settings
OKS_SIGMA = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0
# boxes
# Boxes
def box_area(box):
# box = xyxy(4,n)
"""Return box area, where box shape is xyxy(4,n)."""
return (box[2] - box[0]) * (box[3] - box[1])
@ -175,9 +175,10 @@ def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#iss
return 1.0 - 0.5 * eps, 0.5 * eps
# losses
# Losses
class FocalLoss(nn.Module):
# Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
"""Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
super().__init__()
self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
@ -341,7 +342,7 @@ class ConfusionMatrix:
def smooth(y, f=0.05):
# Box filter of fraction f
"""Box filter of fraction f."""
nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
p = np.ones(nf // 2) # ones padding
yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded
@ -350,7 +351,7 @@ def smooth(y, f=0.05):
@plt_settings()
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
# Precision-recall curve
"""Plots a precision-recall curve."""
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
py = np.stack(py, axis=1)
@ -373,7 +374,7 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
@plt_settings()
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
# Metric-confidence curve
"""Plots a metric-confidence curve."""
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
if 0 < len(names) < 21: # display per-class legend if < 21 classes
@ -614,23 +615,23 @@ class Metric(SimpleClass):
return self.all_ap.mean() if len(self.all_ap) else 0.0
def mean_results(self):
"""Mean of results, return mp, mr, map50, map"""
"""Mean of results, return mp, mr, map50, map."""
return [self.mp, self.mr, self.map50, self.map]
def class_result(self, i):
"""class-aware result, return p[i], r[i], ap50[i], ap[i]"""
"""class-aware result, return p[i], r[i], ap50[i], ap[i]."""
return self.p[i], self.r[i], self.ap50[i], self.ap[i]
@property
def maps(self):
"""mAP of each class"""
"""mAP of each class."""
maps = np.zeros(self.nc) + self.map
for i, c in enumerate(self.ap_class_index):
maps[c] = self.ap[i]
return maps
def fitness(self):
# Model fitness as a weighted combination of metrics
"""Model fitness as a weighted combination of metrics."""
w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
return (np.array(self.mean_results()) * w).sum()
@ -800,7 +801,7 @@ class SegmentMetrics(SimpleClass):
@property
def ap_class_index(self):
# boxes and masks have the same ap_class_index
"""Boxes and masks have the same ap_class_index."""
return self.box.ap_class_index
@property
@ -926,7 +927,7 @@ class ClassifyMetrics(SimpleClass):
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
def process(self, targets, pred):
# target classes and predicted classes
"""Target classes and predicted classes."""
pred, targets = torch.cat(pred), torch.cat(targets)
correct = (targets[:, None] == pred).float()
acc = torch.stack((correct[:, 0], correct.max(1).values), dim=1) # (top1, top5) accuracy

View File

@ -246,7 +246,7 @@ def non_max_suppression(
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
i = i[:max_det] # limit detections
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
# Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
weights = iou * scores[None] # box weights
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes

View File

@ -21,7 +21,7 @@ from .ops import clip_boxes, scale_image, xywh2xyxy, xyxy2xywh
class Colors:
# Ultralytics color palette https://ultralytics.com/
def __init__(self):
# hex = matplotlib.colors.TABLEAU_COLORS.values()
"""Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values()."""
hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
@ -63,7 +63,7 @@ class Annotator:
else: # use cv2
self.im = im
self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
# pose
# Pose
self.skeleton = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13], [6, 7], [6, 8], [7, 9],
[8, 10], [9, 11], [2, 3], [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]]
@ -115,7 +115,7 @@ class Annotator:
alpha (float): mask transparency: 0.0 fully transparent, 1.0 opaque
"""
if self.pil:
# convert to numpy first
# Convert to numpy first
self.im = np.asarray(self.im).copy()
if len(masks) == 0:
self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
@ -136,7 +136,7 @@ class Annotator:
im_mask_np = im_mask.byte().cpu().numpy()
self.im[:] = im_mask_np if retina_masks else scale_image(im_mask_np, self.im.shape)
if self.pil:
# convert im back to PIL and update draw
# Convert im back to PIL and update draw
self.fromarray(self.im)
def kpts(self, kpts, shape=(640, 640), radius=5, kpt_line=True):
@ -152,7 +152,7 @@ class Annotator:
Note: `kpt_line=True` currently only supports human pose plotting.
"""
if self.pil:
# convert to numpy first
# Convert to numpy first
self.im = np.asarray(self.im).copy()
nkpt, ndim = kpts.shape
is_pose = nkpt == 17 and ndim == 3
@ -183,11 +183,11 @@ class Annotator:
continue
cv2.line(self.im, pos1, pos2, [int(x) for x in self.limb_color[i]], thickness=2, lineType=cv2.LINE_AA)
if self.pil:
# convert im back to PIL and update draw
# Convert im back to PIL and update draw
self.fromarray(self.im)
def rectangle(self, xy, fill=None, outline=None, width=1):
# Add rectangle to image (PIL-only)
"""Add rectangle to image (PIL-only)."""
self.draw.rectangle(xy, fill, outline, width)
def text(self, xy, text, txt_color=(255, 255, 255), anchor='top'):
@ -202,12 +202,12 @@ class Annotator:
cv2.putText(self.im, text, xy, 0, self.lw / 3, txt_color, thickness=tf, lineType=cv2.LINE_AA)
def fromarray(self, im):
# Update self.im from a numpy array
"""Update self.im from a numpy array."""
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
self.draw = ImageDraw.Draw(self.im)
def result(self):
# Return annotated image as array
"""Return annotated image as array."""
return np.asarray(self.im)
@ -217,18 +217,18 @@ def plot_labels(boxes, cls, names=(), save_dir=Path('')):
import pandas as pd
import seaborn as sn
# plot dataset labels
# Plot dataset labels
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
b = boxes.transpose() # classes, boxes
nc = int(cls.max() + 1) # number of classes
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
# seaborn correlogram
# Seaborn correlogram
sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
plt.close()
# matplotlib labels
# Matplotlib labels
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
with contextlib.suppress(Exception): # color histogram bars by class
@ -242,7 +242,7 @@ def plot_labels(boxes, cls, names=(), save_dir=Path('')):
sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
# rectangles
# Rectangles
boxes[:, 0:2] = 0.5 # center
boxes = xywh2xyxy(boxes) * 1000
img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)
@ -401,7 +401,7 @@ def plot_images(images,
@plt_settings()
def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False):
# Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
"""Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')."""
import pandas as pd
save_dir = Path(file).parent if file else Path(dir)
if segment:
@ -436,7 +436,7 @@ def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False):
def output_to_target(output, max_det=300):
# Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting
"""Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
targets = []
for i, o in enumerate(output):
box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)

View File

@ -48,7 +48,7 @@ def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
is_max_overlaps = is_max_overlaps.permute(0, 2, 1).to(overlaps.dtype) # (b, n_max_boxes, h*w)
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos) # (b, n_max_boxes, h*w)
fg_mask = mask_pos.sum(-2)
# find each grid serve which gt(index)
# Find each grid serve which gt(index)
target_gt_idx = mask_pos.argmax(-2) # (b, h*w)
return target_gt_idx, fg_mask, mask_pos
@ -112,10 +112,10 @@ class TaskAlignedAssigner(nn.Module):
target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
# assigned target
# Assigned target
target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)
# normalize
# Normalize
align_metric *= mask_pos
pos_align_metrics = align_metric.amax(axis=-1, keepdim=True) # b, max_num_obj
pos_overlaps = (overlaps * mask_pos).amax(axis=-1, keepdim=True) # b, max_num_obj
@ -125,13 +125,13 @@ class TaskAlignedAssigner(nn.Module):
return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
# get in_gts mask, (b, max_num_obj, h*w)
"""Get in_gts mask, (b, max_num_obj, h*w)."""
mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
# get anchor_align metric, (b, max_num_obj, h*w)
# Get anchor_align metric, (b, max_num_obj, h*w)
align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
# get topk_metric mask, (b, max_num_obj, h*w)
# Get topk_metric mask, (b, max_num_obj, h*w)
mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.repeat([1, 1, self.topk]).bool())
# merge all mask to a final mask, (b, max_num_obj, h*w)
# Merge all mask to a final mask, (b, max_num_obj, h*w)
mask_pos = mask_topk * mask_in_gts * mask_gt
return mask_pos, align_metric, overlaps
@ -145,7 +145,7 @@ class TaskAlignedAssigner(nn.Module):
ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj
ind[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.n_max_boxes) # b, max_num_obj
ind[1] = gt_labels.long().squeeze(-1) # b, max_num_obj
# get the scores of each grid for each gt cls
# Get the scores of each grid for each gt cls
bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt] # b, max_num_obj, h*w
# (b, max_num_obj, 1, 4), (b, 1, h*w, 4)

View File

@ -30,7 +30,7 @@ TORCH_2_X = check_version(torch.__version__, minimum='2.0')
@contextmanager
def torch_distributed_zero_first(local_rank: int):
# Decorator to make all processes in distributed training wait for each local_master to do something
"""Decorator to make all processes in distributed training wait for each local_master to do something."""
initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
if initialized and local_rank not in (-1, 0):
dist.barrier(device_ids=[local_rank])
@ -40,7 +40,8 @@ def torch_distributed_zero_first(local_rank: int):
def smart_inference_mode():
# Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator
"""Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator."""
def decorate(fn):
return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)
@ -48,7 +49,7 @@ def smart_inference_mode():
def select_device(device='', batch=0, newline=False, verbose=True):
# device = None or 'cpu' or 0 or '0' or '0,1,2,3'
"""Selects PyTorch Device. Options are device = None or 'cpu' or 0 or '0' or '0,1,2,3'."""
s = f'Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
device = str(device).lower()
for remove in 'cuda:', 'none', '(', ')', '[', ']', "'", ' ':
@ -84,7 +85,7 @@ def select_device(device='', batch=0, newline=False, verbose=True):
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
arg = 'cuda:0'
elif mps and getattr(torch, 'has_mps', False) and torch.backends.mps.is_available() and TORCH_2_X:
# prefer MPS if available
# Prefer MPS if available
s += 'MPS\n'
arg = 'mps'
else: # revert to CPU
@ -97,14 +98,14 @@ def select_device(device='', batch=0, newline=False, verbose=True):
def time_sync():
# PyTorch-accurate time
"""PyTorch-accurate time."""
if torch.cuda.is_available():
torch.cuda.synchronize()
return time.time()
def fuse_conv_and_bn(conv, bn):
# Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
"""Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/."""
fusedconv = nn.Conv2d(conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
@ -128,7 +129,7 @@ def fuse_conv_and_bn(conv, bn):
def fuse_deconv_and_bn(deconv, bn):
# Fuse ConvTranspose2d() and BatchNorm2d() layers
"""Fuse ConvTranspose2d() and BatchNorm2d() layers."""
fuseddconv = nn.ConvTranspose2d(deconv.in_channels,
deconv.out_channels,
kernel_size=deconv.kernel_size,
@ -139,7 +140,7 @@ def fuse_deconv_and_bn(deconv, bn):
groups=deconv.groups,
bias=True).requires_grad_(False).to(deconv.weight.device)
# prepare filters
# Prepare filters
w_deconv = deconv.weight.clone().view(deconv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))
@ -153,7 +154,7 @@ def fuse_deconv_and_bn(deconv, bn):
def model_info(model, detailed=False, verbose=True, imgsz=640):
# Model information. imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]
"""Model information. imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]."""
if not verbose:
return
n_p = get_num_params(model)
@ -174,17 +175,17 @@ def model_info(model, detailed=False, verbose=True, imgsz=640):
def get_num_params(model):
# Return the total number of parameters in a YOLO model
"""Return the total number of parameters in a YOLO model."""
return sum(x.numel() for x in model.parameters())
def get_num_gradients(model):
# Return the total number of parameters with gradients in a YOLO model
"""Return the total number of parameters with gradients in a YOLO model."""
return sum(x.numel() for x in model.parameters() if x.requires_grad)
def get_flops(model, imgsz=640):
# Return a YOLO model's FLOPs
"""Return a YOLO model's FLOPs."""
try:
model = de_parallel(model)
p = next(model.parameters())
@ -199,7 +200,7 @@ def get_flops(model, imgsz=640):
def initialize_weights(model):
# Initialize model weights to random values
"""Initialize model weights to random values."""
for m in model.modules():
t = type(m)
if t is nn.Conv2d:
@ -224,7 +225,7 @@ def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
def make_divisible(x, divisor):
# Returns nearest x divisible by divisor
"""Returns nearest x divisible by divisor."""
if isinstance(divisor, torch.Tensor):
divisor = int(divisor.max()) # to int
return math.ceil(x / divisor) * divisor
@ -240,7 +241,7 @@ def copy_attr(a, b, include=(), exclude=()):
def get_latest_opset():
# Return second-most (for maturity) recently supported ONNX opset by this version of torch
"""Return second-most (for maturity) recently supported ONNX opset by this version of torch."""
return max(int(k[14:]) for k in vars(torch.onnx) if 'symbolic_opset' in k) - 1 # opset
@ -250,22 +251,22 @@ def intersect_dicts(da, db, exclude=()):
def is_parallel(model):
# Returns True if model is of type DP or DDP
"""Returns True if model is of type DP or DDP."""
return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))
def de_parallel(model):
# De-parallelize a model: returns single-GPU model if model is of type DP or DDP
"""De-parallelize a model: returns single-GPU model if model is of type DP or DDP."""
return model.module if is_parallel(model) else model
def one_cycle(y1=0.0, y2=1.0, steps=100):
# lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
"""Returns a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf."""
return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
def init_seeds(seed=0, deterministic=False):
# Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
"""Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
@ -280,14 +281,14 @@ def init_seeds(seed=0, deterministic=False):
class ModelEMA:
""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
"""Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
Keeps a moving average of everything in the model state_dict (parameters and buffers)
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
To disable EMA set the `enabled` attribute to `False`.
"""
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
# Create EMA
"""Create EMA."""
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
self.updates = updates # number of EMA updates
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
@ -296,7 +297,7 @@ class ModelEMA:
self.enabled = True
def update(self, model):
# Update EMA parameters
"""Update EMA parameters."""
if self.enabled:
self.updates += 1
d = self.decay(self.updates)