Fix dataloader (#32)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -127,7 +127,7 @@ class Mosaic(BaseMixTransform):
|
||||
self.border = border
|
||||
|
||||
def get_indexes(self, dataset):
|
||||
return [random.randint(0, len(dataset)) for _ in range(3)]
|
||||
return [random.randint(0, len(dataset) - 1) for _ in range(3)]
|
||||
|
||||
def _mix_transform(self, labels):
|
||||
mosaic_labels = []
|
||||
@ -200,7 +200,7 @@ class MixUp(BaseMixTransform):
|
||||
super().__init__(pre_transform=pre_transform, p=p)
|
||||
|
||||
def get_indexes(self, dataset):
|
||||
return random.randint(0, len(dataset))
|
||||
return random.randint(0, len(dataset) - 1)
|
||||
|
||||
def _mix_transform(self, labels):
|
||||
im = labels["img"]
|
||||
@ -366,7 +366,7 @@ class RandomPerspective:
|
||||
segments = instances.segments
|
||||
keypoints = instances.keypoints
|
||||
# update bboxes if there are segments.
|
||||
if segments is not None:
|
||||
if len(segments):
|
||||
bboxes, segments = self.apply_segments(segments, M)
|
||||
|
||||
if keypoints is not None:
|
||||
@ -379,7 +379,7 @@ class RandomPerspective:
|
||||
# make the bboxes have the same scale with new_bboxes
|
||||
i = self.box_candidates(box1=instances.bboxes.T,
|
||||
box2=new_instances.bboxes.T,
|
||||
area_thr=0.01 if segments is not None else 0.10)
|
||||
area_thr=0.01 if len(segments) else 0.10)
|
||||
labels["instances"] = new_instances[i]
|
||||
# clip
|
||||
labels["cls"] = cls[i]
|
||||
@ -518,7 +518,7 @@ class CopyPaste:
|
||||
bboxes = labels["instances"].bboxes
|
||||
segments = labels["instances"].segments # n, 1000, 2
|
||||
keypoints = labels["instances"].keypoints
|
||||
if self.p and segments is not None:
|
||||
if self.p and len(segments):
|
||||
n = len(segments)
|
||||
h, w, _ = im.shape # height, width, channels
|
||||
im_new = np.zeros(im.shape, np.uint8)
|
||||
@ -593,10 +593,18 @@ class Albumentations:
|
||||
# TODO: technically this is not an augmentation, maybe we should put this to another files
|
||||
class Format:
|
||||
|
||||
def __init__(self, bbox_format="xywh", normalize=True, mask=False, mask_ratio=4, mask_overlap=True, batch_idx=True):
|
||||
def __init__(self,
|
||||
bbox_format="xywh",
|
||||
normalize=True,
|
||||
return_mask=False,
|
||||
return_keypoint=False,
|
||||
mask_ratio=4,
|
||||
mask_overlap=True,
|
||||
batch_idx=True):
|
||||
self.bbox_format = bbox_format
|
||||
self.normalize = normalize
|
||||
self.mask = mask # set False when training detection only
|
||||
self.return_mask = return_mask # set False when training detection only
|
||||
self.return_keypoint = return_keypoint
|
||||
self.mask_ratio = mask_ratio
|
||||
self.mask_overlap = mask_overlap
|
||||
self.batch_idx = batch_idx # keep the batch indexes
|
||||
@ -610,16 +618,20 @@ class Format:
|
||||
instances.denormalize(w, h)
|
||||
nl = len(instances)
|
||||
|
||||
if instances.segments is not None and self.mask:
|
||||
masks, instances, cls = self._format_segments(instances, cls, w, h)
|
||||
labels["masks"] = (torch.from_numpy(masks) if nl else torch.zeros(
|
||||
1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio, img.shape[1] // self.mask_ratio))
|
||||
if self.return_mask:
|
||||
if nl:
|
||||
masks, instances, cls = self._format_segments(instances, cls, w, h)
|
||||
masks = torch.from_numpy(masks)
|
||||
else:
|
||||
masks = torch.zeros(1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio,
|
||||
img.shape[1] // self.mask_ratio)
|
||||
labels["masks"] = masks
|
||||
if self.normalize:
|
||||
instances.normalize(w, h)
|
||||
labels["img"] = self._format_img(img)
|
||||
labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl)
|
||||
labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
|
||||
if instances.keypoints is not None:
|
||||
if self.return_keypoint:
|
||||
labels["keypoints"] = torch.from_numpy(instances.keypoints) if nl else torch.zeros((nl, 17, 2))
|
||||
# then we can use collate_fn
|
||||
if self.batch_idx:
|
||||
|
@ -132,7 +132,12 @@ class YOLODataset(BaseDataset):
|
||||
transforms = affine_transforms(self.img_size, hyp)
|
||||
else:
|
||||
transforms = Compose([LetterBox(new_shape=(self.img_size, self.img_size))])
|
||||
transforms.append(Format(bbox_format="xywh", normalize=True, mask=self.use_segments, batch_idx=True))
|
||||
transforms.append(
|
||||
Format(bbox_format="xywh",
|
||||
normalize=True,
|
||||
return_mask=self.use_segments,
|
||||
return_keypoint=self.use_keypoints,
|
||||
batch_idx=True))
|
||||
return transforms
|
||||
|
||||
def update_labels_info(self, label):
|
||||
@ -140,7 +145,7 @@ class YOLODataset(BaseDataset):
|
||||
# NOTE: cls is not with bboxes now, since other tasks like classification and semantic segmentation need a independent cls label
|
||||
# we can make it also support classification and semantic segmentation by add or remove some dict keys there.
|
||||
bboxes = label.pop("bboxes")
|
||||
segments = label.pop("segments", None)
|
||||
segments = label.pop("segments")
|
||||
keypoints = label.pop("keypoints", None)
|
||||
bbox_format = label.pop("bbox_format")
|
||||
normalized = label.pop("normalized")
|
||||
@ -158,9 +163,9 @@ class YOLODataset(BaseDataset):
|
||||
value = values[i]
|
||||
if k == "img":
|
||||
value = torch.stack(value, 0)
|
||||
if k in ["mask", "keypoint", "bboxes", "cls"]:
|
||||
if k in ["masks", "keypoints", "bboxes", "cls"]:
|
||||
value = torch.cat(value, 0)
|
||||
new_batch[k] = values[i]
|
||||
new_batch[k] = value
|
||||
new_batch["batch_idx"] = list(new_batch["batch_idx"])
|
||||
for i in range(len(new_batch["batch_idx"])):
|
||||
new_batch["batch_idx"][i] += i # add target image index for build_targets()
|
||||
|
@ -52,7 +52,7 @@ def verify_image_label(args):
|
||||
# Verify one image-label pair
|
||||
im_file, lb_file, prefix, keypoint = args
|
||||
# number (missing, found, empty, corrupt), message, segments, keypoints
|
||||
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", None, None
|
||||
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None
|
||||
try:
|
||||
# verify images
|
||||
im = Image.open(im_file)
|
||||
|
Reference in New Issue
Block a user