You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

774 lines
30 KiB

import math
import random
from copy import deepcopy
import cv2
import numpy as np
import torch
import torchvision.transforms as T
from ..utils import LOGGER, colorstr
from ..utils.checks import check_version
from ..utils.instance import Instances
from ..utils.metrics import bbox_ioa
from ..utils.ops import segment2box
from .utils import IMAGENET_MEAN, IMAGENET_STD, polygons2masks, polygons2masks_overlap
# TODO: we might need a BaseTransform to make all these augments be compatible with both classification and semantic
class BaseTransform:
def __init__(self) -> None:
pass
def apply_image(self, labels):
pass
def apply_instances(self, labels):
pass
def apply_semantic(self, labels):
pass
def __call__(self, labels):
self.apply_image(labels)
self.apply_instances(labels)
self.apply_semantic(labels)
class Compose:
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, data):
for t in self.transforms:
data = t(data)
return data
def append(self, transform):
self.transforms.append(transform)
def tolist(self):
return self.transforms
def __repr__(self):
format_string = f"{self.__class__.__name__}("
for t in self.transforms:
format_string += "\n"
format_string += f" {t}"
format_string += "\n)"
return format_string
class BaseMixTransform:
"""This implementation is from mmyolo"""
def __init__(self, dataset, pre_transform=None, p=0.0) -> None:
self.dataset = dataset
self.pre_transform = pre_transform
self.p = p
def __call__(self, labels):
if random.uniform(0, 1) > self.p:
return labels
# get index of one or three other images
indexes = self.get_indexes()
if isinstance(indexes, int):
indexes = [indexes]
# get images information will be used for Mosaic or MixUp
mix_labels = [self.dataset.get_label_info(i) for i in indexes]
if self.pre_transform is not None:
for i, data in enumerate(mix_labels):
mix_labels[i] = self.pre_transform(data)
labels["mix_labels"] = mix_labels
# Mosaic or MixUp
labels = self._mix_transform(labels)
labels.pop("mix_labels", None)
return labels
def _mix_transform(self, labels):
raise NotImplementedError
def get_indexes(self):
raise NotImplementedError
class Mosaic(BaseMixTransform):
"""Mosaic augmentation.
Args:
imgsz (Sequence[int]): Image size after mosaic pipeline of single
image. The shape order should be (height, width).
Default to (640, 640).
"""
def __init__(self, dataset, imgsz=640, p=1.0, border=(0, 0)):
assert 0 <= p <= 1.0, "The probability should be in range [0, 1]. " f"got {p}."
super().__init__(dataset=dataset, p=p)
self.dataset = dataset
self.imgsz = imgsz
self.border = border
def get_indexes(self):
return [random.randint(0, len(self.dataset) - 1) for _ in range(3)]
def _mix_transform(self, labels):
mosaic_labels = []
assert labels.get("rect_shape", None) is None, "rect and mosaic is exclusive."
assert len(labels.get("mix_labels", [])) > 0, "There are no other images for mosaic augment."
s = self.imgsz
yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y
for i in range(4):
labels_patch = (labels if i == 0 else labels["mix_labels"][i - 1]).copy()
# Load image
img = labels_patch["img"]
h, w = labels_patch["resized_shape"]
# place img in img4
if i == 0: # top left
img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
elif i == 1: # top right
x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
elif i == 2: # bottom left
x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
elif i == 3: # bottom right
x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
padw = x1a - x1b
padh = y1a - y1b
labels_patch = self._update_labels(labels_patch, padw, padh)
mosaic_labels.append(labels_patch)
final_labels = self._cat_labels(mosaic_labels)
final_labels["img"] = img4
return final_labels
def _update_labels(self, labels, padw, padh):
"""Update labels"""
nh, nw = labels["img"].shape[:2]
labels["instances"].convert_bbox(format="xyxy")
labels["instances"].denormalize(nw, nh)
labels["instances"].add_padding(padw, padh)
return labels
def _cat_labels(self, mosaic_labels):
if len(mosaic_labels) == 0:
return {}
cls = []
instances = []
for labels in mosaic_labels:
cls.append(labels["cls"])
instances.append(labels["instances"])
final_labels = {
"ori_shape": mosaic_labels[0]["ori_shape"],
"resized_shape": (self.imgsz * 2, self.imgsz * 2),
"im_file": mosaic_labels[0]["im_file"],
"cls": np.concatenate(cls, 0),
"instances": Instances.concatenate(instances, axis=0)}
final_labels["instances"].clip(self.imgsz * 2, self.imgsz * 2)
return final_labels
class MixUp(BaseMixTransform):
def __init__(self, dataset, pre_transform=None, p=0.0) -> None:
super().__init__(dataset=dataset, pre_transform=pre_transform, p=p)
def get_indexes(self):
return random.randint(0, len(self.dataset) - 1)
def _mix_transform(self, labels):
# Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
labels2 = labels["mix_labels"][0]
labels["img"] = (labels["img"] * r + labels2["img"] * (1 - r)).astype(np.uint8)
labels["instances"] = Instances.concatenate([labels["instances"], labels2["instances"]], axis=0)
labels["cls"] = np.concatenate([labels["cls"], labels2["cls"]], 0)
return labels
class RandomPerspective:
def __init__(self, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, border=(0, 0)):
self.degrees = degrees
self.translate = translate
self.scale = scale
self.shear = shear
self.perspective = perspective
# mosaic border
self.border = border
def affine_transform(self, img):
# Center
C = np.eye(3)
C[0, 2] = -img.shape[1] / 2 # x translation (pixels)
C[1, 2] = -img.shape[0] / 2 # y translation (pixels)
# Perspective
P = np.eye(3)
P[2, 0] = random.uniform(-self.perspective, self.perspective) # x perspective (about y)
P[2, 1] = random.uniform(-self.perspective, self.perspective) # y perspective (about x)
# Rotation and Scale
R = np.eye(3)
a = random.uniform(-self.degrees, self.degrees)
# a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
s = random.uniform(1 - self.scale, 1 + self.scale)
# s = 2 ** random.uniform(-scale, scale)
R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
# Shear
S = np.eye(3)
S[0, 1] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) # x shear (deg)
S[1, 0] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) # y shear (deg)
# Translation
T = np.eye(3)
T[0, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[0] # x translation (pixels)
T[1, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[1] # y translation (pixels)
# Combined rotation matrix
M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
# affine image
if (self.border[0] != 0) or (self.border[1] != 0) or (M != np.eye(3)).any(): # image changed
if self.perspective:
img = cv2.warpPerspective(img, M, dsize=self.size, borderValue=(114, 114, 114))
else: # affine
img = cv2.warpAffine(img, M[:2], dsize=self.size, borderValue=(114, 114, 114))
return img, M, s
def apply_bboxes(self, bboxes, M):
"""apply affine to bboxes only.
Args:
bboxes(ndarray): list of bboxes, xyxy format, with shape (num_bboxes, 4).
M(ndarray): affine matrix.
Returns:
new_bboxes(ndarray): bboxes after affine, [num_bboxes, 4].
"""
n = len(bboxes)
if n == 0:
return bboxes
xy = np.ones((n * 4, 3))
xy[:, :2] = bboxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
xy = xy @ M.T # transform
xy = (xy[:, :2] / xy[:, 2:3] if self.perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
# create new boxes
x = xy[:, [0, 2, 4, 6]]
y = xy[:, [1, 3, 5, 7]]
return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
def apply_segments(self, segments, M):
"""apply affine to segments and generate new bboxes from segments.
Args:
segments(ndarray): list of segments, [num_samples, 500, 2].
M(ndarray): affine matrix.
Returns:
new_segments(ndarray): list of segments after affine, [num_samples, 500, 2].
new_bboxes(ndarray): bboxes after affine, [N, 4].
"""
n, num = segments.shape[:2]
if n == 0:
return [], segments
xy = np.ones((n * num, 3))
segments = segments.reshape(-1, 2)
xy[:, :2] = segments
xy = xy @ M.T # transform
xy = xy[:, :2] / xy[:, 2:3]
segments = xy.reshape(n, -1, 2)
bboxes = np.stack([segment2box(xy, self.size[0], self.size[1]) for xy in segments], 0)
return bboxes, segments
def apply_keypoints(self, keypoints, M):
"""apply affine to keypoints.
Args:
keypoints(ndarray): keypoints, [N, 17, 2].
M(ndarray): affine matrix.
Return:
new_keypoints(ndarray): keypoints after affine, [N, 17, 2].
"""
n = len(keypoints)
if n == 0:
return keypoints
new_keypoints = np.ones((n * 17, 3))
new_keypoints[:, :2] = keypoints.reshape(n * 17, 2) # num_kpt is hardcoded to 17
new_keypoints = new_keypoints @ M.T # transform
new_keypoints = (new_keypoints[:, :2] / new_keypoints[:, 2:3]).reshape(n, 34) # perspective rescale or affine
new_keypoints[keypoints.reshape(-1, 34) == 0] = 0
x_kpts = new_keypoints[:, list(range(0, 34, 2))]
y_kpts = new_keypoints[:, list(range(1, 34, 2))]
x_kpts[np.logical_or.reduce((x_kpts < 0, x_kpts > self.size[0], y_kpts < 0, y_kpts > self.size[1]))] = 0
y_kpts[np.logical_or.reduce((x_kpts < 0, x_kpts > self.size[0], y_kpts < 0, y_kpts > self.size[1]))] = 0
new_keypoints[:, list(range(0, 34, 2))] = x_kpts
new_keypoints[:, list(range(1, 34, 2))] = y_kpts
return new_keypoints.reshape(n, 17, 2)
def __call__(self, labels):
"""
Affine images and targets.
Args:
labels(Dict): a dict of `bboxes`, `segments`, `keypoints`.
"""
img = labels["img"]
cls = labels["cls"]
instances = labels.pop("instances")
# make sure the coord formats are right
instances.convert_bbox(format="xyxy")
instances.denormalize(*img.shape[:2][::-1])
self.size = img.shape[1] + self.border[1] * 2, img.shape[0] + self.border[0] * 2 # w, h
# M is affine matrix
# scale for func:`box_candidates`
img, M, scale = self.affine_transform(img)
bboxes = self.apply_bboxes(instances.bboxes, M)
segments = instances.segments
keypoints = instances.keypoints
# update bboxes if there are segments.
if len(segments):
bboxes, segments = self.apply_segments(segments, M)
if keypoints is not None:
keypoints = self.apply_keypoints(keypoints, M)
new_instances = Instances(bboxes, segments, keypoints, bbox_format="xyxy", normalized=False)
# clip
new_instances.clip(*self.size)
# filter instances
instances.scale(scale_w=scale, scale_h=scale, bbox_only=True)
# make the bboxes have the same scale with new_bboxes
i = self.box_candidates(box1=instances.bboxes.T,
box2=new_instances.bboxes.T,
area_thr=0.01 if len(segments) else 0.10)
labels["instances"] = new_instances[i]
labels["cls"] = cls[i]
labels["img"] = img
labels["resized_shape"] = img.shape[:2]
return labels
def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
# Compute box candidates: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
class RandomHSV:
def __init__(self, hgain=0.5, sgain=0.5, vgain=0.5) -> None:
self.hgain = hgain
self.sgain = sgain
self.vgain = vgain
def __call__(self, labels):
img = labels["img"]
if self.hgain or self.sgain or self.vgain:
r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1 # random gains
hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
dtype = img.dtype # uint8
x = np.arange(0, 256, dtype=r.dtype)
lut_hue = ((x * r[0]) % 180).astype(dtype)
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
return labels
class RandomFlip:
def __init__(self, p=0.5, direction="horizontal") -> None:
assert direction in ["horizontal", "vertical"], f"Support direction `horizontal` or `vertical`, got {direction}"
assert 0 <= p <= 1.0
self.p = p
self.direction = direction
def __call__(self, labels):
img = labels["img"]
instances = labels.pop("instances")
instances.convert_bbox(format="xywh")
h, w = img.shape[:2]
h = 1 if instances.normalized else h
w = 1 if instances.normalized else w
# Flip up-down
if self.direction == "vertical" and random.random() < self.p:
img = np.flipud(img)
instances.flipud(h)
if self.direction == "horizontal" and random.random() < self.p:
img = np.fliplr(img)
instances.fliplr(w)
labels["img"] = np.ascontiguousarray(img)
labels["instances"] = instances
return labels
class LetterBox:
"""Resize image and padding for detection, instance segmentation, pose"""
def __init__(self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32):
self.new_shape = new_shape
self.auto = auto
self.scaleFill = scaleFill
self.scaleup = scaleup
self.stride = stride
def __call__(self, labels=None, image=None):
if labels is None:
labels = {}
img = labels.get("img") if image is None else image
shape = img.shape[:2] # current shape [height, width]
new_shape = labels.pop("rect_shape", self.new_shape)
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
if not self.scaleup: # only scale down, do not scale up (for better val mAP)
r = min(r, 1.0)
# Compute padding
ratio = r, r # width, height ratios
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
if self.auto: # minimum rectangle
dw, dh = np.mod(dw, self.stride), np.mod(dh, self.stride) # wh padding
elif self.scaleFill: # stretch
dw, dh = 0.0, 0.0
new_unpad = (new_shape[1], new_shape[0])
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
dw /= 2 # divide padding into 2 sides
dh /= 2
if shape[::-1] != new_unpad: # resize
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT,
value=(114, 114, 114)) # add border
if len(labels):
labels = self._update_labels(labels, ratio, dw, dh)
labels["img"] = img
labels["resized_shape"] = new_shape
return labels
else:
return img
def _update_labels(self, labels, ratio, padw, padh):
"""Update labels"""
labels["instances"].convert_bbox(format="xyxy")
labels["instances"].denormalize(*labels["img"].shape[:2][::-1])
labels["instances"].scale(*ratio)
labels["instances"].add_padding(padw, padh)
return labels
class CopyPaste:
def __init__(self, p=0.5) -> None:
self.p = p
def __call__(self, labels):
# Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
im = labels["img"]
cls = labels["cls"]
instances = labels.pop("instances")
instances.convert_bbox(format="xyxy")
if self.p and len(instances.segments):
n = len(instances)
_, w, _ = im.shape # height, width, channels
im_new = np.zeros(im.shape, np.uint8)
# calculate ioa first then select indexes randomly
ins_flip = deepcopy(instances)
ins_flip.fliplr(w)
ioa = bbox_ioa(ins_flip.bboxes, instances.bboxes) # intersection over area, (N, M)
indexes = np.nonzero((ioa < 0.30).all(1))[0] # (N, )
n = len(indexes)
for j in random.sample(list(indexes), k=round(self.p * n)):
cls = np.concatenate((cls, cls[[j]]), axis=0)
instances = Instances.concatenate((instances, ins_flip[[j]]), axis=0)
cv2.drawContours(im_new, instances.segments[[j]].astype(np.int32), -1, (1, 1, 1), cv2.FILLED)
result = cv2.flip(im, 1) # augment segments (flip left-right)
i = cv2.flip(im_new, 1).astype(bool)
im[i] = result[i] # cv2.imwrite('debug.jpg', im) # debug
labels["img"] = im
labels["cls"] = cls
labels["instances"] = instances
return labels
class Albumentations:
# YOLOv5 Albumentations class (optional, only used if package is installed)
def __init__(self, p=1.0):
self.p = p
self.transform = None
prefix = colorstr("albumentations: ")
try:
import albumentations as A
check_version(A.__version__, "1.0.3", hard=True) # version requirement
T = [
A.Blur(p=0.01),
A.MedianBlur(p=0.01),
A.ToGray(p=0.01),
A.CLAHE(p=0.01),
A.RandomBrightnessContrast(p=0.0),
A.RandomGamma(p=0.0),
A.ImageCompression(quality_lower=75, p=0.0),] # transforms
self.transform = A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"]))
LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
except ImportError: # package not installed, skip
pass
except Exception as e:
LOGGER.info(f"{prefix}{e}")
def __call__(self, labels):
im = labels["img"]
cls = labels["cls"]
if len(cls):
labels["instances"].convert_bbox("xywh")
labels["instances"].normalize(*im.shape[:2][::-1])
bboxes = labels["instances"].bboxes
# TODO: add supports of segments and keypoints
if self.transform and random.random() < self.p:
new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed
labels["img"] = new["image"]
labels["cls"] = np.array(new["class_labels"])
labels["instances"].update(bboxes=bboxes)
return labels
# TODO: technically this is not an augmentation, maybe we should put this to another files
class Format:
def __init__(self,
bbox_format="xywh",
normalize=True,
return_mask=False,
return_keypoint=False,
mask_ratio=4,
mask_overlap=True,
batch_idx=True):
self.bbox_format = bbox_format
self.normalize = normalize
self.return_mask = return_mask # set False when training detection only
self.return_keypoint = return_keypoint
self.mask_ratio = mask_ratio
self.mask_overlap = mask_overlap
self.batch_idx = batch_idx # keep the batch indexes
def __call__(self, labels):
img = labels["img"]
h, w = img.shape[:2]
cls = labels.pop("cls")
instances = labels.pop("instances")
instances.convert_bbox(format=self.bbox_format)
instances.denormalize(w, h)
nl = len(instances)
if self.return_mask:
if nl:
masks, instances, cls = self._format_segments(instances, cls, w, h)
masks = torch.from_numpy(masks)
else:
masks = torch.zeros(1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio,
img.shape[1] // self.mask_ratio)
labels["masks"] = masks
if self.normalize:
instances.normalize(w, h)
labels["img"] = self._format_img(img)
labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl)
labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
if self.return_keypoint:
labels["keypoints"] = torch.from_numpy(instances.keypoints) if nl else torch.zeros((nl, 17, 2))
# then we can use collate_fn
if self.batch_idx:
labels["batch_idx"] = torch.zeros(nl)
return labels
def _format_img(self, img):
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1)[::-1])
img = torch.from_numpy(img)
return img
def _format_segments(self, instances, cls, w, h):
"""convert polygon points to bitmap"""
segments = instances.segments
if self.mask_overlap:
masks, sorted_idx = polygons2masks_overlap((h, w), segments, downsample_ratio=self.mask_ratio)
masks = masks[None] # (640, 640) -> (1, 640, 640)
instances = instances[sorted_idx]
cls = cls[sorted_idx]
else:
masks = polygons2masks((h, w), segments, color=1, downsample_ratio=self.mask_ratio)
return masks, instances, cls
def mosaic_transforms(dataset, imgsz, hyp):
pre_transform = Compose([
Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic, border=[-imgsz // 2, -imgsz // 2]),
CopyPaste(p=hyp.copy_paste),
RandomPerspective(
degrees=hyp.degrees,
translate=hyp.translate,
scale=hyp.scale,
shear=hyp.shear,
perspective=hyp.perspective,
border=[-imgsz // 2, -imgsz // 2],
),])
return Compose([
pre_transform,
MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
Albumentations(p=1.0),
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
RandomFlip(direction="vertical", p=hyp.flipud),
RandomFlip(direction="horizontal", p=hyp.fliplr),]) # transforms
def affine_transforms(imgsz, hyp):
return Compose([
LetterBox(new_shape=(imgsz, imgsz)),
RandomPerspective(
degrees=hyp.degrees,
translate=hyp.translate,
scale=hyp.scale,
shear=hyp.shear,
perspective=hyp.perspective,
border=[0, 0],
),
Albumentations(p=1.0),
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
RandomFlip(direction="vertical", p=hyp.flipud),
RandomFlip(direction="horizontal", p=hyp.fliplr),]) # transforms
# Classification augmentations -----------------------------------------------------------------------------------------
def classify_transforms(size=224):
# Transforms to apply if albumentations not installed
assert isinstance(size, int), f"ERROR: classify_transforms size {size} must be integer, not (list, tuple)"
# T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
def classify_albumentations(
augment=True,
size=224,
scale=(0.08, 1.0),
hflip=0.5,
vflip=0.0,
jitter=0.4,
mean=IMAGENET_MEAN,
std=IMAGENET_STD,
auto_aug=False,
):
# YOLOv5 classification Albumentations (optional, only used if package is installed)
prefix = colorstr("albumentations: ")
try:
import albumentations as A
from albumentations.pytorch import ToTensorV2
check_version(A.__version__, "1.0.3", hard=True) # version requirement
if augment: # Resize and crop
T = [A.RandomResizedCrop(height=size, width=size, scale=scale)]
if auto_aug:
# TODO: implement AugMix, AutoAug & RandAug in albumentation
LOGGER.info(f"{prefix}auto augmentations are currently not supported")
else:
if hflip > 0:
T += [A.HorizontalFlip(p=hflip)]
if vflip > 0:
T += [A.VerticalFlip(p=vflip)]
if jitter > 0:
color_jitter = (float(jitter),) * 3 # repeat value for brightness, contrast, saturation, 0 hue
T += [A.ColorJitter(*color_jitter, 0)]
else: # Use fixed crop for eval set (reproducibility)
T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor
LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
return A.Compose(T)
except ImportError: # package not installed, skip
pass
except Exception as e:
LOGGER.info(f"{prefix}{e}")
class ClassifyLetterBox:
# YOLOv5 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
def __init__(self, size=(640, 640), auto=False, stride=32):
super().__init__()
self.h, self.w = (size, size) if isinstance(size, int) else size
self.auto = auto # pass max size integer, automatically solve for short side using stride
self.stride = stride # used with auto
def __call__(self, im): # im = np.array HWC
imh, imw = im.shape[:2]
r = min(self.h / imh, self.w / imw) # ratio of new/old
h, w = round(imh * r), round(imw * r) # resized image
hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w
top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1)
im_out = np.full((self.h, self.w, 3), 114, dtype=im.dtype)
im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
return im_out
class CenterCrop:
# YOLOv5 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])
def __init__(self, size=640):
super().__init__()
self.h, self.w = (size, size) if isinstance(size, int) else size
def __call__(self, im): # im = np.array HWC
imh, imw = im.shape[:2]
m = min(imh, imw) # min dimension
top, left = (imh - m) // 2, (imw - m) // 2
return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
class ToTensor:
# YOLOv5 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
def __init__(self, half=False):
super().__init__()
self.half = half
def __call__(self, im): # im = np.array HWC in BGR order
im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1]) # HWC to CHW -> BGR to RGB -> contiguous
im = torch.from_numpy(im) # to torch
im = im.half() if self.half else im.float() # uint8 to fp16/32
im /= 255.0 # 0-255 to 0.0-1.0
return im