ultralytics 8.0.75 fixes and updates (#1967)

Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Jonathan Rayner <jonathan.j.rayner@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-04-13 02:04:10 +02:00
committed by GitHub
parent e5cb35edfc
commit 48c4483795
9 changed files with 128 additions and 98 deletions

View File

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
__version__ = '8.0.74'
__version__ = '8.0.75'
from ultralytics.hub import start
from ultralytics.yolo.engine.model import YOLO

View File

@ -56,7 +56,7 @@ download: |
cls = int(row[5]) - 1
box = convert_box(img_size, tuple(map(int, row[:4])))
lines.append(f"{cls} {' '.join(f'{x:.6f}' for x in box)}\n")
with open(str(f).replace(os.sep + 'annotations' + os.sep, os.sep + 'labels' + os.sep), 'w') as fl:
with open(str(f).replace(f'{os.sep}annotations{os.sep}', f'{os.sep}labels{os.sep}'), 'w') as fl:
fl.writelines(lines) # write label.txt

View File

@ -21,7 +21,7 @@ class BaseTensor(SimpleClass):
"""
Attributes:
tensor (torch.Tensor): A tensor.
data (torch.Tensor): Base tensor.
orig_shape (tuple): Original image size, in the format (height, width).
Methods:
@ -31,20 +31,14 @@ class BaseTensor(SimpleClass):
to(): Returns a copy of the tensor with the specified device and dtype.
"""
def __init__(self, tensor, orig_shape) -> None:
super().__init__()
assert isinstance(tensor, torch.Tensor)
self.tensor = tensor
def __init__(self, data, orig_shape) -> None:
self.data = data
self.orig_shape = orig_shape
@property
def shape(self):
return self.data.shape
@property
def data(self):
return self.tensor
def cpu(self):
return self.__class__(self.data.cpu(), self.orig_shape)
@ -164,7 +158,6 @@ class Results(SimpleClass):
font_size=None,
font='Arial.ttf',
pil=False,
example='abc',
img=None,
img_gpu=None,
kpt_line=True,
@ -183,7 +176,6 @@ class Results(SimpleClass):
font_size (float, optional): The font size of the text. If None, it is scaled to the image size.
font (str): The font to use for the text.
pil (bool): Whether to return the image as a PIL Image.
example (str): An example string to display. Useful for indicating the expected format of the output.
img (numpy.ndarray): Plot to another image. if not, plot to original image.
img_gpu (torch.Tensor): Normalized image in gpu with shape (1, 3, 640, 640), for faster mask plotting.
kpt_line (bool): Whether to draw lines connecting keypoints.
@ -201,12 +193,16 @@ class Results(SimpleClass):
conf = kwargs['show_conf']
assert type(conf) == bool, '`show_conf` should be of boolean type, i.e, show_conf=True/False'
annotator = Annotator(deepcopy(self.orig_img if img is None else img), line_width, font_size, font, pil,
example)
names = self.names
annotator = Annotator(deepcopy(self.orig_img if img is None else img),
line_width,
font_size,
font,
pil,
example=names)
pred_boxes, show_boxes = self.boxes, boxes
pred_masks, show_masks = self.masks, masks
pred_probs, show_probs = self.probs, probs
names = self.names
keypoints = self.keypoints
if pred_masks and show_masks:
if img_gpu is None:
@ -236,13 +232,13 @@ class Results(SimpleClass):
def verbose(self):
"""
Return log string for each tasks.
Return log string for each task.
"""
log_string = ''
probs = self.probs
boxes = self.boxes
if len(self) == 0:
return log_string if probs is not None else log_string + '(no detections), '
return log_string if probs is not None else f'{log_string}(no detections), '
if probs is not None:
n5 = min(len(self.names), 5)
top5i = probs.argsort(0, descending=True)[:n5].tolist() # top 5 indices
@ -346,26 +342,26 @@ class Boxes(BaseTensor):
boxes = boxes[None, :]
n = boxes.shape[-1]
assert n in (6, 7), f'expected `n` in [6, 7], but got {n}' # xyxy, (track_id), conf, cls
super().__init__(boxes, orig_shape)
self.is_track = n == 7
self.boxes = boxes
self.orig_shape = torch.as_tensor(orig_shape, device=boxes.device) if isinstance(boxes, torch.Tensor) \
else np.asarray(orig_shape)
@property
def xyxy(self):
return self.boxes[:, :4]
return self.data[:, :4]
@property
def conf(self):
return self.boxes[:, -2]
return self.data[:, -2]
@property
def cls(self):
return self.boxes[:, -1]
return self.data[:, -1]
@property
def id(self):
return self.boxes[:, -3] if self.is_track else None
return self.data[:, -3] if self.is_track else None
@property
@lru_cache(maxsize=2) # maxsize 1 should suffice
@ -386,8 +382,9 @@ class Boxes(BaseTensor):
LOGGER.info('results.pandas() method not yet implemented')
@property
def data(self):
return self.boxes
def boxes(self):
LOGGER.warning("WARNING ⚠️ 'Boxes.boxes' is deprecated. Use 'Boxes.data' instead.")
return self.data
class Masks(BaseTensor):
@ -416,8 +413,7 @@ class Masks(BaseTensor):
def __init__(self, masks, orig_shape) -> None:
if masks.ndim == 2:
masks = masks[None, :]
self.masks = masks # N, h, w
self.orig_shape = orig_shape
super().__init__(masks, orig_shape)
@property
@lru_cache(maxsize=1)
@ -432,17 +428,18 @@ class Masks(BaseTensor):
def xyn(self):
# Segments (normalized)
return [
ops.scale_coords(self.masks.shape[1:], x, self.orig_shape, normalize=True)
for x in ops.masks2segments(self.masks)]
ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=True)
for x in ops.masks2segments(self.data)]
@property
@lru_cache(maxsize=1)
def xy(self):
# Segments (pixels)
return [
ops.scale_coords(self.masks.shape[1:], x, self.orig_shape, normalize=False)
for x in ops.masks2segments(self.masks)]
ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=False)
for x in ops.masks2segments(self.data)]
@property
def data(self):
return self.masks
def masks(self):
LOGGER.warning("WARNING ⚠️ 'Masks.masks' is deprecated. Use 'Masks.data' instead.")
return self.data

View File

@ -17,6 +17,7 @@ from types import SimpleNamespace
from typing import Union
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import yaml
@ -116,7 +117,7 @@ class SimpleClass:
attr = []
for a in dir(self):
v = getattr(self, a)
if not callable(v) and not a.startswith('__'):
if not callable(v) and not a.startswith('_'):
if isinstance(v, SimpleClass):
# Display only the module and class name for subclasses
s = f'{a}: {v.__module__}.{v.__class__.__name__} object'
@ -164,6 +165,39 @@ class IterableSimpleNamespace(SimpleNamespace):
return getattr(self, key, default)
def plt_settings(rcparams={'font.size': 11}, backend='Agg'):
"""
Decorator to temporarily set rc parameters and the backend for a plotting function.
Usage:
decorator: @plt_settings({"font.size": 12})
context manager: with plt_settings({"font.size": 12}):
Args:
rcparams (dict): Dictionary of rc parameters to set.
backend (str, optional): Name of the backend to use. Defaults to 'Agg'.
Returns:
callable: Decorated function with temporarily set rc parameters and backend.
"""
def decorator(func):
def wrapper(*args, **kwargs):
original_backend = plt.get_backend()
plt.switch_backend(backend)
with plt.rc_context(rcparams):
result = func(*args, **kwargs)
plt.switch_backend(original_backend)
return result
return wrapper
return decorator
def set_logging(name=LOGGING_NAME, verbose=True):
# sets up logging for the given name
rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings

View File

@ -128,7 +128,8 @@ def check_latest_pypi_version(package_name='ultralytics'):
Returns:
str: The latest version of the package.
"""
response = requests.get(f'https://pypi.org/pypi/{package_name}/json')
requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning
response = requests.get(f'https://pypi.org/pypi/{package_name}/json', verify=False)
if response.status_code == 200:
return response.json()['info']['version']
return None

View File

@ -11,7 +11,7 @@ import numpy as np
import torch
import torch.nn as nn
from ultralytics.yolo.utils import LOGGER, SimpleClass, TryExcept
from ultralytics.yolo.utils import LOGGER, SimpleClass, TryExcept, plt_settings
OKS_SIGMA = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0
@ -234,6 +234,7 @@ class ConfusionMatrix:
return tp[:-1], fp[:-1] # remove background class
@TryExcept('WARNING ⚠️ ConfusionMatrix plot failure')
@plt_settings()
def plot(self, normalize=True, save_dir='', names=()):
import seaborn as sn
@ -277,6 +278,7 @@ def smooth(y, f=0.05):
return np.convolve(yp, np.ones(nf) / nf, mode='valid') # y-smoothed
@plt_settings()
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
# Precision-recall curve
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
@ -299,6 +301,7 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
plt.close(fig)
@plt_settings()
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
# Metric-confidence curve
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)

View File

@ -5,22 +5,18 @@ import math
from pathlib import Path
import cv2
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
from PIL import __version__ as pil_version
from ultralytics.yolo.utils import LOGGER, TryExcept, threaded
from ultralytics.yolo.utils import LOGGER, TryExcept, plt_settings, threaded
from .checks import check_font, check_version, is_ascii
from .files import increment_path
from .ops import clip_boxes, scale_image, xywh2xyxy, xyxy2xywh
matplotlib.rc('font', **{'size': 11})
matplotlib.use('Agg') # for writing to files only
class Colors:
# Ultralytics color palette https://ultralytics.com/
@ -212,6 +208,7 @@ class Annotator:
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
@plt_settings()
def plot_labels(boxes, cls, names=(), save_dir=Path('')):
import pandas as pd
import seaborn as sn
@ -228,7 +225,6 @@ def plot_labels(boxes, cls, names=(), save_dir=Path('')):
plt.close()
# matplotlib labels
matplotlib.use('svg') # faster
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
with contextlib.suppress(Exception): # color histogram bars by class
@ -244,9 +240,9 @@ def plot_labels(boxes, cls, names=(), save_dir=Path('')):
# rectangles
boxes[:, 0:2] = 0.5 # center
boxes = xywh2xyxy(boxes) * 2000
img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
for cls, box in zip(cls[:1000], boxes[:1000]):
boxes = xywh2xyxy(boxes) * 1000
img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)
for cls, box in zip(cls[:500], boxes[:500]):
ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
ax[1].imshow(img)
ax[1].axis('off')
@ -256,7 +252,6 @@ def plot_labels(boxes, cls, names=(), save_dir=Path('')):
ax[a].spines[s].set_visible(False)
plt.savefig(save_dir / 'labels.jpg', dpi=200)
matplotlib.use('Agg')
plt.close()
@ -400,6 +395,7 @@ def plot_images(images,
annotator.im.save(fname) # save
@plt_settings()
def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False):
# Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
import pandas as pd

View File

@ -79,7 +79,7 @@ class SegLoss(Loss):
# targets
try:
batch_idx = batch['batch_idx'].view(-1, 1)
targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes'].to(dtype)), 1)
targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)