Simplify augmentations (#93)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2022-12-25 21:21:26 +01:00
committed by GitHub
parent 249dfbdc05
commit ae05d44877
5 changed files with 36 additions and 68 deletions

View File

@ -1,4 +1,3 @@
import collections
import math
import random
from copy import deepcopy
@ -65,7 +64,8 @@ class Compose:
class BaseMixTransform:
"""This implementation is from mmyolo"""
def __init__(self, pre_transform=None, p=0.0) -> None:
def __init__(self, dataset, pre_transform=None, p=0.0) -> None:
self.dataset = dataset
self.pre_transform = pre_transform
self.p = p
@ -73,41 +73,28 @@ class BaseMixTransform:
if random.uniform(0, 1) > self.p:
return labels
assert "dataset" in labels
dataset = labels.pop("dataset")
# get index of one or three other images
indexes = self.get_indexes(dataset)
if not isinstance(indexes, collections.abc.Sequence):
indexes = self.get_indexes()
if isinstance(indexes, int):
indexes = [indexes]
# get images information will be used for Mosaic or MixUp
mix_labels = [dataset.get_label_info(index) for index in indexes]
mix_labels = [self.dataset.get_label_info(i) for i in indexes]
if self.pre_transform is not None:
for i, data in enumerate(mix_labels):
# pre_transform may also require dataset
data.update({"dataset": dataset})
# before Mosaic or MixUp need to go through
# the necessary pre_transform
_labels = self.pre_transform(data)
_labels.pop("dataset")
mix_labels[i] = _labels
mix_labels[i] = self.pre_transform(data)
labels["mix_labels"] = mix_labels
# Mosaic or MixUp
labels = self._mix_transform(labels)
if "mix_labels" in labels:
labels.pop("mix_labels")
labels["dataset"] = dataset
labels.pop("mix_labels", None)
return labels
def _mix_transform(self, labels):
raise NotImplementedError
def get_indexes(self, dataset):
def get_indexes(self):
raise NotImplementedError
@ -119,14 +106,15 @@ class Mosaic(BaseMixTransform):
Default to (640, 640).
"""
def __init__(self, imgsz=640, p=1.0, border=(0, 0)):
def __init__(self, dataset, imgsz=640, p=1.0, border=(0, 0)):
assert 0 <= p <= 1.0, "The probability should be in range [0, 1]. " f"got {p}."
super().__init__(pre_transform=None, p=p)
super().__init__(dataset=dataset, p=p)
self.dataset = dataset
self.imgsz = imgsz
self.border = border
def get_indexes(self, dataset):
return [random.randint(0, len(dataset) - 1) for _ in range(3)]
def get_indexes(self):
return [random.randint(0, len(self.dataset) - 1) for _ in range(3)]
def _mix_transform(self, labels):
mosaic_labels = []
@ -193,25 +181,19 @@ class Mosaic(BaseMixTransform):
class MixUp(BaseMixTransform):
def __init__(self, pre_transform=None, p=0.0) -> None:
super().__init__(pre_transform=pre_transform, p=p)
def __init__(self, dataset, pre_transform=None, p=0.0) -> None:
super().__init__(dataset=dataset, pre_transform=pre_transform, p=p)
def get_indexes(self, dataset):
return random.randint(0, len(dataset) - 1)
def get_indexes(self):
return random.randint(0, len(self.dataset) - 1)
def _mix_transform(self, labels):
im = labels["img"]
labels2 = labels["mix_labels"][0]
im2 = labels2["img"]
cls2 = labels2["cls"]
# Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
im = (im * r + im2 * (1 - r)).astype(np.uint8)
cat_instances = Instances.concatenate([labels["instances"], labels2["instances"]], axis=0)
cls = labels["cls"]
labels["img"] = im
labels["instances"] = cat_instances
labels["cls"] = np.concatenate([cls, cls2], 0)
labels2 = labels["mix_labels"][0]
labels["img"] = (labels["img"] * r + labels2["img"] * (1 - r)).astype(np.uint8)
labels["instances"] = Instances.concatenate([labels["instances"], labels2["instances"]], axis=0)
labels["cls"] = np.concatenate([labels["cls"], labels2["cls"]], 0)
return labels
@ -412,7 +394,6 @@ class RandomHSV:
im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
labels["img"] = img
return labels
@ -606,7 +587,6 @@ class Format:
self.batch_idx = batch_idx # keep the batch indexes
def __call__(self, labels):
labels.pop("dataset", None)
img = labels["img"]
h, w = img.shape[:2]
cls = labels.pop("cls")
@ -656,9 +636,9 @@ class Format:
return masks, instances, cls
def mosaic_transforms(imgsz, hyp):
def mosaic_transforms(dataset, imgsz, hyp):
pre_transform = Compose([
Mosaic(imgsz=imgsz, p=hyp.mosaic, border=[-imgsz // 2, -imgsz // 2]),
Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic, border=[-imgsz // 2, -imgsz // 2]),
CopyPaste(p=hyp.copy_paste),
RandomPerspective(
degrees=hyp.degrees,
@ -670,7 +650,7 @@ def mosaic_transforms(imgsz, hyp):
),])
return Compose([
pre_transform,
MixUp(pre_transform=pre_transform, p=hyp.mixup),
MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
Albumentations(p=1.0),
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
RandomFlip(direction="vertical", p=hyp.flipud),

View File

@ -42,12 +42,13 @@ class BaseDataset(Dataset):
self.imgsz = imgsz
self.label_path = label_path
self.augment = augment
self.single_cls = single_cls
self.prefix = prefix
self.im_files = self.get_img_files(self.img_path)
self.labels = self.get_labels()
if single_cls:
self.update_labels(include_class=[], single_cls=single_cls)
if self.single_cls:
self.update_labels(include_class=[])
self.ni = len(self.im_files)
@ -173,10 +174,7 @@ class BaseDataset(Dataset):
self.batch = bi # batch index of image
def __getitem__(self, index):
label = self.get_label_info(index)
if self.augment:
label["dataset"] = self
return self.transforms(label)
return self.transforms(self.get_label_info(index))
def get_label_info(self, index):
label = self.labels[index].copy()

View File

@ -1,7 +1,6 @@
from itertools import repeat
from multiprocessing.pool import Pool
from pathlib import Path
from typing import OrderedDict
import torchvision
from tqdm import tqdm
@ -126,7 +125,7 @@ class YOLODataset(BaseDataset):
def build_transforms(self, hyp=None):
if self.augment:
mosaic = self.augment and not self.rect
transforms = mosaic_transforms(self.imgsz, hyp) if mosaic else affine_transforms(self.imgsz, hyp)
transforms = mosaic_transforms(self, self.imgsz, hyp) if mosaic else affine_transforms(self.imgsz, hyp)
else:
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz))])
transforms.append(