Fix dataloader (#32)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Laughing
2022-11-04 03:00:40 -05:00
committed by GitHub
parent 92c60758dd
commit 2e9b18ce4e
8 changed files with 404 additions and 41 deletions

View File

@ -127,7 +127,7 @@ class Mosaic(BaseMixTransform):
self.border = border
def get_indexes(self, dataset):
return [random.randint(0, len(dataset)) for _ in range(3)]
return [random.randint(0, len(dataset) - 1) for _ in range(3)]
def _mix_transform(self, labels):
mosaic_labels = []
@ -200,7 +200,7 @@ class MixUp(BaseMixTransform):
super().__init__(pre_transform=pre_transform, p=p)
def get_indexes(self, dataset):
return random.randint(0, len(dataset))
return random.randint(0, len(dataset) - 1)
def _mix_transform(self, labels):
im = labels["img"]
@ -366,7 +366,7 @@ class RandomPerspective:
segments = instances.segments
keypoints = instances.keypoints
# update bboxes if there are segments.
if segments is not None:
if len(segments):
bboxes, segments = self.apply_segments(segments, M)
if keypoints is not None:
@ -379,7 +379,7 @@ class RandomPerspective:
# make the bboxes have the same scale with new_bboxes
i = self.box_candidates(box1=instances.bboxes.T,
box2=new_instances.bboxes.T,
area_thr=0.01 if segments is not None else 0.10)
area_thr=0.01 if len(segments) else 0.10)
labels["instances"] = new_instances[i]
# clip
labels["cls"] = cls[i]
@ -518,7 +518,7 @@ class CopyPaste:
bboxes = labels["instances"].bboxes
segments = labels["instances"].segments # n, 1000, 2
keypoints = labels["instances"].keypoints
if self.p and segments is not None:
if self.p and len(segments):
n = len(segments)
h, w, _ = im.shape # height, width, channels
im_new = np.zeros(im.shape, np.uint8)
@ -593,10 +593,18 @@ class Albumentations:
# TODO: technically this is not an augmentation, maybe we should put this to another files
class Format:
def __init__(self, bbox_format="xywh", normalize=True, mask=False, mask_ratio=4, mask_overlap=True, batch_idx=True):
def __init__(self,
bbox_format="xywh",
normalize=True,
return_mask=False,
return_keypoint=False,
mask_ratio=4,
mask_overlap=True,
batch_idx=True):
self.bbox_format = bbox_format
self.normalize = normalize
self.mask = mask # set False when training detection only
self.return_mask = return_mask # set False when training detection only
self.return_keypoint = return_keypoint
self.mask_ratio = mask_ratio
self.mask_overlap = mask_overlap
self.batch_idx = batch_idx # keep the batch indexes
@ -610,16 +618,20 @@ class Format:
instances.denormalize(w, h)
nl = len(instances)
if instances.segments is not None and self.mask:
masks, instances, cls = self._format_segments(instances, cls, w, h)
labels["masks"] = (torch.from_numpy(masks) if nl else torch.zeros(
1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio, img.shape[1] // self.mask_ratio))
if self.return_mask:
if nl:
masks, instances, cls = self._format_segments(instances, cls, w, h)
masks = torch.from_numpy(masks)
else:
masks = torch.zeros(1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio,
img.shape[1] // self.mask_ratio)
labels["masks"] = masks
if self.normalize:
instances.normalize(w, h)
labels["img"] = self._format_img(img)
labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl)
labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
if instances.keypoints is not None:
if self.return_keypoint:
labels["keypoints"] = torch.from_numpy(instances.keypoints) if nl else torch.zeros((nl, 17, 2))
# then we can use collate_fn
if self.batch_idx:

View File

@ -132,7 +132,12 @@ class YOLODataset(BaseDataset):
transforms = affine_transforms(self.img_size, hyp)
else:
transforms = Compose([LetterBox(new_shape=(self.img_size, self.img_size))])
transforms.append(Format(bbox_format="xywh", normalize=True, mask=self.use_segments, batch_idx=True))
transforms.append(
Format(bbox_format="xywh",
normalize=True,
return_mask=self.use_segments,
return_keypoint=self.use_keypoints,
batch_idx=True))
return transforms
def update_labels_info(self, label):
@ -140,7 +145,7 @@ class YOLODataset(BaseDataset):
# NOTE: cls is not with bboxes now, since other tasks like classification and semantic segmentation need a independent cls label
# we can make it also support classification and semantic segmentation by add or remove some dict keys there.
bboxes = label.pop("bboxes")
segments = label.pop("segments", None)
segments = label.pop("segments")
keypoints = label.pop("keypoints", None)
bbox_format = label.pop("bbox_format")
normalized = label.pop("normalized")
@ -158,9 +163,9 @@ class YOLODataset(BaseDataset):
value = values[i]
if k == "img":
value = torch.stack(value, 0)
if k in ["mask", "keypoint", "bboxes", "cls"]:
if k in ["masks", "keypoints", "bboxes", "cls"]:
value = torch.cat(value, 0)
new_batch[k] = values[i]
new_batch[k] = value
new_batch["batch_idx"] = list(new_batch["batch_idx"])
for i in range(len(new_batch["batch_idx"])):
new_batch["batch_idx"][i] += i # add target image index for build_targets()

View File

@ -52,7 +52,7 @@ def verify_image_label(args):
# Verify one image-label pair
im_file, lb_file, prefix, keypoint = args
# number (missing, found, empty, corrupt), message, segments, keypoints
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", None, None
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None
try:
# verify images
im = Image.open(im_file)

View File

@ -162,7 +162,7 @@ class Bboxes:
class Instances:
def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:
def __init__(self, bboxes, segments=[], keypoints=None, bbox_format="xywh", normalized=True) -> None:
"""
Args:
bboxes (ndarray): bboxes with shape [N, 4].
@ -173,11 +173,13 @@ class Instances:
self.keypoints = keypoints
self.normalized = normalized
if isinstance(segments, list) and len(segments) > 0:
if len(segments) > 0:
# list[np.array(1000, 2)] * num_samples
segments = resample_segments(segments)
# (N, 1000, 2)
segments = np.stack(segments, axis=0)
else:
segments = np.zeros((0, 1000, 2), dtype=np.float32)
self.segments = segments
def convert_bbox(self, format):
@ -191,9 +193,8 @@ class Instances:
self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h))
if bbox_only:
return
if self.segments is not None:
self.segments[..., 0] *= scale_w
self.segments[..., 1] *= scale_h
self.segments[..., 0] *= scale_w
self.segments[..., 1] *= scale_h
if self.keypoints is not None:
self.keypoints[..., 0] *= scale_w
self.keypoints[..., 1] *= scale_h
@ -202,9 +203,8 @@ class Instances:
if not self.normalized:
return
self._bboxes.mul(scale=(w, h, w, h))
if self.segments is not None:
self.segments[..., 0] *= w
self.segments[..., 1] *= h
self.segments[..., 0] *= w
self.segments[..., 1] *= h
if self.keypoints is not None:
self.keypoints[..., 0] *= w
self.keypoints[..., 1] *= h
@ -214,9 +214,8 @@ class Instances:
if self.normalized:
return
self._bboxes.mul(scale=(1 / w, 1 / h, 1 / w, 1 / h))
if self.segments is not None:
self.segments[..., 0] /= w
self.segments[..., 1] /= h
self.segments[..., 0] /= w
self.segments[..., 1] /= h
if self.keypoints is not None:
self.keypoints[..., 0] /= w
self.keypoints[..., 1] /= h
@ -226,9 +225,8 @@ class Instances:
# handle rect and mosaic situation
assert not self.normalized, "you should add padding with absolute coordinates."
self._bboxes.add(offset=(padw, padh, padw, padh))
if self.segments is not None:
self.segments[..., 0] += padw
self.segments[..., 1] += padh
self.segments[..., 0] += padw
self.segments[..., 1] += padh
if self.keypoints is not None:
self.keypoints[..., 0] += padw
self.keypoints[..., 1] += padh
@ -241,7 +239,7 @@ class Instances:
Returns:
Instances: Create a new :class:`Instances` by indexing.
"""
segments = self.segments[index] if self.segments is not None else None
segments = self.segments[index] if len(self.segments) else self.segments
keypoints = self.keypoints[index] if self.keypoints is not None else None
bboxes = self.bboxes[index]
bbox_format = self._bboxes.format
@ -256,16 +254,14 @@ class Instances:
def flipud(self, h):
# this function may not be very logical, just for clean code when using augment flipud
self.bboxes[:, 1] = h - self.bboxes[:, 1]
if self.segments is not None:
self.segments[..., 1] = h - self.segments[..., 1]
self.segments[..., 1] = h - self.segments[..., 1]
if self.keypoints is not None:
self.keypoints[..., 1] = h - self.keypoints[..., 1]
def fliplr(self, w):
# this function may not be very logical, just for clean code when using augment fliplr
self.bboxes[:, 0] = w - self.bboxes[:, 0]
if self.segments is not None:
self.segments[..., 0] = w - self.segments[..., 0]
self.segments[..., 0] = w - self.segments[..., 0]
if self.keypoints is not None:
self.keypoints[..., 0] = w - self.keypoints[..., 0]
@ -273,9 +269,8 @@ class Instances:
self.convert_bbox(format="xyxy")
self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)
if self.segments is not None:
self.segments[..., 0] = self.segments[..., 0].clip(0, w)
self.segments[..., 1] = self.segments[..., 1].clip(0, h)
self.segments[..., 0] = self.segments[..., 0].clip(0, w)
self.segments[..., 1] = self.segments[..., 1].clip(0, h)
if self.keypoints is not None:
self.keypoints[..., 0] = self.keypoints[..., 0].clip(0, w)
self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h)
@ -311,13 +306,12 @@ class Instances:
if len(instances_list) == 1:
return instances_list[0]
use_segment = instances_list[0].segments is not None
use_keypoint = instances_list[0].keypoints is not None
bbox_format = instances_list[0]._bboxes.format
normalized = instances_list[0].normalized
cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis)
cat_segments = np.concatenate([b.segments for b in instances_list], axis=axis) if use_segment else None
cat_segments = np.concatenate([b.segments for b in instances_list], axis=axis)
cat_keypoints = np.concatenate([b.keypoints for b in instances_list], axis=axis) if use_keypoint else None
return cls(cat_boxes, cat_segments, cat_keypoints, bbox_format, normalized)