ultralytics 8.0.24 mosaic, DDP, download fixes (#703)

Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-01-29 18:10:25 +01:00
committed by GitHub
parent 899abe9f82
commit aecd17d455
15 changed files with 120 additions and 98 deletions

View File

@ -44,20 +44,8 @@ class Compose:
self.transforms = transforms
def __call__(self, data):
mosaic_p = None
mosaic_imgsz = None
for t in self.transforms:
if isinstance(t, Mosaic):
temp = t(data)
mosaic_p = False if temp == data else True
mosaic_imgsz = t.imgsz
data = temp
else:
if isinstance(t, RandomPerspective):
t.border = [-mosaic_imgsz // 2, -mosaic_imgsz // 2] if mosaic_p else [0, 0]
data = t(data)
data = t(data)
return data
def append(self, transform):
@ -140,7 +128,7 @@ class Mosaic(BaseMixTransform):
labels_patch = (labels if i == 0 else labels["mix_labels"][i - 1]).copy()
# Load image
img = labels_patch["img"]
h, w = labels_patch["resized_shape"]
h, w = labels_patch.pop("resized_shape")
# place img in img4
if i == 0: # top left
@ -184,11 +172,12 @@ class Mosaic(BaseMixTransform):
cls.append(labels["cls"])
instances.append(labels["instances"])
final_labels = {
"im_file": mosaic_labels[0]["im_file"],
"ori_shape": mosaic_labels[0]["ori_shape"],
"resized_shape": (self.imgsz * 2, self.imgsz * 2),
"im_file": mosaic_labels[0]["im_file"],
"cls": np.concatenate(cls, 0),
"instances": Instances.concatenate(instances, axis=0)}
"instances": Instances.concatenate(instances, axis=0),
"mosaic_border": self.border}
final_labels["instances"].clip(self.imgsz * 2, self.imgsz * 2)
return final_labels
@ -213,7 +202,14 @@ class MixUp(BaseMixTransform):
class RandomPerspective:
def __init__(self, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, border=(0, 0)):
def __init__(self,
degrees=0.0,
translate=0.1,
scale=0.5,
shear=0.0,
perspective=0.0,
border=(0, 0),
pre_transform=None):
self.degrees = degrees
self.translate = translate
self.scale = scale
@ -221,8 +217,9 @@ class RandomPerspective:
self.perspective = perspective
# mosaic border
self.border = border
self.pre_transform = pre_transform
def affine_transform(self, img):
def affine_transform(self, img, border):
# Center
C = np.eye(3)
@ -255,7 +252,7 @@ class RandomPerspective:
# Combined rotation matrix
M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
# affine image
if (self.border[0] != 0) or (self.border[1] != 0) or (M != np.eye(3)).any(): # image changed
if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
if self.perspective:
img = cv2.warpPerspective(img, M, dsize=self.size, borderValue=(114, 114, 114))
else: # affine
@ -341,6 +338,10 @@ class RandomPerspective:
Args:
labels(Dict): a dict of `bboxes`, `segments`, `keypoints`.
"""
if self.pre_transform and "mosaic_border" not in labels:
labels = self.pre_transform(labels)
labels.pop("ratio_pad") # do not need ratio pad
img = labels["img"]
cls = labels["cls"]
instances = labels.pop("instances")
@ -348,10 +349,11 @@ class RandomPerspective:
instances.convert_bbox(format="xyxy")
instances.denormalize(*img.shape[:2][::-1])
self.size = img.shape[1] + self.border[1] * 2, img.shape[0] + self.border[0] * 2 # w, h
border = labels.pop("mosaic_border", self.border)
self.size = img.shape[1] + border[1] * 2, img.shape[0] + border[0] * 2 # w, h
# M is affine matrix
# scale for func:`box_candidates`
img, M, scale = self.affine_transform(img)
img, M, scale = self.affine_transform(img, border)
bboxes = self.apply_bboxes(instances.bboxes, M)
@ -513,8 +515,10 @@ class CopyPaste:
# Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
im = labels["img"]
cls = labels["cls"]
h, w = im.shape[:2]
instances = labels.pop("instances")
instances.convert_bbox(format="xyxy")
instances.denormalize(w, h)
if self.p and len(instances.segments):
n = len(instances)
_, w, _ = im.shape # height, width, channels
@ -605,7 +609,7 @@ class Format:
self.batch_idx = batch_idx # keep the batch indexes
def __call__(self, labels):
img = labels["img"]
img = labels.pop("img")
h, w = img.shape[:2]
cls = labels.pop("cls")
instances = labels.pop("instances")
@ -654,7 +658,7 @@ class Format:
return masks, instances, cls
def mosaic_transforms(dataset, imgsz, hyp):
def v8_transforms(dataset, imgsz, hyp):
pre_transform = Compose([
Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic, border=[-imgsz // 2, -imgsz // 2]),
CopyPaste(p=hyp.copy_paste),
@ -664,7 +668,7 @@ def mosaic_transforms(dataset, imgsz, hyp):
scale=hyp.scale,
shear=hyp.shear,
perspective=hyp.perspective,
border=[-imgsz // 2, -imgsz // 2],
pre_transform=LetterBox(new_shape=(imgsz, imgsz)),
),])
return Compose([
pre_transform,
@ -675,23 +679,6 @@ def mosaic_transforms(dataset, imgsz, hyp):
RandomFlip(direction="horizontal", p=hyp.fliplr),]) # transforms
def affine_transforms(imgsz, hyp):
return Compose([
LetterBox(new_shape=(imgsz, imgsz)),
RandomPerspective(
degrees=hyp.degrees,
translate=hyp.translate,
scale=hyp.scale,
shear=hyp.shear,
perspective=hyp.perspective,
border=[0, 0],
),
Albumentations(p=1.0),
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
RandomFlip(direction="vertical", p=hyp.flipud),
RandomFlip(direction="horizontal", p=hyp.fliplr),]) # transforms
# Classification augmentations -----------------------------------------------------------------------------------------
def classify_transforms(size=224):
# Transforms to apply if albumentations not installed

View File

@ -182,6 +182,7 @@ class BaseDataset(Dataset):
def get_label_info(self, index):
label = self.labels[index].copy()
label.pop("shape", None) # shape is for rect, remove it
label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index)
label["ratio_pad"] = (
label["resized_shape"][0] / label["ori_shape"][0],

View File

@ -136,8 +136,9 @@ class YOLODataset(BaseDataset):
# TODO: use hyp config to set all these augmentations
def build_transforms(self, hyp=None):
if self.augment:
mosaic = self.augment and not self.rect
transforms = mosaic_transforms(self, self.imgsz, hyp) if mosaic else affine_transforms(self.imgsz, hyp)
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
transforms = v8_transforms(self, self.imgsz, hyp)
else:
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
transforms.append(
@ -151,15 +152,10 @@ class YOLODataset(BaseDataset):
return transforms
def close_mosaic(self, hyp):
self.transforms = affine_transforms(self.imgsz, hyp)
self.transforms.append(
Format(bbox_format="xywh",
normalize=True,
return_mask=self.use_segments,
return_keypoint=self.use_keypoints,
batch_idx=True,
mask_ratio=hyp.mask_ratio,
mask_overlap=hyp.overlap_mask))
hyp.mosaic = 0.0 # set mosaic ratio=0.0
hyp.copy_paste = 0.0 # keep the same behavior as previous v8 close-mosaic
hyp.mixup = 0.0 # keep the same behavior as previous v8 close-mosaic
self.transforms = self.build_transforms(hyp)
def update_labels_info(self, label):
"""custom your label format here"""
@ -175,8 +171,6 @@ class YOLODataset(BaseDataset):
@staticmethod
def collate_fn(batch):
# TODO: returning a dict can make thing easier and cleaner when using dataset in training
# but I don't know if this will slow down a little bit.
new_batch = {}
keys = batch[0].keys()
values = list(zip(*[list(b.values()) for b in batch]))

View File

@ -246,7 +246,7 @@ def check_det_dataset(dataset, autodownload=True):
r = exec(s, {'yaml': data}) # return None
dt = f'({round(time.time() - t, 1)}s)'
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt}"
LOGGER.info(f"Dataset download {s}")
LOGGER.info(f"Dataset download {s}\n")
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf') # download fonts
return data # dictionary