Trainer + Dataloaders (#27)
Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Ayush Chaurasia <ayushchaurasia@Ayushs-MacBook-Pro.local> Co-authored-by: Ayush Chaurasia <ayush.chuararsia@gmail.com>
This commit is contained in:
3
ultralytics/yolo/data/__init__.py
Normal file
3
ultralytics/yolo/data/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .build import build_classification_dataloader, build_dataloader
|
||||
from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
|
||||
from .dataset_wrappers import MixAndRectDataset
|
785
ultralytics/yolo/data/augment.py
Normal file
785
ultralytics/yolo/data/augment.py
Normal file
@ -0,0 +1,785 @@
|
||||
import collections
|
||||
import math
|
||||
import random
|
||||
from copy import deepcopy
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
|
||||
from ..utils.general import LOGGER, check_version, colorstr, segment2box
|
||||
from ..utils.instance import Instances
|
||||
from ..utils.metrics import bbox_ioa
|
||||
from .utils import IMAGENET_MEAN, IMAGENET_STD, polygons2masks, polygons2masks_overlap
|
||||
|
||||
|
||||
# TODO: we might need a BaseTransform to make all these augments be compatible with both classification and semantic
|
||||
class BaseTransform:
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def apply_image(self, labels):
|
||||
pass
|
||||
|
||||
def apply_instances(self, labels):
|
||||
pass
|
||||
|
||||
def apply_semantic(self, labels):
|
||||
pass
|
||||
|
||||
def __call__(self, labels):
|
||||
self.apply_image(labels)
|
||||
self.apply_instances(labels)
|
||||
self.apply_semantic(labels)
|
||||
|
||||
|
||||
class Compose:
|
||||
|
||||
def __init__(self, transforms):
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, data):
|
||||
for t in self.transforms:
|
||||
data = t(data)
|
||||
return data
|
||||
|
||||
def append(self, transform):
|
||||
self.transforms.append(transform)
|
||||
|
||||
def tolist(self):
|
||||
return self.transforms
|
||||
|
||||
def __repr__(self):
|
||||
format_string = f"{self.__class__.__name__}("
|
||||
for t in self.transforms:
|
||||
format_string += "\n"
|
||||
format_string += f" {t}"
|
||||
format_string += "\n)"
|
||||
return format_string
|
||||
|
||||
|
||||
class BaseMixTransform:
|
||||
"""This implementation is from mmyolo"""
|
||||
|
||||
def __init__(self, pre_transform=None, p=0.0) -> None:
|
||||
self.pre_transform = pre_transform
|
||||
self.p = p
|
||||
|
||||
def __call__(self, labels):
|
||||
if random.uniform(0, 1) > self.p:
|
||||
return labels
|
||||
|
||||
assert "dataset" in labels
|
||||
dataset = labels.pop("dataset")
|
||||
|
||||
# get index of one or three other images
|
||||
indexes = self.get_indexes(dataset)
|
||||
if not isinstance(indexes, collections.abc.Sequence):
|
||||
indexes = [indexes]
|
||||
|
||||
# get images information will be used for Mosaic or MixUp
|
||||
mix_labels = [deepcopy(dataset.get_label_info(index)) for index in indexes]
|
||||
|
||||
if self.pre_transform is not None:
|
||||
for i, data in enumerate(mix_labels):
|
||||
# pre_transform may also require dataset
|
||||
data.update({"dataset": dataset})
|
||||
# before Mosaic or MixUp need to go through
|
||||
# the necessary pre_transform
|
||||
_labels = self.pre_transform(data)
|
||||
_labels.pop("dataset")
|
||||
mix_labels[i] = _labels
|
||||
labels["mix_labels"] = mix_labels
|
||||
|
||||
# Mosaic or MixUp
|
||||
labels = self._mix_transform(labels)
|
||||
|
||||
if "mix_labels" in labels:
|
||||
labels.pop("mix_labels")
|
||||
labels["dataset"] = dataset
|
||||
|
||||
return labels
|
||||
|
||||
def _mix_transform(self, labels):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_indexes(self, dataset):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Mosaic(BaseMixTransform):
|
||||
"""Mosaic augmentation.
|
||||
Args:
|
||||
img_size (Sequence[int]): Image size after mosaic pipeline of single
|
||||
image. The shape order should be (height, width).
|
||||
Default to (640, 640).
|
||||
"""
|
||||
|
||||
def __init__(self, img_size=640, p=1.0, border=(0, 0)):
|
||||
assert 0 <= p <= 1.0, "The probability should be in range [0, 1]. " f"got {p}."
|
||||
super().__init__(pre_transform=None, p=p)
|
||||
self.img_size = img_size
|
||||
self.border = border
|
||||
|
||||
def get_indexes(self, dataset):
|
||||
return [random.randint(0, len(dataset)) for _ in range(3)]
|
||||
|
||||
def _mix_transform(self, labels):
|
||||
mosaic_labels = []
|
||||
assert labels.get("rect_shape", None) is None, "rect and mosaic is exclusive."
|
||||
assert len(labels.get("mix_labels", [])) > 0, "There are no other images for mosaic augment."
|
||||
s = self.img_size
|
||||
yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y
|
||||
mix_labels = labels["mix_labels"]
|
||||
for i in range(4):
|
||||
labels_patch = deepcopy(labels) if i == 0 else deepcopy(mix_labels[i - 1])
|
||||
# Load image
|
||||
img = labels_patch["img"]
|
||||
h, w = labels_patch["resized_shape"]
|
||||
|
||||
# place img in img4
|
||||
if i == 0: # top left
|
||||
img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
|
||||
x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
|
||||
x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
|
||||
elif i == 1: # top right
|
||||
x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
|
||||
x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
|
||||
elif i == 2: # bottom left
|
||||
x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
|
||||
x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
|
||||
elif i == 3: # bottom right
|
||||
x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
|
||||
x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
|
||||
|
||||
img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
|
||||
padw = x1a - x1b
|
||||
padh = y1a - y1b
|
||||
|
||||
labels_patch = self._update_labels(labels_patch, padw, padh)
|
||||
mosaic_labels.append(labels_patch)
|
||||
final_labels = self._cat_labels(mosaic_labels)
|
||||
final_labels["img"] = img4
|
||||
return final_labels
|
||||
|
||||
def _update_labels(self, labels, padw, padh):
|
||||
"""Update labels"""
|
||||
nh, nw = labels["img"].shape[:2]
|
||||
labels["instances"].convert_bbox(format="xyxy")
|
||||
labels["instances"].denormalize(nw, nh)
|
||||
labels["instances"].add_padding(padw, padh)
|
||||
return labels
|
||||
|
||||
def _cat_labels(self, mosaic_labels):
|
||||
if len(mosaic_labels) == 0:
|
||||
return {}
|
||||
cls = []
|
||||
instances = []
|
||||
for labels in mosaic_labels:
|
||||
cls.append(labels["cls"])
|
||||
instances.append(labels["instances"])
|
||||
final_labels = {
|
||||
"ori_shape": (self.img_size * 2, self.img_size * 2),
|
||||
"resized_shape": (self.img_size * 2, self.img_size * 2),
|
||||
"im_file": mosaic_labels[0]["im_file"],
|
||||
"cls": np.concatenate(cls, 0)}
|
||||
|
||||
final_labels["instances"] = Instances.concatenate(instances, axis=0)
|
||||
final_labels["instances"].clip(self.img_size * 2, self.img_size * 2)
|
||||
return final_labels
|
||||
|
||||
|
||||
class MixUp(BaseMixTransform):
|
||||
|
||||
def __init__(self, pre_transform=None, p=0.0) -> None:
|
||||
super().__init__(pre_transform=pre_transform, p=p)
|
||||
|
||||
def get_indexes(self, dataset):
|
||||
return random.randint(0, len(dataset))
|
||||
|
||||
def _mix_transform(self, labels):
|
||||
im = labels["img"]
|
||||
labels2 = labels["mix_labels"][0]
|
||||
im2 = labels2["img"]
|
||||
cls2 = labels2["cls"]
|
||||
# Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
|
||||
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
|
||||
im = (im * r + im2 * (1 - r)).astype(np.uint8)
|
||||
cat_instances = Instances.concatenate([labels["instances"], labels2["instances"]], axis=0)
|
||||
cls = labels["cls"]
|
||||
labels["img"] = im
|
||||
labels["instances"] = cat_instances
|
||||
labels["cls"] = np.concatenate([cls, cls2], 0)
|
||||
return labels
|
||||
|
||||
|
||||
class RandomPerspective:
|
||||
|
||||
def __init__(self, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, border=(0, 0)):
|
||||
self.degrees = degrees
|
||||
self.translate = translate
|
||||
self.scale = scale
|
||||
self.shear = shear
|
||||
self.perspective = perspective
|
||||
# mosaic border
|
||||
self.border = border
|
||||
|
||||
def affine_transform(self, img):
|
||||
# Center
|
||||
C = np.eye(3)
|
||||
|
||||
C[0, 2] = -img.shape[1] / 2 # x translation (pixels)
|
||||
C[1, 2] = -img.shape[0] / 2 # y translation (pixels)
|
||||
|
||||
# Perspective
|
||||
P = np.eye(3)
|
||||
P[2, 0] = random.uniform(-self.perspective, self.perspective) # x perspective (about y)
|
||||
P[2, 1] = random.uniform(-self.perspective, self.perspective) # y perspective (about x)
|
||||
|
||||
# Rotation and Scale
|
||||
R = np.eye(3)
|
||||
a = random.uniform(-self.degrees, self.degrees)
|
||||
# a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
|
||||
s = random.uniform(1 - self.scale, 1 + self.scale)
|
||||
# s = 2 ** random.uniform(-scale, scale)
|
||||
R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
|
||||
|
||||
# Shear
|
||||
S = np.eye(3)
|
||||
S[0, 1] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) # x shear (deg)
|
||||
S[1, 0] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) # y shear (deg)
|
||||
|
||||
# Translation
|
||||
T = np.eye(3)
|
||||
T[0, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[0] # x translation (pixels)
|
||||
T[1, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[1] # y translation (pixels)
|
||||
|
||||
# Combined rotation matrix
|
||||
M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
|
||||
# affine image
|
||||
if (self.border[0] != 0) or (self.border[1] != 0) or (M != np.eye(3)).any(): # image changed
|
||||
if self.perspective:
|
||||
img = cv2.warpPerspective(img, M, dsize=self.size, borderValue=(114, 114, 114))
|
||||
else: # affine
|
||||
img = cv2.warpAffine(img, M[:2], dsize=self.size, borderValue=(114, 114, 114))
|
||||
return img, M, s
|
||||
|
||||
def apply_bboxes(self, bboxes, M):
|
||||
"""apply affine to bboxes only.
|
||||
|
||||
Args:
|
||||
bboxes(ndarray): list of bboxes, xyxy format, with shape (num_bboxes, 4).
|
||||
M(ndarray): affine matrix.
|
||||
Returns:
|
||||
new_bboxes(ndarray): bboxes after affine, [num_bboxes, 4].
|
||||
"""
|
||||
n = len(bboxes)
|
||||
if n == 0:
|
||||
return bboxes
|
||||
|
||||
xy = np.ones((n * 4, 3))
|
||||
xy[:, :2] = bboxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
|
||||
xy = xy @ M.T # transform
|
||||
xy = (xy[:, :2] / xy[:, 2:3] if self.perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
|
||||
|
||||
# create new boxes
|
||||
x = xy[:, [0, 2, 4, 6]]
|
||||
y = xy[:, [1, 3, 5, 7]]
|
||||
return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
|
||||
|
||||
def apply_segments(self, segments, M):
|
||||
"""apply affine to segments and generate new bboxes from segments.
|
||||
|
||||
Args:
|
||||
segments(ndarray): list of segments, [num_samples, 500, 2].
|
||||
M(ndarray): affine matrix.
|
||||
Returns:
|
||||
new_segments(ndarray): list of segments after affine, [num_samples, 500, 2].
|
||||
new_bboxes(ndarray): bboxes after affine, [N, 4].
|
||||
"""
|
||||
n, num = segments.shape[:2]
|
||||
if n == 0:
|
||||
return [], segments
|
||||
|
||||
xy = np.ones((n * num, 3))
|
||||
segments = segments.reshape(-1, 2)
|
||||
xy[:, :2] = segments
|
||||
xy = xy @ M.T # transform
|
||||
xy = xy[:, :2] / xy[:, 2:3]
|
||||
segments = xy.reshape(n, -1, 2)
|
||||
bboxes = np.stack([segment2box(xy, self.size[0], self.size[1]) for xy in segments], 0)
|
||||
return bboxes, segments
|
||||
|
||||
def apply_keypoints(self, keypoints, M):
|
||||
"""apply affine to keypoints.
|
||||
|
||||
Args:
|
||||
keypoints(ndarray): keypoints, [N, 17, 2].
|
||||
M(ndarray): affine matrix.
|
||||
Return:
|
||||
new_keypoints(ndarray): keypoints after affine, [N, 17, 2].
|
||||
"""
|
||||
n = len(keypoints)
|
||||
if n == 0:
|
||||
return keypoints
|
||||
new_keypoints = np.ones((n * 17, 3))
|
||||
new_keypoints[:, :2] = keypoints.reshape(n * 17, 2) # num_kpt is hardcoded to 17
|
||||
new_keypoints = new_keypoints @ M.T # transform
|
||||
new_keypoints = (new_keypoints[:, :2] / new_keypoints[:, 2:3]).reshape(n, 34) # perspective rescale or affine
|
||||
new_keypoints[keypoints.reshape(-1, 34) == 0] = 0
|
||||
x_kpts = new_keypoints[:, list(range(0, 34, 2))]
|
||||
y_kpts = new_keypoints[:, list(range(1, 34, 2))]
|
||||
|
||||
x_kpts[np.logical_or.reduce((x_kpts < 0, x_kpts > self.size[0], y_kpts < 0, y_kpts > self.size[1]))] = 0
|
||||
y_kpts[np.logical_or.reduce((x_kpts < 0, x_kpts > self.size[0], y_kpts < 0, y_kpts > self.size[1]))] = 0
|
||||
new_keypoints[:, list(range(0, 34, 2))] = x_kpts
|
||||
new_keypoints[:, list(range(1, 34, 2))] = y_kpts
|
||||
return new_keypoints.reshape(n, 17, 2)
|
||||
|
||||
def __call__(self, labels):
|
||||
"""
|
||||
Affine images and targets.
|
||||
|
||||
Args:
|
||||
img(ndarray): image.
|
||||
labels(Dict): a dict of `bboxes`, `segments`, `keypoints`.
|
||||
"""
|
||||
img = labels["img"]
|
||||
cls = labels["cls"]
|
||||
instances = labels["instances"]
|
||||
# make sure the coord formats are right
|
||||
instances.convert_bbox(format="xyxy")
|
||||
instances.denormalize(*img.shape[:2][::-1])
|
||||
|
||||
self.size = img.shape[1] + self.border[1] * 2, img.shape[0] + self.border[0] * 2 # w, h
|
||||
# M is affine matrix
|
||||
# scale for func:`box_candidates`
|
||||
img, M, scale = self.affine_transform(img)
|
||||
|
||||
bboxes = self.apply_bboxes(instances.bboxes, M)
|
||||
|
||||
segments = instances.segments
|
||||
keypoints = instances.keypoints
|
||||
# update bboxes if there are segments.
|
||||
if segments is not None:
|
||||
bboxes, segments = self.apply_segments(segments, M)
|
||||
|
||||
if keypoints is not None:
|
||||
keypoints = self.apply_keypoints(keypoints, M)
|
||||
new_instances = Instances(bboxes, segments, keypoints, bbox_format="xyxy", normalized=False)
|
||||
new_instances.clip(*self.size)
|
||||
|
||||
# filter instances
|
||||
instances.scale(scale_w=scale, scale_h=scale, bbox_only=True)
|
||||
# make the bboxes have the same scale with new_bboxes
|
||||
i = self.box_candidates(box1=instances.bboxes.T,
|
||||
box2=new_instances.bboxes.T,
|
||||
area_thr=0.01 if segments is not None else 0.10)
|
||||
labels["instances"] = new_instances[i]
|
||||
# clip
|
||||
labels["cls"] = cls[i]
|
||||
labels["img"] = img
|
||||
return labels
|
||||
|
||||
def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
|
||||
# Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
|
||||
w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
|
||||
w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
|
||||
ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
|
||||
return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
|
||||
|
||||
|
||||
class RandomHSV:
|
||||
|
||||
def __init__(self, hgain=0.5, sgain=0.5, vgain=0.5) -> None:
|
||||
self.hgain = hgain
|
||||
self.sgain = sgain
|
||||
self.vgain = vgain
|
||||
|
||||
def __call__(self, labels):
|
||||
img = labels["img"]
|
||||
if self.hgain or self.sgain or self.vgain:
|
||||
r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1 # random gains
|
||||
hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
|
||||
dtype = img.dtype # uint8
|
||||
|
||||
x = np.arange(0, 256, dtype=r.dtype)
|
||||
lut_hue = ((x * r[0]) % 180).astype(dtype)
|
||||
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
|
||||
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
|
||||
|
||||
im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
|
||||
cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
|
||||
labels["img"] = img
|
||||
return labels
|
||||
|
||||
|
||||
class RandomFlip:
|
||||
|
||||
def __init__(self, p=0.5, direction="horizontal") -> None:
|
||||
assert direction in ["horizontal", "vertical"], f"Support direction `horizontal` or `vertical`, got {direction}"
|
||||
assert 0 <= p <= 1.0
|
||||
|
||||
self.p = p
|
||||
self.direction = direction
|
||||
|
||||
def __call__(self, labels):
|
||||
img = labels["img"]
|
||||
instances = labels["instances"]
|
||||
instances.convert_bbox(format="xywh")
|
||||
h, w = img.shape[:2]
|
||||
h = 1 if instances.normalized else h
|
||||
w = 1 if instances.normalized else w
|
||||
|
||||
# Flip up-down
|
||||
if self.direction == "vertical" and random.random() < self.p:
|
||||
img = np.flipud(img)
|
||||
img = np.ascontiguousarray(img)
|
||||
instances.flipud(h)
|
||||
if self.direction == "horizontal" and random.random() < self.p:
|
||||
img = np.fliplr(img)
|
||||
img = np.ascontiguousarray(img)
|
||||
instances.fliplr(w)
|
||||
labels["img"] = img
|
||||
labels["instances"] = instances
|
||||
return labels
|
||||
|
||||
|
||||
class LetterBox:
|
||||
"""Resize image and padding for detection, instance segmentation, pose"""
|
||||
|
||||
def __init__(self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32):
|
||||
self.new_shape = new_shape
|
||||
self.auto = auto
|
||||
self.scaleFill = scaleFill
|
||||
self.scaleup = scaleup
|
||||
self.stride = stride
|
||||
|
||||
def __call__(self, labels):
|
||||
img = labels["img"]
|
||||
shape = img.shape[:2] # current shape [height, width]
|
||||
new_shape = labels.get("rect_shape", self.new_shape)
|
||||
if isinstance(new_shape, int):
|
||||
new_shape = (new_shape, new_shape)
|
||||
|
||||
# Scale ratio (new / old)
|
||||
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
||||
if not self.scaleup: # only scale down, do not scale up (for better val mAP)
|
||||
r = min(r, 1.0)
|
||||
|
||||
# Compute padding
|
||||
ratio = r, r # width, height ratios
|
||||
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
||||
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
||||
if self.auto: # minimum rectangle
|
||||
dw, dh = np.mod(dw, self.stride), np.mod(dh, self.stride) # wh padding
|
||||
elif self.scaleFill: # stretch
|
||||
dw, dh = 0.0, 0.0
|
||||
new_unpad = (new_shape[1], new_shape[0])
|
||||
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
|
||||
|
||||
dw /= 2 # divide padding into 2 sides
|
||||
dh /= 2
|
||||
|
||||
if shape[::-1] != new_unpad: # resize
|
||||
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
||||
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
||||
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT,
|
||||
value=(114, 114, 114)) # add border
|
||||
|
||||
labels = self._update_labels(labels, ratio, dw, dh)
|
||||
labels["img"] = img
|
||||
return labels
|
||||
|
||||
def _update_labels(self, labels, ratio, padw, padh):
|
||||
"""Update labels"""
|
||||
labels["instances"].convert_bbox(format="xyxy")
|
||||
labels["instances"].denormalize(*labels["img"].shape[:2][::-1])
|
||||
labels["instances"].scale(*ratio)
|
||||
labels["instances"].add_padding(padw, padh)
|
||||
return labels
|
||||
|
||||
|
||||
class CopyPaste:
|
||||
|
||||
def __init__(self, p=0.5) -> None:
|
||||
self.p = p
|
||||
|
||||
def __call__(self, labels):
|
||||
# Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
|
||||
im = labels["img"]
|
||||
cls = labels["cls"]
|
||||
bboxes = labels["instances"].bboxes
|
||||
segments = labels["instances"].segments # n, 1000, 2
|
||||
keypoints = labels["instances"].keypoints
|
||||
if self.p and segments is not None:
|
||||
n = len(segments)
|
||||
h, w, _ = im.shape # height, width, channels
|
||||
im_new = np.zeros(im.shape, np.uint8)
|
||||
# TODO: this implement can be parallel since segments are ndarray, also might work with Instances inside
|
||||
for j in random.sample(range(n), k=round(self.p * n)):
|
||||
c, b, s = cls[j], bboxes[j], segments[j]
|
||||
box = w - b[2], b[1], w - b[0], b[3]
|
||||
ioa = bbox_ioa(box, bboxes) # intersection over area
|
||||
if (ioa < 0.30).all(): # allow 30% obscuration of existing labels
|
||||
bboxes = np.concatenate((bboxes, [box]), 0)
|
||||
cls = np.concatenate((cls, c[None]), axis=0)
|
||||
segments = np.concatenate((segments, np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1)[None]), 0)
|
||||
if keypoints is not None:
|
||||
keypoints = np.concatenate(
|
||||
(keypoints, np.concatenate((w - keypoints[j][:, 0:1], keypoints[j][:, 1:2]), 1)), 0)
|
||||
cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED)
|
||||
|
||||
result = cv2.bitwise_and(src1=im, src2=im_new)
|
||||
result = cv2.flip(result, 1) # augment segments (flip left-right)
|
||||
i = result > 0 # pixels to replace
|
||||
# i[:, :] = result.max(2).reshape(h, w, 1) # act over ch
|
||||
im[i] = result[i] # cv2.imwrite('debug.jpg', im) # debug
|
||||
labels["img"] = im
|
||||
labels["cls"] = cls
|
||||
labels["instances"].update(bboxes, segments, keypoints)
|
||||
return labels
|
||||
|
||||
|
||||
class Albumentations:
|
||||
# YOLOv5 Albumentations class (optional, only used if package is installed)
|
||||
def __init__(self, p=1.0):
|
||||
self.p = p
|
||||
self.transform = None
|
||||
prefix = colorstr("albumentations: ")
|
||||
try:
|
||||
import albumentations as A
|
||||
|
||||
check_version(A.__version__, "1.0.3", hard=True) # version requirement
|
||||
|
||||
T = [
|
||||
A.Blur(p=0.01),
|
||||
A.MedianBlur(p=0.01),
|
||||
A.ToGray(p=0.01),
|
||||
A.CLAHE(p=0.01),
|
||||
A.RandomBrightnessContrast(p=0.0),
|
||||
A.RandomGamma(p=0.0),
|
||||
A.ImageCompression(quality_lower=75, p=0.0),] # transforms
|
||||
self.transform = A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"]))
|
||||
|
||||
LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
|
||||
except ImportError: # package not installed, skip
|
||||
pass
|
||||
except Exception as e:
|
||||
LOGGER.info(f"{prefix}{e}")
|
||||
|
||||
def __call__(self, labels):
|
||||
im = labels["img"]
|
||||
cls = labels["cls"]
|
||||
if len(cls):
|
||||
labels["instances"].convert_bbox("xywh")
|
||||
labels["instances"].normalize(*im.shape[:2][::-1])
|
||||
bboxes = labels["instances"].bboxes
|
||||
# TODO: add supports of segments and keypoints
|
||||
if self.transform and random.random() < self.p:
|
||||
new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed
|
||||
labels["img"] = new["image"]
|
||||
labels["cls"] = np.array(new["class_labels"])
|
||||
labels["instances"].update(bboxes=bboxes)
|
||||
return labels
|
||||
|
||||
|
||||
# TODO: technically this is not an augmentation, maybe we should put this to another files
|
||||
class Format:
|
||||
|
||||
def __init__(self, bbox_format="xywh", normalize=True, mask=False, mask_ratio=4, mask_overlap=True, batch_idx=True):
|
||||
self.bbox_format = bbox_format
|
||||
self.normalize = normalize
|
||||
self.mask = mask # set False when training detection only
|
||||
self.mask_ratio = mask_ratio
|
||||
self.mask_overlap = mask_overlap
|
||||
self.batch_idx = batch_idx # keep the batch indexes
|
||||
|
||||
def __call__(self, labels):
|
||||
img = labels["img"]
|
||||
h, w = img.shape[:2]
|
||||
cls = labels.pop("cls")
|
||||
instances = labels.pop("instances")
|
||||
instances.convert_bbox(format=self.bbox_format)
|
||||
instances.denormalize(w, h)
|
||||
nl = len(instances)
|
||||
|
||||
if instances.segments is not None and self.mask:
|
||||
masks, instances, cls = self._format_segments(instances, cls, w, h)
|
||||
labels["masks"] = (torch.from_numpy(masks) if nl else torch.zeros(
|
||||
1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio, img.shape[1] // self.mask_ratio))
|
||||
if self.normalize:
|
||||
instances.normalize(w, h)
|
||||
labels["img"] = self._format_img(img)
|
||||
labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl)
|
||||
labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
|
||||
if instances.keypoints is not None:
|
||||
labels["keypoints"] = torch.from_numpy(instances.keypoints) if nl else torch.zeros((nl, 17, 2))
|
||||
# then we can use collate_fn
|
||||
if self.batch_idx:
|
||||
labels["batch_idx"] = torch.zeros(nl)
|
||||
return labels
|
||||
|
||||
def _format_img(self, img):
|
||||
if len(img.shape) < 3:
|
||||
img = np.expand_dims(img, -1)
|
||||
img = np.ascontiguousarray(img.transpose(2, 0, 1))
|
||||
img = torch.from_numpy(img)
|
||||
return img
|
||||
|
||||
def _format_segments(self, instances, cls, w, h):
|
||||
"""convert polygon points to bitmap"""
|
||||
segments = instances.segments
|
||||
if self.mask_overlap:
|
||||
masks, sorted_idx = polygons2masks_overlap((h, w), segments, downsample_ratio=self.mask_ratio)
|
||||
masks = masks[None] # (640, 640) -> (1, 640, 640)
|
||||
instances = instances[sorted_idx]
|
||||
cls = cls[sorted_idx]
|
||||
else:
|
||||
masks = polygons2masks((h, w), segments, color=1, downsample_ratio=self.mask_ratio)
|
||||
|
||||
return masks, instances, cls
|
||||
|
||||
|
||||
def mosaic_transforms(img_size, hyp):
|
||||
pre_transform = Compose([
|
||||
Mosaic(img_size=img_size, p=hyp.mosaic, border=[-img_size // 2, -img_size // 2]),
|
||||
CopyPaste(p=hyp.copy_paste),
|
||||
RandomPerspective(
|
||||
degrees=hyp.degrees,
|
||||
translate=hyp.translate,
|
||||
scale=hyp.scale,
|
||||
shear=hyp.shear,
|
||||
perspective=hyp.perspective,
|
||||
border=[-img_size // 2, -img_size // 2],
|
||||
),])
|
||||
transforms = Compose([
|
||||
pre_transform,
|
||||
MixUp(
|
||||
pre_transform=pre_transform,
|
||||
p=hyp.mixup,
|
||||
),
|
||||
Albumentations(p=1.0),
|
||||
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
|
||||
RandomFlip(direction="vertical", p=hyp.flipud),
|
||||
RandomFlip(direction="horizontal", p=hyp.fliplr),])
|
||||
return transforms
|
||||
|
||||
|
||||
def affine_transforms(img_size, hyp):
|
||||
# rect, randomperspective, albumentation, hsv, flipud, fliplr
|
||||
transforms = Compose([
|
||||
LetterBox(new_shape=(img_size, img_size)),
|
||||
RandomPerspective(
|
||||
degrees=hyp.degrees,
|
||||
translate=hyp.translate,
|
||||
scale=hyp.scale,
|
||||
shear=hyp.shear,
|
||||
perspective=hyp.perspective,
|
||||
border=[0, 0],
|
||||
),
|
||||
Albumentations(p=1.0),
|
||||
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
|
||||
RandomFlip(direction="vertical", p=hyp.flipud),
|
||||
RandomFlip(direction="horizontal", p=hyp.fliplr),])
|
||||
return transforms
|
||||
|
||||
|
||||
# Classification augmentations -------------------------------------------------------------------------------------------
|
||||
def classify_transforms(size=224):
|
||||
# Transforms to apply if albumentations not installed
|
||||
assert isinstance(size, int), f"ERROR: classify_transforms size {size} must be integer, not (list, tuple)"
|
||||
# T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
|
||||
return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
|
||||
|
||||
|
||||
def classify_albumentations(
|
||||
augment=True,
|
||||
size=224,
|
||||
scale=(0.08, 1.0),
|
||||
hflip=0.5,
|
||||
vflip=0.0,
|
||||
jitter=0.4,
|
||||
mean=IMAGENET_MEAN,
|
||||
std=IMAGENET_STD,
|
||||
auto_aug=False,
|
||||
):
|
||||
# YOLOv5 classification Albumentations (optional, only used if package is installed)
|
||||
prefix = colorstr("albumentations: ")
|
||||
try:
|
||||
import albumentations as A
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
|
||||
check_version(A.__version__, "1.0.3", hard=True) # version requirement
|
||||
if augment: # Resize and crop
|
||||
T = [A.RandomResizedCrop(height=size, width=size, scale=scale)]
|
||||
if auto_aug:
|
||||
# TODO: implement AugMix, AutoAug & RandAug in albumentation
|
||||
LOGGER.info(f"{prefix}auto augmentations are currently not supported")
|
||||
else:
|
||||
if hflip > 0:
|
||||
T += [A.HorizontalFlip(p=hflip)]
|
||||
if vflip > 0:
|
||||
T += [A.VerticalFlip(p=vflip)]
|
||||
if jitter > 0:
|
||||
color_jitter = (float(jitter),) * 3 # repeat value for brightness, contrast, satuaration, 0 hue
|
||||
T += [A.ColorJitter(*color_jitter, 0)]
|
||||
else: # Use fixed crop for eval set (reproducibility)
|
||||
T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
|
||||
T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor
|
||||
LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
|
||||
return A.Compose(T)
|
||||
|
||||
except ImportError: # package not installed, skip
|
||||
pass
|
||||
except Exception as e:
|
||||
LOGGER.info(f"{prefix}{e}")
|
||||
|
||||
|
||||
class ClassifyLetterBox:
|
||||
# YOLOv5 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
|
||||
def __init__(self, size=(640, 640), auto=False, stride=32):
|
||||
super().__init__()
|
||||
self.h, self.w = (size, size) if isinstance(size, int) else size
|
||||
self.auto = auto # pass max size integer, automatically solve for short side using stride
|
||||
self.stride = stride # used with auto
|
||||
|
||||
def __call__(self, im): # im = np.array HWC
|
||||
imh, imw = im.shape[:2]
|
||||
r = min(self.h / imh, self.w / imw) # ratio of new/old
|
||||
h, w = round(imh * r), round(imw * r) # resized image
|
||||
hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w
|
||||
top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1)
|
||||
im_out = np.full((self.h, self.w, 3), 114, dtype=im.dtype)
|
||||
im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
|
||||
return im_out
|
||||
|
||||
|
||||
class CenterCrop:
|
||||
# YOLOv5 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])
|
||||
def __init__(self, size=640):
|
||||
super().__init__()
|
||||
self.h, self.w = (size, size) if isinstance(size, int) else size
|
||||
|
||||
def __call__(self, im): # im = np.array HWC
|
||||
imh, imw = im.shape[:2]
|
||||
m = min(imh, imw) # min dimension
|
||||
top, left = (imh - m) // 2, (imw - m) // 2
|
||||
return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
|
||||
class ToTensor:
|
||||
# YOLOv5 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
|
||||
def __init__(self, half=False):
|
||||
super().__init__()
|
||||
self.half = half
|
||||
|
||||
def __call__(self, im): # im = np.array HWC in BGR order
|
||||
im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1]) # HWC to CHW -> BGR to RGB -> contiguous
|
||||
im = torch.from_numpy(im) # to torch
|
||||
im = im.half() if self.half else im.float() # uint8 to fp16/32
|
||||
im /= 255.0 # 0-255 to 0.0-1.0
|
||||
return im
|
224
ultralytics/yolo/data/base.py
Normal file
224
ultralytics/yolo/data/base.py
Normal file
@ -0,0 +1,224 @@
|
||||
import glob
|
||||
import os
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..utils.general import NUM_THREADS
|
||||
from .utils import BAR_FORMAT, HELP_URL, IMG_FORMATS, LOCAL_RANK
|
||||
|
||||
|
||||
class BaseDataset(Dataset):
|
||||
"""Base Dataset.
|
||||
Args:
|
||||
img_path (str): image path.
|
||||
pipeline (dict): a dict of image transforms.
|
||||
label_path (str): label path, this can also be a ann_file or other custom label path.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_path,
|
||||
img_size=640,
|
||||
label_path=None,
|
||||
cache=False,
|
||||
augment=True,
|
||||
hyp=None,
|
||||
prefix="",
|
||||
rect=False,
|
||||
batch_size=None,
|
||||
stride=32,
|
||||
pad=0.5,
|
||||
single_cls=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.img_path = img_path
|
||||
self.img_size = img_size
|
||||
self.label_path = label_path
|
||||
self.augment = augment
|
||||
self.prefix = prefix
|
||||
|
||||
self.im_files = self.get_img_files(self.img_path)
|
||||
self.labels = self.get_labels()
|
||||
if single_cls:
|
||||
self.update_labels(include_class=[], single_cls=single_cls)
|
||||
|
||||
self.ni = len(self.im_files)
|
||||
|
||||
# rect stuff
|
||||
self.rect = rect
|
||||
self.batch_size = batch_size
|
||||
self.stride = stride
|
||||
self.pad = pad
|
||||
if self.rect:
|
||||
assert self.batch_size is not None
|
||||
self.set_rectangle()
|
||||
|
||||
# cache stuff
|
||||
self.ims = [None] * self.ni
|
||||
self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files]
|
||||
if cache:
|
||||
self.cache_images()
|
||||
|
||||
# transforms
|
||||
self.transforms = self.build_transforms(hyp=hyp)
|
||||
|
||||
def get_img_files(self, img_path):
|
||||
"""Read image files."""
|
||||
try:
|
||||
f = [] # image files
|
||||
for p in img_path if isinstance(img_path, list) else [img_path]:
|
||||
p = Path(p) # os-agnostic
|
||||
if p.is_dir(): # dir
|
||||
f += glob.glob(str(p / "**" / "*.*"), recursive=True)
|
||||
# f = list(p.rglob('*.*')) # pathlib
|
||||
elif p.is_file(): # file
|
||||
with open(p) as t:
|
||||
t = t.read().strip().splitlines()
|
||||
parent = str(p.parent) + os.sep
|
||||
f += [x.replace("./", parent) if x.startswith("./") else x for x in t] # local to global path
|
||||
# f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
|
||||
else:
|
||||
raise FileNotFoundError(f"{self.prefix}{p} does not exist")
|
||||
im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
|
||||
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
|
||||
assert im_files, f"{self.prefix}No images found"
|
||||
except Exception as e:
|
||||
raise Exception(f"{self.prefix}Error loading data from {img_path}: {e}\n{HELP_URL}")
|
||||
return im_files
|
||||
|
||||
def update_labels(self, include_class: Optional[list]):
|
||||
"""include_class, filter labels to include only these classes (optional)"""
|
||||
include_class_array = np.array(include_class).reshape(1, -1)
|
||||
for i in range(len(self.labels)):
|
||||
if include_class:
|
||||
cls = self.labels[i]["cls"]
|
||||
bboxes = self.labels[i]["bboxes"]
|
||||
segments = self.labels[i]["segments"]
|
||||
j = (cls == include_class_array).any(1)
|
||||
self.labels[i]["cls"] = cls[j]
|
||||
self.labels[i]["bboxes"] = bboxes[j]
|
||||
if segments:
|
||||
self.labels[i]["segments"] = segments[j]
|
||||
if self.single_cls:
|
||||
self.labels[i]["cls"] = 0
|
||||
|
||||
def load_image(self, i):
|
||||
# Loads 1 image from dataset index 'i', returns (im, resized hw)
|
||||
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
|
||||
if im is None: # not cached in RAM
|
||||
if fn.exists(): # load npy
|
||||
im = np.load(fn)
|
||||
else: # read image
|
||||
im = cv2.imread(f) # BGR
|
||||
assert im is not None, f"Image Not Found {f}"
|
||||
h0, w0 = im.shape[:2] # orig hw
|
||||
r = self.img_size / max(h0, w0) # ratio
|
||||
if r != 1: # if sizes are not equal
|
||||
interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
|
||||
im = cv2.resize(im, (int(w0 * r), int(h0 * r)), interpolation=interp)
|
||||
return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
|
||||
return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
|
||||
|
||||
def cache_images(self):
|
||||
# cache images to memory or disk
|
||||
gb = 0 # Gigabytes of cached images
|
||||
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
|
||||
fcn = self.cache_images_to_disk if self.cache == "disk" else self.load_image
|
||||
results = ThreadPool(NUM_THREADS).imap(fcn, range(self.ni))
|
||||
pbar = tqdm(enumerate(results), total=self.ni, bar_format=BAR_FORMAT, disable=LOCAL_RANK > 0)
|
||||
for i, x in pbar:
|
||||
if self.cache == "disk":
|
||||
gb += self.npy_files[i].stat().st_size
|
||||
else: # 'ram'
|
||||
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
|
||||
gb += self.ims[i].nbytes
|
||||
pbar.desc = f"{self.prefix}Caching images ({gb / 1E9:.1f}GB {self.cache})"
|
||||
pbar.close()
|
||||
|
||||
def cache_images_to_disk(self, i):
|
||||
# Saves an image as an *.npy file for faster loading
|
||||
f = self.npy_files[i]
|
||||
if not f.exists():
|
||||
np.save(f.as_posix(), cv2.imread(self.im_files[i]))
|
||||
|
||||
def set_rectangle(self):
|
||||
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
|
||||
nb = bi[-1] + 1 # number of batches
|
||||
|
||||
s = np.array([x["shape"] for x in self.labels]) # hw
|
||||
ar = s[:, 0] / s[:, 1] # aspect ratio
|
||||
irect = ar.argsort()
|
||||
self.im_files = [self.im_files[i] for i in irect]
|
||||
self.labels = [self.labels[i] for i in irect]
|
||||
ar = ar[irect]
|
||||
|
||||
# Set training image shapes
|
||||
shapes = [[1, 1]] * nb
|
||||
for i in range(nb):
|
||||
ari = ar[bi == i]
|
||||
mini, maxi = ari.min(), ari.max()
|
||||
if maxi < 1:
|
||||
shapes[i] = [maxi, 1]
|
||||
elif mini > 1:
|
||||
shapes[i] = [1, 1 / mini]
|
||||
|
||||
self.batch_shapes = np.ceil(np.array(shapes) * self.img_size / self.stride + self.pad).astype(int) * self.stride
|
||||
self.batch = bi # batch index of image
|
||||
|
||||
def __getitem__(self, index):
|
||||
label = self.get_label_info(index)
|
||||
if self.augment:
|
||||
label["dataset"] = self
|
||||
return self.transforms(label)
|
||||
|
||||
def get_label_info(self, index):
|
||||
label = self.labels[index].copy()
|
||||
img, (h0, w0), (h, w) = self.load_image(index)
|
||||
label["img"] = img
|
||||
label["ori_shape"] = (h0, w0)
|
||||
label["resized_shape"] = (h, w)
|
||||
if self.rect:
|
||||
label["rect_shape"] = self.batch_shapes[self.batch[index]]
|
||||
label = self.update_labels_info(label)
|
||||
return label
|
||||
|
||||
def __len__(self):
|
||||
return len(self.im_files)
|
||||
|
||||
def update_labels_info(self, label):
|
||||
"""custom your label format here"""
|
||||
return label
|
||||
|
||||
def build_transforms(self, hyp=None):
|
||||
"""Users can custom augmentations here
|
||||
like:
|
||||
if self.augment:
|
||||
# training transforms
|
||||
return Compose([])
|
||||
else:
|
||||
# val transforms
|
||||
return Compose([])
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_labels(self):
|
||||
"""Users can custom their own format here.
|
||||
Make sure your output is a list with each element like below:
|
||||
dict(
|
||||
im_file=im_file,
|
||||
shape=shape, # format: (height, width)
|
||||
cls=cls,
|
||||
bboxes=bboxes, # xywh
|
||||
segments=segments, # xy
|
||||
keypoints=keypoints, # xy
|
||||
normalized=True, # or False
|
||||
bbox_format="xyxy", # or xywh, ltwh
|
||||
)
|
||||
"""
|
||||
raise NotImplementedError
|
145
ultralytics/yolo/data/build.py
Normal file
145
ultralytics/yolo/data/build.py
Normal file
@ -0,0 +1,145 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, dataloader, distributed
|
||||
|
||||
from ..utils.general import LOGGER
|
||||
from ..utils.torch_utils import torch_distributed_zero_first
|
||||
from .dataset import ClassificationDataset, YOLODataset
|
||||
from .utils import PIN_MEMORY, RANK
|
||||
|
||||
|
||||
class InfiniteDataLoader(dataloader.DataLoader):
|
||||
"""Dataloader that reuses workers
|
||||
|
||||
Uses same syntax as vanilla DataLoader
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))
|
||||
self.iterator = super().__iter__()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.batch_sampler.sampler)
|
||||
|
||||
def __iter__(self):
|
||||
for _ in range(len(self)):
|
||||
yield next(self.iterator)
|
||||
|
||||
|
||||
class _RepeatSampler:
|
||||
"""Sampler that repeats forever
|
||||
|
||||
Args:
|
||||
sampler (Sampler)
|
||||
"""
|
||||
|
||||
def __init__(self, sampler):
|
||||
self.sampler = sampler
|
||||
|
||||
def __iter__(self):
|
||||
while True:
|
||||
yield from iter(self.sampler)
|
||||
|
||||
|
||||
def seed_worker(worker_id):
|
||||
# Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader
|
||||
worker_seed = torch.initial_seed() % 2 ** 32
|
||||
np.random.seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
|
||||
# TODO: we can inject most args from a config file
|
||||
def build_dataloader(
|
||||
img_path,
|
||||
img_size, #
|
||||
batch_size, #
|
||||
single_cls=False, #
|
||||
hyp=None, #
|
||||
augment=False,
|
||||
cache=False, #
|
||||
image_weights=False, #
|
||||
stride=32,
|
||||
label_path=None,
|
||||
pad=0.0,
|
||||
rect=False,
|
||||
rank=-1,
|
||||
workers=8,
|
||||
prefix="",
|
||||
shuffle=False,
|
||||
use_segments=False,
|
||||
use_keypoints=False,
|
||||
):
|
||||
if rect and shuffle:
|
||||
LOGGER.warning("WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False")
|
||||
shuffle = False
|
||||
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
||||
dataset = YOLODataset(
|
||||
img_path=img_path,
|
||||
img_size=img_size,
|
||||
batch_size=batch_size,
|
||||
label_path=label_path,
|
||||
augment=augment, # augmentation
|
||||
hyp=hyp,
|
||||
rect=rect, # rectangular batches
|
||||
cache=cache,
|
||||
single_cls=single_cls,
|
||||
stride=int(stride),
|
||||
pad=pad,
|
||||
prefix=prefix,
|
||||
use_segments=use_segments,
|
||||
use_keypoints=use_keypoints,
|
||||
)
|
||||
|
||||
batch_size = min(batch_size, len(dataset))
|
||||
nd = torch.cuda.device_count() # number of CUDA devices
|
||||
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
|
||||
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
||||
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(6148914691236517205 + RANK)
|
||||
return (
|
||||
loader(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle and sampler is None,
|
||||
num_workers=nw,
|
||||
sampler=sampler,
|
||||
pin_memory=PIN_MEMORY,
|
||||
collate_fn=getattr(dataset, "collate_fn", None),
|
||||
worker_init_fn=seed_worker,
|
||||
generator=generator,
|
||||
),
|
||||
dataset,
|
||||
)
|
||||
|
||||
|
||||
# build classification
|
||||
def build_classification_dataloader(path,
|
||||
imgsz=224,
|
||||
batch_size=16,
|
||||
augment=True,
|
||||
cache=False,
|
||||
rank=-1,
|
||||
workers=8,
|
||||
shuffle=True):
|
||||
# Returns Dataloader object to be used with YOLOv5 Classifier
|
||||
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
||||
dataset = ClassificationDataset(root=path, imgsz=imgsz, augment=augment, cache=cache)
|
||||
batch_size = min(batch_size, len(dataset))
|
||||
nd = torch.cuda.device_count()
|
||||
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])
|
||||
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(6148914691236517205 + RANK)
|
||||
return InfiniteDataLoader(dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle and sampler is None,
|
||||
num_workers=nw,
|
||||
sampler=sampler,
|
||||
pin_memory=PIN_MEMORY,
|
||||
worker_init_fn=seed_worker,
|
||||
generator=generator) # or DataLoader(persistent_workers=True)
|
0
ultralytics/yolo/data/dataloaders/__init__.py
Normal file
0
ultralytics/yolo/data/dataloaders/__init__.py
Normal file
0
ultralytics/yolo/data/dataloaders/box.py
Normal file
0
ultralytics/yolo/data/dataloaders/box.py
Normal file
0
ultralytics/yolo/data/dataloaders/segment.py
Normal file
0
ultralytics/yolo/data/dataloaders/segment.py
Normal file
213
ultralytics/yolo/data/dataset.py
Normal file
213
ultralytics/yolo/data/dataset.py
Normal file
@ -0,0 +1,213 @@
|
||||
from itertools import repeat
|
||||
from multiprocessing.pool import Pool
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..utils.general import LOGGER, NUM_THREADS
|
||||
from .augment import *
|
||||
from .base import BaseDataset
|
||||
from .utils import BAR_FORMAT, HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label
|
||||
|
||||
|
||||
class YOLODataset(BaseDataset):
|
||||
cache_version = 0.6 # dataset labels *.cache version
|
||||
rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
|
||||
"""YOLO Dataset.
|
||||
Args:
|
||||
img_path (str): image path.
|
||||
prefix (str): prefix.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_path,
|
||||
img_size=640,
|
||||
label_path=None,
|
||||
cache=False,
|
||||
augment=True,
|
||||
hyp=None,
|
||||
prefix="",
|
||||
rect=False,
|
||||
batch_size=None,
|
||||
stride=32,
|
||||
pad=0.0,
|
||||
single_cls=False,
|
||||
use_segments=False,
|
||||
use_keypoints=False,
|
||||
):
|
||||
self.use_segments = use_segments
|
||||
self.use_keypoints = use_keypoints
|
||||
assert not (self.use_segments and self.use_keypoints), "We can't use both of segmentation and pose."
|
||||
super().__init__(img_path, img_size, label_path, cache, augment, hyp, prefix, rect, batch_size, stride, pad,
|
||||
single_cls)
|
||||
|
||||
def cache_labels(self, path=Path("./labels.cache")):
|
||||
# Cache dataset labels, check images and read shapes
|
||||
x = {"labels": []}
|
||||
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
||||
desc = f"{self.prefix}Scanning '{path.parent / path.stem}' images and labels..."
|
||||
with Pool(NUM_THREADS) as pool:
|
||||
pbar = tqdm(
|
||||
pool.imap(verify_image_label,
|
||||
zip(self.im_files, self.label_files, repeat(self.prefix), repeat(self.use_keypoints))),
|
||||
desc=desc,
|
||||
total=len(self.im_files),
|
||||
bar_format=BAR_FORMAT,
|
||||
)
|
||||
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
|
||||
nm += nm_f
|
||||
nf += nf_f
|
||||
ne += ne_f
|
||||
nc += nc_f
|
||||
if im_file:
|
||||
x["labels"].append(
|
||||
dict(
|
||||
im_file=im_file,
|
||||
shape=shape,
|
||||
cls=lb[:, 0:1], # n, 1
|
||||
bboxes=lb[:, 1:], # n, 4
|
||||
segments=segments,
|
||||
keypoints=keypoint,
|
||||
normalized=True,
|
||||
bbox_format="xywh",
|
||||
))
|
||||
if msg:
|
||||
msgs.append(msg)
|
||||
pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupt"
|
||||
|
||||
pbar.close()
|
||||
if msgs:
|
||||
LOGGER.info("\n".join(msgs))
|
||||
if nf == 0:
|
||||
LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}")
|
||||
x["hash"] = get_hash(self.label_files + self.im_files)
|
||||
x["results"] = nf, nm, ne, nc, len(self.im_files)
|
||||
x["msgs"] = msgs # warnings
|
||||
x["version"] = self.cache_version # cache version
|
||||
try:
|
||||
np.save(path, x) # save cache for next time
|
||||
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
|
||||
LOGGER.info(f"{self.prefix}New cache created: {path}")
|
||||
except Exception as e:
|
||||
LOGGER.warning(
|
||||
f"{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable: {e}") # not writeable
|
||||
return x
|
||||
|
||||
def get_labels(self):
|
||||
self.label_files = img2label_paths(self.im_files)
|
||||
cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
|
||||
try:
|
||||
cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
|
||||
assert cache["version"] == self.cache_version # matches current version
|
||||
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
|
||||
except Exception:
|
||||
cache, exists = self.cache_labels(cache_path), False # run cache ops
|
||||
|
||||
# Display cache
|
||||
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
|
||||
if exists and LOCAL_RANK in {-1, 0}:
|
||||
d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupt"
|
||||
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=BAR_FORMAT) # display cache results
|
||||
if cache["msgs"]:
|
||||
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
||||
assert nf > 0, f"{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}"
|
||||
|
||||
# Read cache
|
||||
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
|
||||
labels = cache["labels"]
|
||||
nl = len(np.concatenate([label["cls"] for label in labels], 0)) # number of labels
|
||||
assert nl > 0, f"{self.prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}"
|
||||
return labels
|
||||
|
||||
# TODO: use hyp config to set all these augmentations
|
||||
def build_transforms(self, hyp=None):
|
||||
mosaic = self.augment and not self.rect
|
||||
# mosaic = False
|
||||
if self.augment:
|
||||
if mosaic:
|
||||
transforms = mosaic_transforms(self.img_size, hyp)
|
||||
else:
|
||||
transforms = affine_transforms(self.img_size, hyp)
|
||||
else:
|
||||
transforms = Compose([LetterBox(new_shape=(self.img_size, self.img_size))])
|
||||
transforms.append(Format(bbox_format="xywh", normalize=True, mask=self.use_segments, batch_idx=True))
|
||||
return transforms
|
||||
|
||||
def update_labels_info(self, label):
|
||||
"""custom your label format here"""
|
||||
# NOTE: cls is not with bboxes now, since other tasks like classification and semantic segmentation need a independent cls label
|
||||
# we can make it also support classification and semantic segmentation by add or remove some dict keys there.
|
||||
bboxes = label.pop("bboxes")
|
||||
segments = label.pop("segments", None)
|
||||
keypoints = label.pop("keypoints", None)
|
||||
bbox_format = label.pop("bbox_format")
|
||||
normalized = label.pop("normalized")
|
||||
label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
|
||||
return label
|
||||
|
||||
@staticmethod
|
||||
def collate_fn(batch):
|
||||
# TODO: returning a dict can make thing easier and cleaner when using dataset in training
|
||||
# but I don't know if this will slow down a little bit.
|
||||
new_batch = {}
|
||||
keys = batch[0].keys()
|
||||
values = list(zip(*[list(b.values()) for b in batch]))
|
||||
for i, k in enumerate(keys):
|
||||
value = values[i]
|
||||
if k == "img":
|
||||
value = torch.stack(value, 0)
|
||||
if k in ["mask", "keypoint", "bboxes", "cls"]:
|
||||
value = torch.cat(value, 0)
|
||||
new_batch[k] = values[i]
|
||||
new_batch["batch_idx"] = list(new_batch["batch_idx"])
|
||||
for i in range(len(new_batch["batch_idx"])):
|
||||
new_batch["batch_idx"][i] += i # add target image index for build_targets()
|
||||
new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
|
||||
return new_batch
|
||||
|
||||
|
||||
# Classification dataloaders -------------------------------------------------------------------------------------------
|
||||
class ClassificationDataset(torchvision.datasets.ImageFolder):
|
||||
"""
|
||||
YOLOv5 Classification Dataset.
|
||||
Arguments
|
||||
root: Dataset path
|
||||
transform: torchvision transforms, used by default
|
||||
album_transform: Albumentations transforms, used if installed
|
||||
"""
|
||||
|
||||
def __init__(self, root, augment, imgsz, cache=False):
|
||||
super().__init__(root=root)
|
||||
self.torch_transforms = classify_transforms(imgsz)
|
||||
self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
|
||||
self.cache_ram = cache is True or cache == "ram"
|
||||
self.cache_disk = cache == "disk"
|
||||
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
|
||||
|
||||
def __getitem__(self, i):
|
||||
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
|
||||
if self.cache_ram and im is None:
|
||||
im = self.samples[i][3] = cv2.imread(f)
|
||||
elif self.cache_disk:
|
||||
if not fn.exists(): # load npy
|
||||
np.save(fn.as_posix(), cv2.imread(f))
|
||||
im = np.load(fn)
|
||||
else: # read image
|
||||
im = cv2.imread(f) # BGR
|
||||
if self.album_transforms:
|
||||
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
|
||||
else:
|
||||
sample = self.torch_transforms(im)
|
||||
return sample, j
|
||||
|
||||
|
||||
# TODO: support semantic segmentation
|
||||
class SemanticDataset(BaseDataset):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
37
ultralytics/yolo/data/dataset_wrappers.py
Normal file
37
ultralytics/yolo/data/dataset_wrappers.py
Normal file
@ -0,0 +1,37 @@
|
||||
import collections
|
||||
from copy import deepcopy
|
||||
|
||||
from .augment import LetterBox
|
||||
|
||||
|
||||
class MixAndRectDataset:
|
||||
"""A wrapper of multiple images mixed dataset.
|
||||
|
||||
Args:
|
||||
dataset (:obj:`BaseDataset`): The dataset to be mixed.
|
||||
transforms (Sequence[dict]): config dict to be composed.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset):
|
||||
self.dataset = dataset
|
||||
self.img_size = dataset.img_size
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, index):
|
||||
labels = deepcopy(self.dataset[index])
|
||||
for transform in self.dataset.transforms.tolist():
|
||||
# mosaic and mixup
|
||||
if hasattr(transform, "get_indexes"):
|
||||
indexes = transform.get_indexes(self.dataset)
|
||||
if not isinstance(indexes, collections.abc.Sequence):
|
||||
indexes = [indexes]
|
||||
mix_labels = [deepcopy(self.dataset[index]) for index in indexes]
|
||||
labels["mix_labels"] = mix_labels
|
||||
if self.dataset.rect and isinstance(transform, LetterBox):
|
||||
transform.new_shape = self.dataset.batch_shapes[self.dataset.batch[index]]
|
||||
labels = transform(labels)
|
||||
if "mix_labels" in labels:
|
||||
labels.pop("mix_labels")
|
||||
return labels
|
177
ultralytics/yolo/data/utils.py
Normal file
177
ultralytics/yolo/data/utils.py
Normal file
@ -0,0 +1,177 @@
|
||||
import contextlib
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import ExifTags, Image, ImageOps
|
||||
|
||||
from ..utils.general import segments2boxes
|
||||
|
||||
HELP_URL = "See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data"
|
||||
IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # include image suffixes
|
||||
VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv" # include video suffixes
|
||||
BAR_FORMAT = "{l_bar}{bar:10}{r_bar}{bar:-10b}" # tqdm bar format
|
||||
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||
RANK = int(os.getenv('RANK', -1))
|
||||
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
|
||||
IMAGENET_MEAN = 0.485, 0.456, 0.406 # RGB mean
|
||||
IMAGENET_STD = 0.229, 0.224, 0.225 # RGB standard deviation
|
||||
|
||||
# Get orientation exif tag
|
||||
for orientation in ExifTags.TAGS.keys():
|
||||
if ExifTags.TAGS[orientation] == "Orientation":
|
||||
break
|
||||
|
||||
|
||||
def img2label_paths(img_paths):
|
||||
# Define label paths as a function of image paths
|
||||
sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings
|
||||
return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
|
||||
|
||||
|
||||
def get_hash(paths):
|
||||
# Returns a single hash value of a list of paths (files or dirs)
|
||||
size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
|
||||
h = hashlib.md5(str(size).encode()) # hash sizes
|
||||
h.update("".join(paths).encode()) # hash paths
|
||||
return h.hexdigest() # return hash
|
||||
|
||||
|
||||
def exif_size(img):
|
||||
# Returns exif-corrected PIL size
|
||||
s = img.size # (width, height)
|
||||
with contextlib.suppress(Exception):
|
||||
rotation = dict(img._getexif().items())[orientation]
|
||||
if rotation in [6, 8]: # rotation 270 or 90
|
||||
s = (s[1], s[0])
|
||||
return s
|
||||
|
||||
|
||||
def verify_image_label(args):
|
||||
# Verify one image-label pair
|
||||
im_file, lb_file, prefix, keypoint = args
|
||||
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", None, None # number (missing, found, empty, corrupt), message, segments, keypoints
|
||||
try:
|
||||
# verify images
|
||||
im = Image.open(im_file)
|
||||
im.verify() # PIL verify
|
||||
shape = exif_size(im) # image size
|
||||
shape = (shape[1], shape[0]) # hw
|
||||
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
||||
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
|
||||
if im.format.lower() in ("jpg", "jpeg"):
|
||||
with open(im_file, "rb") as f:
|
||||
f.seek(-2, 2)
|
||||
if f.read() != b"\xff\xd9": # corrupt JPEG
|
||||
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
|
||||
msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved"
|
||||
|
||||
# verify labels
|
||||
if os.path.isfile(lb_file):
|
||||
nf = 1 # label found
|
||||
with open(lb_file) as f:
|
||||
lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
|
||||
if any(len(x) > 6 for x in lb) and (not keypoint): # is segment
|
||||
classes = np.array([x[0] for x in lb], dtype=np.float32)
|
||||
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
|
||||
lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
|
||||
lb = np.array(lb, dtype=np.float32)
|
||||
nl = len(lb)
|
||||
if nl:
|
||||
if keypoint:
|
||||
assert lb.shape[1] == 56, "labels require 56 columns each"
|
||||
assert (lb[:, 5::3] <= 1).all(), "non-normalized or out of bounds coordinate labels"
|
||||
assert (lb[:, 6::3] <= 1).all(), "non-normalized or out of bounds coordinate labels"
|
||||
kpts = np.zeros((lb.shape[0], 39))
|
||||
for i in range(len(lb)):
|
||||
kpt = np.delete(lb[i, 5:], np.arange(2, lb.shape[1] - 5,
|
||||
3)) # remove the occlusion paramater from the GT
|
||||
kpts[i] = np.hstack((lb[i, :5], kpt))
|
||||
lb = kpts
|
||||
assert lb.shape[1] == 39, "labels require 39 columns each after removing occlusion paramater"
|
||||
else:
|
||||
assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
|
||||
assert (lb >= 0).all(), f"negative label values {lb[lb < 0]}"
|
||||
assert (lb[:, 1:] <=
|
||||
1).all(), f"non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}"
|
||||
_, i = np.unique(lb, axis=0, return_index=True)
|
||||
if len(i) < nl: # duplicate row check
|
||||
lb = lb[i] # remove duplicates
|
||||
if segments:
|
||||
segments = [segments[x] for x in i]
|
||||
msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed"
|
||||
else:
|
||||
ne = 1 # label empty
|
||||
lb = np.zeros((0, 39), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32)
|
||||
else:
|
||||
nm = 1 # label missing
|
||||
lb = np.zeros((0, 39), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32)
|
||||
if keypoint:
|
||||
keypoints = lb[:, 5:].reshape(-1, 17, 2)
|
||||
lb = lb[:, :5]
|
||||
return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
|
||||
except Exception as e:
|
||||
nc = 1
|
||||
msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
|
||||
return [None, None, None, None, None, nm, nf, ne, nc, msg]
|
||||
|
||||
|
||||
def polygon2mask(img_size, polygons, color=1, downsample_ratio=1):
|
||||
"""
|
||||
Args:
|
||||
img_size (tuple): The image size.
|
||||
polygons (np.ndarray): [N, M], N is the number of polygons,
|
||||
M is the number of points(Be divided by 2).
|
||||
"""
|
||||
mask = np.zeros(img_size, dtype=np.uint8)
|
||||
polygons = np.asarray(polygons)
|
||||
polygons = polygons.astype(np.int32)
|
||||
shape = polygons.shape
|
||||
polygons = polygons.reshape(shape[0], -1, 2)
|
||||
cv2.fillPoly(mask, polygons, color=color)
|
||||
nh, nw = (img_size[0] // downsample_ratio, img_size[1] // downsample_ratio)
|
||||
# NOTE: fillPoly firstly then resize is trying the keep the same way
|
||||
# of loss calculation when mask-ratio=1.
|
||||
mask = cv2.resize(mask, (nw, nh))
|
||||
return mask
|
||||
|
||||
|
||||
def polygons2masks(img_size, polygons, color, downsample_ratio=1):
|
||||
"""
|
||||
Args:
|
||||
img_size (tuple): The image size.
|
||||
polygons (list[np.ndarray]): each polygon is [N, M],
|
||||
N is the number of polygons,
|
||||
M is the number of points(Be divided by 2).
|
||||
"""
|
||||
masks = []
|
||||
for si in range(len(polygons)):
|
||||
mask = polygon2mask(img_size, [polygons[si].reshape(-1)], color, downsample_ratio)
|
||||
masks.append(mask)
|
||||
return np.array(masks)
|
||||
|
||||
|
||||
def polygons2masks_overlap(img_size, segments, downsample_ratio=1):
|
||||
"""Return a (640, 640) overlap mask."""
|
||||
masks = np.zeros((img_size[0] // downsample_ratio, img_size[1] // downsample_ratio),
|
||||
dtype=np.int32 if len(segments) > 255 else np.uint8)
|
||||
areas = []
|
||||
ms = []
|
||||
for si in range(len(segments)):
|
||||
mask = polygon2mask(
|
||||
img_size,
|
||||
[segments[si].reshape(-1)],
|
||||
downsample_ratio=downsample_ratio,
|
||||
color=1,
|
||||
)
|
||||
ms.append(mask)
|
||||
areas.append(mask.sum())
|
||||
areas = np.asarray(areas)
|
||||
index = np.argsort(-areas)
|
||||
ms = np.array(ms)[index]
|
||||
for i in range(len(segments)):
|
||||
mask = ms[i] * (i + 1)
|
||||
masks = masks + mask
|
||||
masks = np.clip(masks, a_min=0, a_max=i + 1)
|
||||
return masks, index
|
Reference in New Issue
Block a user