Simplify augmentations (#93)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent 249dfbdc05
commit ae05d44877
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,4 +1,3 @@
import collections
import math import math
import random import random
from copy import deepcopy from copy import deepcopy
@ -65,7 +64,8 @@ class Compose:
class BaseMixTransform: class BaseMixTransform:
"""This implementation is from mmyolo""" """This implementation is from mmyolo"""
def __init__(self, pre_transform=None, p=0.0) -> None: def __init__(self, dataset, pre_transform=None, p=0.0) -> None:
self.dataset = dataset
self.pre_transform = pre_transform self.pre_transform = pre_transform
self.p = p self.p = p
@ -73,41 +73,28 @@ class BaseMixTransform:
if random.uniform(0, 1) > self.p: if random.uniform(0, 1) > self.p:
return labels return labels
assert "dataset" in labels
dataset = labels.pop("dataset")
# get index of one or three other images # get index of one or three other images
indexes = self.get_indexes(dataset) indexes = self.get_indexes()
if not isinstance(indexes, collections.abc.Sequence): if isinstance(indexes, int):
indexes = [indexes] indexes = [indexes]
# get images information will be used for Mosaic or MixUp # get images information will be used for Mosaic or MixUp
mix_labels = [dataset.get_label_info(index) for index in indexes] mix_labels = [self.dataset.get_label_info(i) for i in indexes]
if self.pre_transform is not None: if self.pre_transform is not None:
for i, data in enumerate(mix_labels): for i, data in enumerate(mix_labels):
# pre_transform may also require dataset mix_labels[i] = self.pre_transform(data)
data.update({"dataset": dataset})
# before Mosaic or MixUp need to go through
# the necessary pre_transform
_labels = self.pre_transform(data)
_labels.pop("dataset")
mix_labels[i] = _labels
labels["mix_labels"] = mix_labels labels["mix_labels"] = mix_labels
# Mosaic or MixUp # Mosaic or MixUp
labels = self._mix_transform(labels) labels = self._mix_transform(labels)
labels.pop("mix_labels", None)
if "mix_labels" in labels:
labels.pop("mix_labels")
labels["dataset"] = dataset
return labels return labels
def _mix_transform(self, labels): def _mix_transform(self, labels):
raise NotImplementedError raise NotImplementedError
def get_indexes(self, dataset): def get_indexes(self):
raise NotImplementedError raise NotImplementedError
@ -119,14 +106,15 @@ class Mosaic(BaseMixTransform):
Default to (640, 640). Default to (640, 640).
""" """
def __init__(self, imgsz=640, p=1.0, border=(0, 0)): def __init__(self, dataset, imgsz=640, p=1.0, border=(0, 0)):
assert 0 <= p <= 1.0, "The probability should be in range [0, 1]. " f"got {p}." assert 0 <= p <= 1.0, "The probability should be in range [0, 1]. " f"got {p}."
super().__init__(pre_transform=None, p=p) super().__init__(dataset=dataset, p=p)
self.dataset = dataset
self.imgsz = imgsz self.imgsz = imgsz
self.border = border self.border = border
def get_indexes(self, dataset): def get_indexes(self):
return [random.randint(0, len(dataset) - 1) for _ in range(3)] return [random.randint(0, len(self.dataset) - 1) for _ in range(3)]
def _mix_transform(self, labels): def _mix_transform(self, labels):
mosaic_labels = [] mosaic_labels = []
@ -193,25 +181,19 @@ class Mosaic(BaseMixTransform):
class MixUp(BaseMixTransform): class MixUp(BaseMixTransform):
def __init__(self, pre_transform=None, p=0.0) -> None: def __init__(self, dataset, pre_transform=None, p=0.0) -> None:
super().__init__(pre_transform=pre_transform, p=p) super().__init__(dataset=dataset, pre_transform=pre_transform, p=p)
def get_indexes(self, dataset): def get_indexes(self):
return random.randint(0, len(dataset) - 1) return random.randint(0, len(self.dataset) - 1)
def _mix_transform(self, labels): def _mix_transform(self, labels):
im = labels["img"]
labels2 = labels["mix_labels"][0]
im2 = labels2["img"]
cls2 = labels2["cls"]
# Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf # Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0 r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
im = (im * r + im2 * (1 - r)).astype(np.uint8) labels2 = labels["mix_labels"][0]
cat_instances = Instances.concatenate([labels["instances"], labels2["instances"]], axis=0) labels["img"] = (labels["img"] * r + labels2["img"] * (1 - r)).astype(np.uint8)
cls = labels["cls"] labels["instances"] = Instances.concatenate([labels["instances"], labels2["instances"]], axis=0)
labels["img"] = im labels["cls"] = np.concatenate([labels["cls"], labels2["cls"]], 0)
labels["instances"] = cat_instances
labels["cls"] = np.concatenate([cls, cls2], 0)
return labels return labels
@ -412,7 +394,6 @@ class RandomHSV:
im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))) im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
labels["img"] = img
return labels return labels
@ -606,7 +587,6 @@ class Format:
self.batch_idx = batch_idx # keep the batch indexes self.batch_idx = batch_idx # keep the batch indexes
def __call__(self, labels): def __call__(self, labels):
labels.pop("dataset", None)
img = labels["img"] img = labels["img"]
h, w = img.shape[:2] h, w = img.shape[:2]
cls = labels.pop("cls") cls = labels.pop("cls")
@ -656,9 +636,9 @@ class Format:
return masks, instances, cls return masks, instances, cls
def mosaic_transforms(imgsz, hyp): def mosaic_transforms(dataset, imgsz, hyp):
pre_transform = Compose([ pre_transform = Compose([
Mosaic(imgsz=imgsz, p=hyp.mosaic, border=[-imgsz // 2, -imgsz // 2]), Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic, border=[-imgsz // 2, -imgsz // 2]),
CopyPaste(p=hyp.copy_paste), CopyPaste(p=hyp.copy_paste),
RandomPerspective( RandomPerspective(
degrees=hyp.degrees, degrees=hyp.degrees,
@ -670,7 +650,7 @@ def mosaic_transforms(imgsz, hyp):
),]) ),])
return Compose([ return Compose([
pre_transform, pre_transform,
MixUp(pre_transform=pre_transform, p=hyp.mixup), MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
Albumentations(p=1.0), Albumentations(p=1.0),
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v), RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
RandomFlip(direction="vertical", p=hyp.flipud), RandomFlip(direction="vertical", p=hyp.flipud),

@ -42,12 +42,13 @@ class BaseDataset(Dataset):
self.imgsz = imgsz self.imgsz = imgsz
self.label_path = label_path self.label_path = label_path
self.augment = augment self.augment = augment
self.single_cls = single_cls
self.prefix = prefix self.prefix = prefix
self.im_files = self.get_img_files(self.img_path) self.im_files = self.get_img_files(self.img_path)
self.labels = self.get_labels() self.labels = self.get_labels()
if single_cls: if self.single_cls:
self.update_labels(include_class=[], single_cls=single_cls) self.update_labels(include_class=[])
self.ni = len(self.im_files) self.ni = len(self.im_files)
@ -173,10 +174,7 @@ class BaseDataset(Dataset):
self.batch = bi # batch index of image self.batch = bi # batch index of image
def __getitem__(self, index): def __getitem__(self, index):
label = self.get_label_info(index) return self.transforms(self.get_label_info(index))
if self.augment:
label["dataset"] = self
return self.transforms(label)
def get_label_info(self, index): def get_label_info(self, index):
label = self.labels[index].copy() label = self.labels[index].copy()

@ -1,7 +1,6 @@
from itertools import repeat from itertools import repeat
from multiprocessing.pool import Pool from multiprocessing.pool import Pool
from pathlib import Path from pathlib import Path
from typing import OrderedDict
import torchvision import torchvision
from tqdm import tqdm from tqdm import tqdm
@ -126,7 +125,7 @@ class YOLODataset(BaseDataset):
def build_transforms(self, hyp=None): def build_transforms(self, hyp=None):
if self.augment: if self.augment:
mosaic = self.augment and not self.rect mosaic = self.augment and not self.rect
transforms = mosaic_transforms(self.imgsz, hyp) if mosaic else affine_transforms(self.imgsz, hyp) transforms = mosaic_transforms(self, self.imgsz, hyp) if mosaic else affine_transforms(self.imgsz, hyp)
else: else:
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz))]) transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz))])
transforms.append( transforms.append(

@ -72,18 +72,12 @@ weight_decay: 0.0005 # optimizer weight decay 5e-4
warmup_epochs: 3.0 # warmup epochs (fractions ok) warmup_epochs: 3.0 # warmup epochs (fractions ok)
warmup_momentum: 0.8 # warmup initial momentum warmup_momentum: 0.8 # warmup initial momentum
warmup_bias_lr: 0.1 # warmup initial bias lr warmup_bias_lr: 0.1 # warmup initial bias lr
box: 0.05 # box loss gain box: 7.5 # box loss gain
cls: 0.5 # cls loss gain cls: 0.5 # cls loss gain (scale with pixels)
cls_pw: 1.0 # cls BCELoss positive_weight dfl: 1.5 # dfl loss gain
obj: 1.0 # obj loss gain (scale with pixels)
obj_pw: 1.0 # obj BCELoss positive_weight
iou_t: 0.20 # IoU training threshold
anchor_t: 4.0 # anchor-multiple threshold
# anchors: 3 # anchors per output layer (0 to ignore)
fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5) fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
label_smoothing: 0.0 label_smoothing: 0.0
nbs: 64 # nominal batch size nbs: 64 # nominal batch size
# anchors: 3
hsv_h: 0.015 # image HSV-Hue augmentation (fraction) hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
hsv_s: 0.7 # image HSV-Saturation augmentation (fraction) hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
hsv_v: 0.4 # image HSV-Value augmentation (fraction) hsv_v: 0.4 # image HSV-Value augmentation (fraction)

@ -30,8 +30,8 @@ class DetectionTrainer(BaseTrainer):
def set_model_attributes(self): def set_model_attributes(self):
nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps) nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)
self.args.box *= 3 / nl # scale to layers self.args.box *= 3 / nl # scale to layers
self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers # self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
self.args.obj *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
self.model.nc = self.data["nc"] # attach number of classes to model self.model.nc = self.data["nc"] # attach number of classes to model
self.model.args = self.args # attach hyperparameters to model self.model.args = self.args # attach hyperparameters to model
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
@ -85,14 +85,11 @@ class Loss:
device = next(model.parameters()).device # get model device device = next(model.parameters()).device # get model device
h = model.args # hyperparameters h = model.args # hyperparameters
# Define criteria
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h["cls_pw"]], device=device), reduction='none')
# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3 # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
self.cp, self.cn = smooth_BCE(eps=h.get("label_smoothing", 0.0)) # positive, negative BCE targets self.cp, self.cn = smooth_BCE(eps=h.get("label_smoothing", 0.0)) # positive, negative BCE targets
m = model.model[-1] # Detect() module m = model.model[-1] # Detect() module
self.BCEcls = BCEcls self.bce = nn.BCEWithLogitsLoss(reduction='none')
self.hyp = h self.hyp = h
self.stride = m.stride # model strides self.stride = m.stride # model strides
self.nc = m.nc # number of classes self.nc = m.nc # number of classes
@ -156,7 +153,7 @@ class Loss:
# cls loss # cls loss
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
loss[1] = self.BCEcls(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
# bbox loss # bbox loss
if fg_mask.sum(): if fg_mask.sum():

Loading…
Cancel
Save