Trainer + Dataloaders (#27)

Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Ayush Chaurasia <ayushchaurasia@Ayushs-MacBook-Pro.local>
Co-authored-by: Ayush Chaurasia <ayush.chuararsia@gmail.com>
This commit is contained in:
Ayush Chaurasia
2022-10-10 17:31:07 +05:30
committed by GitHub
parent 7a2e5fdfa3
commit d0b3c9812b
27 changed files with 2885 additions and 9 deletions

View File

@ -0,0 +1,3 @@
from .build import build_classification_dataloader, build_dataloader
from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
from .dataset_wrappers import MixAndRectDataset

View File

@ -0,0 +1,785 @@
import collections
import math
import random
from copy import deepcopy
import cv2
import numpy as np
import torch
import torchvision.transforms as T
from ..utils.general import LOGGER, check_version, colorstr, segment2box
from ..utils.instance import Instances
from ..utils.metrics import bbox_ioa
from .utils import IMAGENET_MEAN, IMAGENET_STD, polygons2masks, polygons2masks_overlap
# TODO: we might need a BaseTransform to make all these augments be compatible with both classification and semantic
class BaseTransform:
def __init__(self) -> None:
pass
def apply_image(self, labels):
pass
def apply_instances(self, labels):
pass
def apply_semantic(self, labels):
pass
def __call__(self, labels):
self.apply_image(labels)
self.apply_instances(labels)
self.apply_semantic(labels)
class Compose:
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, data):
for t in self.transforms:
data = t(data)
return data
def append(self, transform):
self.transforms.append(transform)
def tolist(self):
return self.transforms
def __repr__(self):
format_string = f"{self.__class__.__name__}("
for t in self.transforms:
format_string += "\n"
format_string += f" {t}"
format_string += "\n)"
return format_string
class BaseMixTransform:
"""This implementation is from mmyolo"""
def __init__(self, pre_transform=None, p=0.0) -> None:
self.pre_transform = pre_transform
self.p = p
def __call__(self, labels):
if random.uniform(0, 1) > self.p:
return labels
assert "dataset" in labels
dataset = labels.pop("dataset")
# get index of one or three other images
indexes = self.get_indexes(dataset)
if not isinstance(indexes, collections.abc.Sequence):
indexes = [indexes]
# get images information will be used for Mosaic or MixUp
mix_labels = [deepcopy(dataset.get_label_info(index)) for index in indexes]
if self.pre_transform is not None:
for i, data in enumerate(mix_labels):
# pre_transform may also require dataset
data.update({"dataset": dataset})
# before Mosaic or MixUp need to go through
# the necessary pre_transform
_labels = self.pre_transform(data)
_labels.pop("dataset")
mix_labels[i] = _labels
labels["mix_labels"] = mix_labels
# Mosaic or MixUp
labels = self._mix_transform(labels)
if "mix_labels" in labels:
labels.pop("mix_labels")
labels["dataset"] = dataset
return labels
def _mix_transform(self, labels):
raise NotImplementedError
def get_indexes(self, dataset):
raise NotImplementedError
class Mosaic(BaseMixTransform):
"""Mosaic augmentation.
Args:
img_size (Sequence[int]): Image size after mosaic pipeline of single
image. The shape order should be (height, width).
Default to (640, 640).
"""
def __init__(self, img_size=640, p=1.0, border=(0, 0)):
assert 0 <= p <= 1.0, "The probability should be in range [0, 1]. " f"got {p}."
super().__init__(pre_transform=None, p=p)
self.img_size = img_size
self.border = border
def get_indexes(self, dataset):
return [random.randint(0, len(dataset)) for _ in range(3)]
def _mix_transform(self, labels):
mosaic_labels = []
assert labels.get("rect_shape", None) is None, "rect and mosaic is exclusive."
assert len(labels.get("mix_labels", [])) > 0, "There are no other images for mosaic augment."
s = self.img_size
yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y
mix_labels = labels["mix_labels"]
for i in range(4):
labels_patch = deepcopy(labels) if i == 0 else deepcopy(mix_labels[i - 1])
# Load image
img = labels_patch["img"]
h, w = labels_patch["resized_shape"]
# place img in img4
if i == 0: # top left
img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
elif i == 1: # top right
x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
elif i == 2: # bottom left
x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
elif i == 3: # bottom right
x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
padw = x1a - x1b
padh = y1a - y1b
labels_patch = self._update_labels(labels_patch, padw, padh)
mosaic_labels.append(labels_patch)
final_labels = self._cat_labels(mosaic_labels)
final_labels["img"] = img4
return final_labels
def _update_labels(self, labels, padw, padh):
"""Update labels"""
nh, nw = labels["img"].shape[:2]
labels["instances"].convert_bbox(format="xyxy")
labels["instances"].denormalize(nw, nh)
labels["instances"].add_padding(padw, padh)
return labels
def _cat_labels(self, mosaic_labels):
if len(mosaic_labels) == 0:
return {}
cls = []
instances = []
for labels in mosaic_labels:
cls.append(labels["cls"])
instances.append(labels["instances"])
final_labels = {
"ori_shape": (self.img_size * 2, self.img_size * 2),
"resized_shape": (self.img_size * 2, self.img_size * 2),
"im_file": mosaic_labels[0]["im_file"],
"cls": np.concatenate(cls, 0)}
final_labels["instances"] = Instances.concatenate(instances, axis=0)
final_labels["instances"].clip(self.img_size * 2, self.img_size * 2)
return final_labels
class MixUp(BaseMixTransform):
def __init__(self, pre_transform=None, p=0.0) -> None:
super().__init__(pre_transform=pre_transform, p=p)
def get_indexes(self, dataset):
return random.randint(0, len(dataset))
def _mix_transform(self, labels):
im = labels["img"]
labels2 = labels["mix_labels"][0]
im2 = labels2["img"]
cls2 = labels2["cls"]
# Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
im = (im * r + im2 * (1 - r)).astype(np.uint8)
cat_instances = Instances.concatenate([labels["instances"], labels2["instances"]], axis=0)
cls = labels["cls"]
labels["img"] = im
labels["instances"] = cat_instances
labels["cls"] = np.concatenate([cls, cls2], 0)
return labels
class RandomPerspective:
def __init__(self, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, border=(0, 0)):
self.degrees = degrees
self.translate = translate
self.scale = scale
self.shear = shear
self.perspective = perspective
# mosaic border
self.border = border
def affine_transform(self, img):
# Center
C = np.eye(3)
C[0, 2] = -img.shape[1] / 2 # x translation (pixels)
C[1, 2] = -img.shape[0] / 2 # y translation (pixels)
# Perspective
P = np.eye(3)
P[2, 0] = random.uniform(-self.perspective, self.perspective) # x perspective (about y)
P[2, 1] = random.uniform(-self.perspective, self.perspective) # y perspective (about x)
# Rotation and Scale
R = np.eye(3)
a = random.uniform(-self.degrees, self.degrees)
# a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
s = random.uniform(1 - self.scale, 1 + self.scale)
# s = 2 ** random.uniform(-scale, scale)
R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
# Shear
S = np.eye(3)
S[0, 1] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) # x shear (deg)
S[1, 0] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) # y shear (deg)
# Translation
T = np.eye(3)
T[0, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[0] # x translation (pixels)
T[1, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[1] # y translation (pixels)
# Combined rotation matrix
M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
# affine image
if (self.border[0] != 0) or (self.border[1] != 0) or (M != np.eye(3)).any(): # image changed
if self.perspective:
img = cv2.warpPerspective(img, M, dsize=self.size, borderValue=(114, 114, 114))
else: # affine
img = cv2.warpAffine(img, M[:2], dsize=self.size, borderValue=(114, 114, 114))
return img, M, s
def apply_bboxes(self, bboxes, M):
"""apply affine to bboxes only.
Args:
bboxes(ndarray): list of bboxes, xyxy format, with shape (num_bboxes, 4).
M(ndarray): affine matrix.
Returns:
new_bboxes(ndarray): bboxes after affine, [num_bboxes, 4].
"""
n = len(bboxes)
if n == 0:
return bboxes
xy = np.ones((n * 4, 3))
xy[:, :2] = bboxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
xy = xy @ M.T # transform
xy = (xy[:, :2] / xy[:, 2:3] if self.perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
# create new boxes
x = xy[:, [0, 2, 4, 6]]
y = xy[:, [1, 3, 5, 7]]
return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
def apply_segments(self, segments, M):
"""apply affine to segments and generate new bboxes from segments.
Args:
segments(ndarray): list of segments, [num_samples, 500, 2].
M(ndarray): affine matrix.
Returns:
new_segments(ndarray): list of segments after affine, [num_samples, 500, 2].
new_bboxes(ndarray): bboxes after affine, [N, 4].
"""
n, num = segments.shape[:2]
if n == 0:
return [], segments
xy = np.ones((n * num, 3))
segments = segments.reshape(-1, 2)
xy[:, :2] = segments
xy = xy @ M.T # transform
xy = xy[:, :2] / xy[:, 2:3]
segments = xy.reshape(n, -1, 2)
bboxes = np.stack([segment2box(xy, self.size[0], self.size[1]) for xy in segments], 0)
return bboxes, segments
def apply_keypoints(self, keypoints, M):
"""apply affine to keypoints.
Args:
keypoints(ndarray): keypoints, [N, 17, 2].
M(ndarray): affine matrix.
Return:
new_keypoints(ndarray): keypoints after affine, [N, 17, 2].
"""
n = len(keypoints)
if n == 0:
return keypoints
new_keypoints = np.ones((n * 17, 3))
new_keypoints[:, :2] = keypoints.reshape(n * 17, 2) # num_kpt is hardcoded to 17
new_keypoints = new_keypoints @ M.T # transform
new_keypoints = (new_keypoints[:, :2] / new_keypoints[:, 2:3]).reshape(n, 34) # perspective rescale or affine
new_keypoints[keypoints.reshape(-1, 34) == 0] = 0
x_kpts = new_keypoints[:, list(range(0, 34, 2))]
y_kpts = new_keypoints[:, list(range(1, 34, 2))]
x_kpts[np.logical_or.reduce((x_kpts < 0, x_kpts > self.size[0], y_kpts < 0, y_kpts > self.size[1]))] = 0
y_kpts[np.logical_or.reduce((x_kpts < 0, x_kpts > self.size[0], y_kpts < 0, y_kpts > self.size[1]))] = 0
new_keypoints[:, list(range(0, 34, 2))] = x_kpts
new_keypoints[:, list(range(1, 34, 2))] = y_kpts
return new_keypoints.reshape(n, 17, 2)
def __call__(self, labels):
"""
Affine images and targets.
Args:
img(ndarray): image.
labels(Dict): a dict of `bboxes`, `segments`, `keypoints`.
"""
img = labels["img"]
cls = labels["cls"]
instances = labels["instances"]
# make sure the coord formats are right
instances.convert_bbox(format="xyxy")
instances.denormalize(*img.shape[:2][::-1])
self.size = img.shape[1] + self.border[1] * 2, img.shape[0] + self.border[0] * 2 # w, h
# M is affine matrix
# scale for func:`box_candidates`
img, M, scale = self.affine_transform(img)
bboxes = self.apply_bboxes(instances.bboxes, M)
segments = instances.segments
keypoints = instances.keypoints
# update bboxes if there are segments.
if segments is not None:
bboxes, segments = self.apply_segments(segments, M)
if keypoints is not None:
keypoints = self.apply_keypoints(keypoints, M)
new_instances = Instances(bboxes, segments, keypoints, bbox_format="xyxy", normalized=False)
new_instances.clip(*self.size)
# filter instances
instances.scale(scale_w=scale, scale_h=scale, bbox_only=True)
# make the bboxes have the same scale with new_bboxes
i = self.box_candidates(box1=instances.bboxes.T,
box2=new_instances.bboxes.T,
area_thr=0.01 if segments is not None else 0.10)
labels["instances"] = new_instances[i]
# clip
labels["cls"] = cls[i]
labels["img"] = img
return labels
def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
# Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
class RandomHSV:
def __init__(self, hgain=0.5, sgain=0.5, vgain=0.5) -> None:
self.hgain = hgain
self.sgain = sgain
self.vgain = vgain
def __call__(self, labels):
img = labels["img"]
if self.hgain or self.sgain or self.vgain:
r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1 # random gains
hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
dtype = img.dtype # uint8
x = np.arange(0, 256, dtype=r.dtype)
lut_hue = ((x * r[0]) % 180).astype(dtype)
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
labels["img"] = img
return labels
class RandomFlip:
def __init__(self, p=0.5, direction="horizontal") -> None:
assert direction in ["horizontal", "vertical"], f"Support direction `horizontal` or `vertical`, got {direction}"
assert 0 <= p <= 1.0
self.p = p
self.direction = direction
def __call__(self, labels):
img = labels["img"]
instances = labels["instances"]
instances.convert_bbox(format="xywh")
h, w = img.shape[:2]
h = 1 if instances.normalized else h
w = 1 if instances.normalized else w
# Flip up-down
if self.direction == "vertical" and random.random() < self.p:
img = np.flipud(img)
img = np.ascontiguousarray(img)
instances.flipud(h)
if self.direction == "horizontal" and random.random() < self.p:
img = np.fliplr(img)
img = np.ascontiguousarray(img)
instances.fliplr(w)
labels["img"] = img
labels["instances"] = instances
return labels
class LetterBox:
"""Resize image and padding for detection, instance segmentation, pose"""
def __init__(self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32):
self.new_shape = new_shape
self.auto = auto
self.scaleFill = scaleFill
self.scaleup = scaleup
self.stride = stride
def __call__(self, labels):
img = labels["img"]
shape = img.shape[:2] # current shape [height, width]
new_shape = labels.get("rect_shape", self.new_shape)
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
if not self.scaleup: # only scale down, do not scale up (for better val mAP)
r = min(r, 1.0)
# Compute padding
ratio = r, r # width, height ratios
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
if self.auto: # minimum rectangle
dw, dh = np.mod(dw, self.stride), np.mod(dh, self.stride) # wh padding
elif self.scaleFill: # stretch
dw, dh = 0.0, 0.0
new_unpad = (new_shape[1], new_shape[0])
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
dw /= 2 # divide padding into 2 sides
dh /= 2
if shape[::-1] != new_unpad: # resize
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT,
value=(114, 114, 114)) # add border
labels = self._update_labels(labels, ratio, dw, dh)
labels["img"] = img
return labels
def _update_labels(self, labels, ratio, padw, padh):
"""Update labels"""
labels["instances"].convert_bbox(format="xyxy")
labels["instances"].denormalize(*labels["img"].shape[:2][::-1])
labels["instances"].scale(*ratio)
labels["instances"].add_padding(padw, padh)
return labels
class CopyPaste:
def __init__(self, p=0.5) -> None:
self.p = p
def __call__(self, labels):
# Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
im = labels["img"]
cls = labels["cls"]
bboxes = labels["instances"].bboxes
segments = labels["instances"].segments # n, 1000, 2
keypoints = labels["instances"].keypoints
if self.p and segments is not None:
n = len(segments)
h, w, _ = im.shape # height, width, channels
im_new = np.zeros(im.shape, np.uint8)
# TODO: this implement can be parallel since segments are ndarray, also might work with Instances inside
for j in random.sample(range(n), k=round(self.p * n)):
c, b, s = cls[j], bboxes[j], segments[j]
box = w - b[2], b[1], w - b[0], b[3]
ioa = bbox_ioa(box, bboxes) # intersection over area
if (ioa < 0.30).all(): # allow 30% obscuration of existing labels
bboxes = np.concatenate((bboxes, [box]), 0)
cls = np.concatenate((cls, c[None]), axis=0)
segments = np.concatenate((segments, np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1)[None]), 0)
if keypoints is not None:
keypoints = np.concatenate(
(keypoints, np.concatenate((w - keypoints[j][:, 0:1], keypoints[j][:, 1:2]), 1)), 0)
cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED)
result = cv2.bitwise_and(src1=im, src2=im_new)
result = cv2.flip(result, 1) # augment segments (flip left-right)
i = result > 0 # pixels to replace
# i[:, :] = result.max(2).reshape(h, w, 1) # act over ch
im[i] = result[i] # cv2.imwrite('debug.jpg', im) # debug
labels["img"] = im
labels["cls"] = cls
labels["instances"].update(bboxes, segments, keypoints)
return labels
class Albumentations:
# YOLOv5 Albumentations class (optional, only used if package is installed)
def __init__(self, p=1.0):
self.p = p
self.transform = None
prefix = colorstr("albumentations: ")
try:
import albumentations as A
check_version(A.__version__, "1.0.3", hard=True) # version requirement
T = [
A.Blur(p=0.01),
A.MedianBlur(p=0.01),
A.ToGray(p=0.01),
A.CLAHE(p=0.01),
A.RandomBrightnessContrast(p=0.0),
A.RandomGamma(p=0.0),
A.ImageCompression(quality_lower=75, p=0.0),] # transforms
self.transform = A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"]))
LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
except ImportError: # package not installed, skip
pass
except Exception as e:
LOGGER.info(f"{prefix}{e}")
def __call__(self, labels):
im = labels["img"]
cls = labels["cls"]
if len(cls):
labels["instances"].convert_bbox("xywh")
labels["instances"].normalize(*im.shape[:2][::-1])
bboxes = labels["instances"].bboxes
# TODO: add supports of segments and keypoints
if self.transform and random.random() < self.p:
new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed
labels["img"] = new["image"]
labels["cls"] = np.array(new["class_labels"])
labels["instances"].update(bboxes=bboxes)
return labels
# TODO: technically this is not an augmentation, maybe we should put this to another files
class Format:
def __init__(self, bbox_format="xywh", normalize=True, mask=False, mask_ratio=4, mask_overlap=True, batch_idx=True):
self.bbox_format = bbox_format
self.normalize = normalize
self.mask = mask # set False when training detection only
self.mask_ratio = mask_ratio
self.mask_overlap = mask_overlap
self.batch_idx = batch_idx # keep the batch indexes
def __call__(self, labels):
img = labels["img"]
h, w = img.shape[:2]
cls = labels.pop("cls")
instances = labels.pop("instances")
instances.convert_bbox(format=self.bbox_format)
instances.denormalize(w, h)
nl = len(instances)
if instances.segments is not None and self.mask:
masks, instances, cls = self._format_segments(instances, cls, w, h)
labels["masks"] = (torch.from_numpy(masks) if nl else torch.zeros(
1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio, img.shape[1] // self.mask_ratio))
if self.normalize:
instances.normalize(w, h)
labels["img"] = self._format_img(img)
labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl)
labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
if instances.keypoints is not None:
labels["keypoints"] = torch.from_numpy(instances.keypoints) if nl else torch.zeros((nl, 17, 2))
# then we can use collate_fn
if self.batch_idx:
labels["batch_idx"] = torch.zeros(nl)
return labels
def _format_img(self, img):
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1))
img = torch.from_numpy(img)
return img
def _format_segments(self, instances, cls, w, h):
"""convert polygon points to bitmap"""
segments = instances.segments
if self.mask_overlap:
masks, sorted_idx = polygons2masks_overlap((h, w), segments, downsample_ratio=self.mask_ratio)
masks = masks[None] # (640, 640) -> (1, 640, 640)
instances = instances[sorted_idx]
cls = cls[sorted_idx]
else:
masks = polygons2masks((h, w), segments, color=1, downsample_ratio=self.mask_ratio)
return masks, instances, cls
def mosaic_transforms(img_size, hyp):
pre_transform = Compose([
Mosaic(img_size=img_size, p=hyp.mosaic, border=[-img_size // 2, -img_size // 2]),
CopyPaste(p=hyp.copy_paste),
RandomPerspective(
degrees=hyp.degrees,
translate=hyp.translate,
scale=hyp.scale,
shear=hyp.shear,
perspective=hyp.perspective,
border=[-img_size // 2, -img_size // 2],
),])
transforms = Compose([
pre_transform,
MixUp(
pre_transform=pre_transform,
p=hyp.mixup,
),
Albumentations(p=1.0),
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
RandomFlip(direction="vertical", p=hyp.flipud),
RandomFlip(direction="horizontal", p=hyp.fliplr),])
return transforms
def affine_transforms(img_size, hyp):
# rect, randomperspective, albumentation, hsv, flipud, fliplr
transforms = Compose([
LetterBox(new_shape=(img_size, img_size)),
RandomPerspective(
degrees=hyp.degrees,
translate=hyp.translate,
scale=hyp.scale,
shear=hyp.shear,
perspective=hyp.perspective,
border=[0, 0],
),
Albumentations(p=1.0),
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
RandomFlip(direction="vertical", p=hyp.flipud),
RandomFlip(direction="horizontal", p=hyp.fliplr),])
return transforms
# Classification augmentations -------------------------------------------------------------------------------------------
def classify_transforms(size=224):
# Transforms to apply if albumentations not installed
assert isinstance(size, int), f"ERROR: classify_transforms size {size} must be integer, not (list, tuple)"
# T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
def classify_albumentations(
augment=True,
size=224,
scale=(0.08, 1.0),
hflip=0.5,
vflip=0.0,
jitter=0.4,
mean=IMAGENET_MEAN,
std=IMAGENET_STD,
auto_aug=False,
):
# YOLOv5 classification Albumentations (optional, only used if package is installed)
prefix = colorstr("albumentations: ")
try:
import albumentations as A
from albumentations.pytorch import ToTensorV2
check_version(A.__version__, "1.0.3", hard=True) # version requirement
if augment: # Resize and crop
T = [A.RandomResizedCrop(height=size, width=size, scale=scale)]
if auto_aug:
# TODO: implement AugMix, AutoAug & RandAug in albumentation
LOGGER.info(f"{prefix}auto augmentations are currently not supported")
else:
if hflip > 0:
T += [A.HorizontalFlip(p=hflip)]
if vflip > 0:
T += [A.VerticalFlip(p=vflip)]
if jitter > 0:
color_jitter = (float(jitter),) * 3 # repeat value for brightness, contrast, satuaration, 0 hue
T += [A.ColorJitter(*color_jitter, 0)]
else: # Use fixed crop for eval set (reproducibility)
T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor
LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
return A.Compose(T)
except ImportError: # package not installed, skip
pass
except Exception as e:
LOGGER.info(f"{prefix}{e}")
class ClassifyLetterBox:
# YOLOv5 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
def __init__(self, size=(640, 640), auto=False, stride=32):
super().__init__()
self.h, self.w = (size, size) if isinstance(size, int) else size
self.auto = auto # pass max size integer, automatically solve for short side using stride
self.stride = stride # used with auto
def __call__(self, im): # im = np.array HWC
imh, imw = im.shape[:2]
r = min(self.h / imh, self.w / imw) # ratio of new/old
h, w = round(imh * r), round(imw * r) # resized image
hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w
top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1)
im_out = np.full((self.h, self.w, 3), 114, dtype=im.dtype)
im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
return im_out
class CenterCrop:
# YOLOv5 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])
def __init__(self, size=640):
super().__init__()
self.h, self.w = (size, size) if isinstance(size, int) else size
def __call__(self, im): # im = np.array HWC
imh, imw = im.shape[:2]
m = min(imh, imw) # min dimension
top, left = (imh - m) // 2, (imw - m) // 2
return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
class ToTensor:
# YOLOv5 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
def __init__(self, half=False):
super().__init__()
self.half = half
def __call__(self, im): # im = np.array HWC in BGR order
im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1]) # HWC to CHW -> BGR to RGB -> contiguous
im = torch.from_numpy(im) # to torch
im = im.half() if self.half else im.float() # uint8 to fp16/32
im /= 255.0 # 0-255 to 0.0-1.0
return im

View File

@ -0,0 +1,224 @@
import glob
import os
from multiprocessing.pool import ThreadPool
from pathlib import Path
from typing import Optional
import cv2
import numpy as np
from torch.utils.data import Dataset
from tqdm import tqdm
from ..utils.general import NUM_THREADS
from .utils import BAR_FORMAT, HELP_URL, IMG_FORMATS, LOCAL_RANK
class BaseDataset(Dataset):
"""Base Dataset.
Args:
img_path (str): image path.
pipeline (dict): a dict of image transforms.
label_path (str): label path, this can also be a ann_file or other custom label path.
"""
def __init__(
self,
img_path,
img_size=640,
label_path=None,
cache=False,
augment=True,
hyp=None,
prefix="",
rect=False,
batch_size=None,
stride=32,
pad=0.5,
single_cls=False,
):
super().__init__()
self.img_path = img_path
self.img_size = img_size
self.label_path = label_path
self.augment = augment
self.prefix = prefix
self.im_files = self.get_img_files(self.img_path)
self.labels = self.get_labels()
if single_cls:
self.update_labels(include_class=[], single_cls=single_cls)
self.ni = len(self.im_files)
# rect stuff
self.rect = rect
self.batch_size = batch_size
self.stride = stride
self.pad = pad
if self.rect:
assert self.batch_size is not None
self.set_rectangle()
# cache stuff
self.ims = [None] * self.ni
self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files]
if cache:
self.cache_images()
# transforms
self.transforms = self.build_transforms(hyp=hyp)
def get_img_files(self, img_path):
"""Read image files."""
try:
f = [] # image files
for p in img_path if isinstance(img_path, list) else [img_path]:
p = Path(p) # os-agnostic
if p.is_dir(): # dir
f += glob.glob(str(p / "**" / "*.*"), recursive=True)
# f = list(p.rglob('*.*')) # pathlib
elif p.is_file(): # file
with open(p) as t:
t = t.read().strip().splitlines()
parent = str(p.parent) + os.sep
f += [x.replace("./", parent) if x.startswith("./") else x for x in t] # local to global path
# f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
else:
raise FileNotFoundError(f"{self.prefix}{p} does not exist")
im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
assert im_files, f"{self.prefix}No images found"
except Exception as e:
raise Exception(f"{self.prefix}Error loading data from {img_path}: {e}\n{HELP_URL}")
return im_files
def update_labels(self, include_class: Optional[list]):
"""include_class, filter labels to include only these classes (optional)"""
include_class_array = np.array(include_class).reshape(1, -1)
for i in range(len(self.labels)):
if include_class:
cls = self.labels[i]["cls"]
bboxes = self.labels[i]["bboxes"]
segments = self.labels[i]["segments"]
j = (cls == include_class_array).any(1)
self.labels[i]["cls"] = cls[j]
self.labels[i]["bboxes"] = bboxes[j]
if segments:
self.labels[i]["segments"] = segments[j]
if self.single_cls:
self.labels[i]["cls"] = 0
def load_image(self, i):
# Loads 1 image from dataset index 'i', returns (im, resized hw)
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
if im is None: # not cached in RAM
if fn.exists(): # load npy
im = np.load(fn)
else: # read image
im = cv2.imread(f) # BGR
assert im is not None, f"Image Not Found {f}"
h0, w0 = im.shape[:2] # orig hw
r = self.img_size / max(h0, w0) # ratio
if r != 1: # if sizes are not equal
interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
im = cv2.resize(im, (int(w0 * r), int(h0 * r)), interpolation=interp)
return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
def cache_images(self):
# cache images to memory or disk
gb = 0 # Gigabytes of cached images
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
fcn = self.cache_images_to_disk if self.cache == "disk" else self.load_image
results = ThreadPool(NUM_THREADS).imap(fcn, range(self.ni))
pbar = tqdm(enumerate(results), total=self.ni, bar_format=BAR_FORMAT, disable=LOCAL_RANK > 0)
for i, x in pbar:
if self.cache == "disk":
gb += self.npy_files[i].stat().st_size
else: # 'ram'
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
gb += self.ims[i].nbytes
pbar.desc = f"{self.prefix}Caching images ({gb / 1E9:.1f}GB {self.cache})"
pbar.close()
def cache_images_to_disk(self, i):
# Saves an image as an *.npy file for faster loading
f = self.npy_files[i]
if not f.exists():
np.save(f.as_posix(), cv2.imread(self.im_files[i]))
def set_rectangle(self):
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
nb = bi[-1] + 1 # number of batches
s = np.array([x["shape"] for x in self.labels]) # hw
ar = s[:, 0] / s[:, 1] # aspect ratio
irect = ar.argsort()
self.im_files = [self.im_files[i] for i in irect]
self.labels = [self.labels[i] for i in irect]
ar = ar[irect]
# Set training image shapes
shapes = [[1, 1]] * nb
for i in range(nb):
ari = ar[bi == i]
mini, maxi = ari.min(), ari.max()
if maxi < 1:
shapes[i] = [maxi, 1]
elif mini > 1:
shapes[i] = [1, 1 / mini]
self.batch_shapes = np.ceil(np.array(shapes) * self.img_size / self.stride + self.pad).astype(int) * self.stride
self.batch = bi # batch index of image
def __getitem__(self, index):
label = self.get_label_info(index)
if self.augment:
label["dataset"] = self
return self.transforms(label)
def get_label_info(self, index):
label = self.labels[index].copy()
img, (h0, w0), (h, w) = self.load_image(index)
label["img"] = img
label["ori_shape"] = (h0, w0)
label["resized_shape"] = (h, w)
if self.rect:
label["rect_shape"] = self.batch_shapes[self.batch[index]]
label = self.update_labels_info(label)
return label
def __len__(self):
return len(self.im_files)
def update_labels_info(self, label):
"""custom your label format here"""
return label
def build_transforms(self, hyp=None):
"""Users can custom augmentations here
like:
if self.augment:
# training transforms
return Compose([])
else:
# val transforms
return Compose([])
"""
raise NotImplementedError
def get_labels(self):
"""Users can custom their own format here.
Make sure your output is a list with each element like below:
dict(
im_file=im_file,
shape=shape, # format: (height, width)
cls=cls,
bboxes=bboxes, # xywh
segments=segments, # xy
keypoints=keypoints, # xy
normalized=True, # or False
bbox_format="xyxy", # or xywh, ltwh
)
"""
raise NotImplementedError

View File

@ -0,0 +1,145 @@
import os
import random
import numpy as np
import torch
from torch.utils.data import DataLoader, dataloader, distributed
from ..utils.general import LOGGER
from ..utils.torch_utils import torch_distributed_zero_first
from .dataset import ClassificationDataset, YOLODataset
from .utils import PIN_MEMORY, RANK
class InfiniteDataLoader(dataloader.DataLoader):
"""Dataloader that reuses workers
Uses same syntax as vanilla DataLoader
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))
self.iterator = super().__iter__()
def __len__(self):
return len(self.batch_sampler.sampler)
def __iter__(self):
for _ in range(len(self)):
yield next(self.iterator)
class _RepeatSampler:
"""Sampler that repeats forever
Args:
sampler (Sampler)
"""
def __init__(self, sampler):
self.sampler = sampler
def __iter__(self):
while True:
yield from iter(self.sampler)
def seed_worker(worker_id):
# Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader
worker_seed = torch.initial_seed() % 2 ** 32
np.random.seed(worker_seed)
random.seed(worker_seed)
# TODO: we can inject most args from a config file
def build_dataloader(
img_path,
img_size, #
batch_size, #
single_cls=False, #
hyp=None, #
augment=False,
cache=False, #
image_weights=False, #
stride=32,
label_path=None,
pad=0.0,
rect=False,
rank=-1,
workers=8,
prefix="",
shuffle=False,
use_segments=False,
use_keypoints=False,
):
if rect and shuffle:
LOGGER.warning("WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False")
shuffle = False
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = YOLODataset(
img_path=img_path,
img_size=img_size,
batch_size=batch_size,
label_path=label_path,
augment=augment, # augmentation
hyp=hyp,
rect=rect, # rectangular batches
cache=cache,
single_cls=single_cls,
stride=int(stride),
pad=pad,
prefix=prefix,
use_segments=use_segments,
use_keypoints=use_keypoints,
)
batch_size = min(batch_size, len(dataset))
nd = torch.cuda.device_count() # number of CUDA devices
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
generator = torch.Generator()
generator.manual_seed(6148914691236517205 + RANK)
return (
loader(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle and sampler is None,
num_workers=nw,
sampler=sampler,
pin_memory=PIN_MEMORY,
collate_fn=getattr(dataset, "collate_fn", None),
worker_init_fn=seed_worker,
generator=generator,
),
dataset,
)
# build classification
def build_classification_dataloader(path,
imgsz=224,
batch_size=16,
augment=True,
cache=False,
rank=-1,
workers=8,
shuffle=True):
# Returns Dataloader object to be used with YOLOv5 Classifier
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = ClassificationDataset(root=path, imgsz=imgsz, augment=augment, cache=cache)
batch_size = min(batch_size, len(dataset))
nd = torch.cuda.device_count()
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
generator = torch.Generator()
generator.manual_seed(6148914691236517205 + RANK)
return InfiniteDataLoader(dataset,
batch_size=batch_size,
shuffle=shuffle and sampler is None,
num_workers=nw,
sampler=sampler,
pin_memory=PIN_MEMORY,
worker_init_fn=seed_worker,
generator=generator) # or DataLoader(persistent_workers=True)

View File

View File

@ -0,0 +1,213 @@
from itertools import repeat
from multiprocessing.pool import Pool
from pathlib import Path
import cv2
import numpy as np
import torch
import torchvision
from tqdm import tqdm
from ..utils.general import LOGGER, NUM_THREADS
from .augment import *
from .base import BaseDataset
from .utils import BAR_FORMAT, HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label
class YOLODataset(BaseDataset):
cache_version = 0.6 # dataset labels *.cache version
rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
"""YOLO Dataset.
Args:
img_path (str): image path.
prefix (str): prefix.
"""
def __init__(
self,
img_path,
img_size=640,
label_path=None,
cache=False,
augment=True,
hyp=None,
prefix="",
rect=False,
batch_size=None,
stride=32,
pad=0.0,
single_cls=False,
use_segments=False,
use_keypoints=False,
):
self.use_segments = use_segments
self.use_keypoints = use_keypoints
assert not (self.use_segments and self.use_keypoints), "We can't use both of segmentation and pose."
super().__init__(img_path, img_size, label_path, cache, augment, hyp, prefix, rect, batch_size, stride, pad,
single_cls)
def cache_labels(self, path=Path("./labels.cache")):
# Cache dataset labels, check images and read shapes
x = {"labels": []}
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
desc = f"{self.prefix}Scanning '{path.parent / path.stem}' images and labels..."
with Pool(NUM_THREADS) as pool:
pbar = tqdm(
pool.imap(verify_image_label,
zip(self.im_files, self.label_files, repeat(self.prefix), repeat(self.use_keypoints))),
desc=desc,
total=len(self.im_files),
bar_format=BAR_FORMAT,
)
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
nm += nm_f
nf += nf_f
ne += ne_f
nc += nc_f
if im_file:
x["labels"].append(
dict(
im_file=im_file,
shape=shape,
cls=lb[:, 0:1], # n, 1
bboxes=lb[:, 1:], # n, 4
segments=segments,
keypoints=keypoint,
normalized=True,
bbox_format="xywh",
))
if msg:
msgs.append(msg)
pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupt"
pbar.close()
if msgs:
LOGGER.info("\n".join(msgs))
if nf == 0:
LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}")
x["hash"] = get_hash(self.label_files + self.im_files)
x["results"] = nf, nm, ne, nc, len(self.im_files)
x["msgs"] = msgs # warnings
x["version"] = self.cache_version # cache version
try:
np.save(path, x) # save cache for next time
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
LOGGER.info(f"{self.prefix}New cache created: {path}")
except Exception as e:
LOGGER.warning(
f"{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable: {e}") # not writeable
return x
def get_labels(self):
self.label_files = img2label_paths(self.im_files)
cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
try:
cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
assert cache["version"] == self.cache_version # matches current version
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
except Exception:
cache, exists = self.cache_labels(cache_path), False # run cache ops
# Display cache
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
if exists and LOCAL_RANK in {-1, 0}:
d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupt"
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=BAR_FORMAT) # display cache results
if cache["msgs"]:
LOGGER.info("\n".join(cache["msgs"])) # display warnings
assert nf > 0, f"{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}"
# Read cache
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
labels = cache["labels"]
nl = len(np.concatenate([label["cls"] for label in labels], 0)) # number of labels
assert nl > 0, f"{self.prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}"
return labels
# TODO: use hyp config to set all these augmentations
def build_transforms(self, hyp=None):
mosaic = self.augment and not self.rect
# mosaic = False
if self.augment:
if mosaic:
transforms = mosaic_transforms(self.img_size, hyp)
else:
transforms = affine_transforms(self.img_size, hyp)
else:
transforms = Compose([LetterBox(new_shape=(self.img_size, self.img_size))])
transforms.append(Format(bbox_format="xywh", normalize=True, mask=self.use_segments, batch_idx=True))
return transforms
def update_labels_info(self, label):
"""custom your label format here"""
# NOTE: cls is not with bboxes now, since other tasks like classification and semantic segmentation need a independent cls label
# we can make it also support classification and semantic segmentation by add or remove some dict keys there.
bboxes = label.pop("bboxes")
segments = label.pop("segments", None)
keypoints = label.pop("keypoints", None)
bbox_format = label.pop("bbox_format")
normalized = label.pop("normalized")
label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
return label
@staticmethod
def collate_fn(batch):
# TODO: returning a dict can make thing easier and cleaner when using dataset in training
# but I don't know if this will slow down a little bit.
new_batch = {}
keys = batch[0].keys()
values = list(zip(*[list(b.values()) for b in batch]))
for i, k in enumerate(keys):
value = values[i]
if k == "img":
value = torch.stack(value, 0)
if k in ["mask", "keypoint", "bboxes", "cls"]:
value = torch.cat(value, 0)
new_batch[k] = values[i]
new_batch["batch_idx"] = list(new_batch["batch_idx"])
for i in range(len(new_batch["batch_idx"])):
new_batch["batch_idx"][i] += i # add target image index for build_targets()
new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
return new_batch
# Classification dataloaders -------------------------------------------------------------------------------------------
class ClassificationDataset(torchvision.datasets.ImageFolder):
"""
YOLOv5 Classification Dataset.
Arguments
root: Dataset path
transform: torchvision transforms, used by default
album_transform: Albumentations transforms, used if installed
"""
def __init__(self, root, augment, imgsz, cache=False):
super().__init__(root=root)
self.torch_transforms = classify_transforms(imgsz)
self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
self.cache_ram = cache is True or cache == "ram"
self.cache_disk = cache == "disk"
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
def __getitem__(self, i):
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
if self.cache_ram and im is None:
im = self.samples[i][3] = cv2.imread(f)
elif self.cache_disk:
if not fn.exists(): # load npy
np.save(fn.as_posix(), cv2.imread(f))
im = np.load(fn)
else: # read image
im = cv2.imread(f) # BGR
if self.album_transforms:
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
else:
sample = self.torch_transforms(im)
return sample, j
# TODO: support semantic segmentation
class SemanticDataset(BaseDataset):
def __init__(self):
pass

View File

@ -0,0 +1,37 @@
import collections
from copy import deepcopy
from .augment import LetterBox
class MixAndRectDataset:
"""A wrapper of multiple images mixed dataset.
Args:
dataset (:obj:`BaseDataset`): The dataset to be mixed.
transforms (Sequence[dict]): config dict to be composed.
"""
def __init__(self, dataset):
self.dataset = dataset
self.img_size = dataset.img_size
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
labels = deepcopy(self.dataset[index])
for transform in self.dataset.transforms.tolist():
# mosaic and mixup
if hasattr(transform, "get_indexes"):
indexes = transform.get_indexes(self.dataset)
if not isinstance(indexes, collections.abc.Sequence):
indexes = [indexes]
mix_labels = [deepcopy(self.dataset[index]) for index in indexes]
labels["mix_labels"] = mix_labels
if self.dataset.rect and isinstance(transform, LetterBox):
transform.new_shape = self.dataset.batch_shapes[self.dataset.batch[index]]
labels = transform(labels)
if "mix_labels" in labels:
labels.pop("mix_labels")
return labels

View File

@ -0,0 +1,177 @@
import contextlib
import hashlib
import os
import cv2
import numpy as np
from PIL import ExifTags, Image, ImageOps
from ..utils.general import segments2boxes
HELP_URL = "See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data"
IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # include image suffixes
VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv" # include video suffixes
BAR_FORMAT = "{l_bar}{bar:10}{r_bar}{bar:-10b}" # tqdm bar format
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
IMAGENET_MEAN = 0.485, 0.456, 0.406 # RGB mean
IMAGENET_STD = 0.229, 0.224, 0.225 # RGB standard deviation
# Get orientation exif tag
for orientation in ExifTags.TAGS.keys():
if ExifTags.TAGS[orientation] == "Orientation":
break
def img2label_paths(img_paths):
# Define label paths as a function of image paths
sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings
return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
def get_hash(paths):
# Returns a single hash value of a list of paths (files or dirs)
size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
h = hashlib.md5(str(size).encode()) # hash sizes
h.update("".join(paths).encode()) # hash paths
return h.hexdigest() # return hash
def exif_size(img):
# Returns exif-corrected PIL size
s = img.size # (width, height)
with contextlib.suppress(Exception):
rotation = dict(img._getexif().items())[orientation]
if rotation in [6, 8]: # rotation 270 or 90
s = (s[1], s[0])
return s
def verify_image_label(args):
# Verify one image-label pair
im_file, lb_file, prefix, keypoint = args
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", None, None # number (missing, found, empty, corrupt), message, segments, keypoints
try:
# verify images
im = Image.open(im_file)
im.verify() # PIL verify
shape = exif_size(im) # image size
shape = (shape[1], shape[0]) # hw
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
if im.format.lower() in ("jpg", "jpeg"):
with open(im_file, "rb") as f:
f.seek(-2, 2)
if f.read() != b"\xff\xd9": # corrupt JPEG
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved"
# verify labels
if os.path.isfile(lb_file):
nf = 1 # label found
with open(lb_file) as f:
lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
if any(len(x) > 6 for x in lb) and (not keypoint): # is segment
classes = np.array([x[0] for x in lb], dtype=np.float32)
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
lb = np.array(lb, dtype=np.float32)
nl = len(lb)
if nl:
if keypoint:
assert lb.shape[1] == 56, "labels require 56 columns each"
assert (lb[:, 5::3] <= 1).all(), "non-normalized or out of bounds coordinate labels"
assert (lb[:, 6::3] <= 1).all(), "non-normalized or out of bounds coordinate labels"
kpts = np.zeros((lb.shape[0], 39))
for i in range(len(lb)):
kpt = np.delete(lb[i, 5:], np.arange(2, lb.shape[1] - 5,
3)) # remove the occlusion paramater from the GT
kpts[i] = np.hstack((lb[i, :5], kpt))
lb = kpts
assert lb.shape[1] == 39, "labels require 39 columns each after removing occlusion paramater"
else:
assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
assert (lb >= 0).all(), f"negative label values {lb[lb < 0]}"
assert (lb[:, 1:] <=
1).all(), f"non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}"
_, i = np.unique(lb, axis=0, return_index=True)
if len(i) < nl: # duplicate row check
lb = lb[i] # remove duplicates
if segments:
segments = [segments[x] for x in i]
msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed"
else:
ne = 1 # label empty
lb = np.zeros((0, 39), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32)
else:
nm = 1 # label missing
lb = np.zeros((0, 39), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32)
if keypoint:
keypoints = lb[:, 5:].reshape(-1, 17, 2)
lb = lb[:, :5]
return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
except Exception as e:
nc = 1
msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
return [None, None, None, None, None, nm, nf, ne, nc, msg]
def polygon2mask(img_size, polygons, color=1, downsample_ratio=1):
"""
Args:
img_size (tuple): The image size.
polygons (np.ndarray): [N, M], N is the number of polygons,
M is the number of points(Be divided by 2).
"""
mask = np.zeros(img_size, dtype=np.uint8)
polygons = np.asarray(polygons)
polygons = polygons.astype(np.int32)
shape = polygons.shape
polygons = polygons.reshape(shape[0], -1, 2)
cv2.fillPoly(mask, polygons, color=color)
nh, nw = (img_size[0] // downsample_ratio, img_size[1] // downsample_ratio)
# NOTE: fillPoly firstly then resize is trying the keep the same way
# of loss calculation when mask-ratio=1.
mask = cv2.resize(mask, (nw, nh))
return mask
def polygons2masks(img_size, polygons, color, downsample_ratio=1):
"""
Args:
img_size (tuple): The image size.
polygons (list[np.ndarray]): each polygon is [N, M],
N is the number of polygons,
M is the number of points(Be divided by 2).
"""
masks = []
for si in range(len(polygons)):
mask = polygon2mask(img_size, [polygons[si].reshape(-1)], color, downsample_ratio)
masks.append(mask)
return np.array(masks)
def polygons2masks_overlap(img_size, segments, downsample_ratio=1):
"""Return a (640, 640) overlap mask."""
masks = np.zeros((img_size[0] // downsample_ratio, img_size[1] // downsample_ratio),
dtype=np.int32 if len(segments) > 255 else np.uint8)
areas = []
ms = []
for si in range(len(segments)):
mask = polygon2mask(
img_size,
[segments[si].reshape(-1)],
downsample_ratio=downsample_ratio,
color=1,
)
ms.append(mask)
areas.append(mask.sum())
areas = np.asarray(areas)
index = np.argsort(-areas)
ms = np.array(ms)[index]
for i in range(len(segments)):
mask = ms[i] * (i + 1)
masks = masks + mask
masks = np.clip(masks, a_min=0, a_max=i + 1)
return masks, index