From 07b57c03c890cff11f932844707ba9c6858fc26f Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 25 May 2023 00:42:13 +0200 Subject: [PATCH] Buffered Mosaic for reduced HDD reads (#2791) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ultralytics/yolo/data/augment.py | 25 ++++++++++++++++--------- ultralytics/yolo/data/base.py | 28 ++++++++++++++++++++-------- ultralytics/yolo/utils/instance.py | 14 ++++++++++++-- 3 files changed, 48 insertions(+), 19 deletions(-) diff --git a/ultralytics/yolo/data/augment.py b/ultralytics/yolo/data/augment.py index c3aec94..b7aea5b 100644 --- a/ultralytics/yolo/data/augment.py +++ b/ultralytics/yolo/data/augment.py @@ -93,7 +93,7 @@ class BaseMixTransform: indexes = [indexes] # Get images information will be used for Mosaic or MixUp - mix_labels = [self.dataset.get_label_info(i) for i in indexes] + mix_labels = [self.dataset.get_image_and_label(i) for i in indexes] if self.pre_transform is not None: for i, data in enumerate(mix_labels): @@ -135,12 +135,15 @@ class Mosaic(BaseMixTransform): super().__init__(dataset=dataset, p=p) self.dataset = dataset self.imgsz = imgsz - self.border = [-imgsz // 2, -imgsz // 2] if n == 4 else [-imgsz, -imgsz] + self.border = (-imgsz // 2, -imgsz // 2) # width, height self.n = n - def get_indexes(self): + def get_indexes(self, buffer=True): """Return a list of random indexes from the dataset.""" - return [random.randint(0, len(self.dataset) - 1) for _ in range(self.n - 1)] + if buffer: # select images from buffer + return random.choices(list(self.dataset.buffer), k=self.n - 1) + else: # select any images + return [random.randint(0, len(self.dataset) - 1) for _ in range(self.n - 1)] def _mix_transform(self, labels): """Apply mixup transformation to the input image and labels.""" @@ -224,10 +227,12 @@ class Mosaic(BaseMixTransform): img9[y1:y2, x1:x2] = img[y1 - padh:, x1 - padw:] # img9[ymin:ymax, xmin:xmax] hp, wp = h, w # height, width previous for next iteration - labels_patch = self._update_labels(labels_patch, padw, padh) + # Labels assuming imgsz*2 mosaic size + labels_patch = self._update_labels(labels_patch, padw + self.border[0], padh + self.border[1]) mosaic_labels.append(labels_patch) final_labels = self._cat_labels(mosaic_labels) - final_labels['img'] = img9 + + final_labels['img'] = img9[-self.border[0]:self.border[0], -self.border[1]:self.border[1]] return final_labels @staticmethod @@ -245,18 +250,20 @@ class Mosaic(BaseMixTransform): return {} cls = [] instances = [] + imgsz = self.imgsz * 2 # mosaic imgsz for labels in mosaic_labels: cls.append(labels['cls']) instances.append(labels['instances']) final_labels = { 'im_file': mosaic_labels[0]['im_file'], 'ori_shape': mosaic_labels[0]['ori_shape'], - 'resized_shape': (self.imgsz * 2, self.imgsz * 2), + 'resized_shape': (imgsz, imgsz), 'cls': np.concatenate(cls, 0), 'instances': Instances.concatenate(instances, axis=0), 'mosaic_border': self.border} # final_labels - clip_size = self.imgsz * (2 if self.n == 4 else 3) - final_labels['instances'].clip(clip_size, clip_size) + final_labels['instances'].clip(imgsz, imgsz) + good = final_labels['instances'].remove_zero_area_boxes() + final_labels['cls'] = final_labels['cls'][good] return final_labels diff --git a/ultralytics/yolo/data/base.py b/ultralytics/yolo/data/base.py index 624bc04..d2ea0cc 100644 --- a/ultralytics/yolo/data/base.py +++ b/ultralytics/yolo/data/base.py @@ -80,7 +80,7 @@ class BaseDataset(Dataset): # Cache stuff if cache == 'ram' and not self.check_cache_ram(): cache = False - self.ims = [None] * self.ni + self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files] if cache: self.cache_images(cache) @@ -88,6 +88,10 @@ class BaseDataset(Dataset): # Transforms self.transforms = self.build_transforms(hyp=hyp) + # Buffer thread for mosaic images + self.buffer = [] # buffer size = batch size + self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0 + def get_img_files(self, img_path): """Read image files.""" try: @@ -147,13 +151,22 @@ class BaseDataset(Dataset): interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA im = cv2.resize(im, (min(math.ceil(w0 * r), self.imgsz), min(math.ceil(h0 * r), self.imgsz)), interpolation=interp) - return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized - return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized + + # Add to buffer if training with augmentations + if self.augment: + self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized + self.buffer.append(i) + if len(self.buffer) >= self.max_buffer_length: + j = self.buffer.pop(0) + self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None + + return im, (h0, w0), im.shape[:2] + + return self.ims[i], self.im_hw0[i], self.im_hw[i] def cache_images(self, cache): """Cache images to memory or disk.""" b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes - self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image with ThreadPool(NUM_THREADS) as pool: results = pool.imap(fcn, range(self.ni)) @@ -218,9 +231,9 @@ class BaseDataset(Dataset): def __getitem__(self, index): """Returns transformed label information for given index.""" - return self.transforms(self.get_label_info(index)) + return self.transforms(self.get_image_and_label(index)) - def get_label_info(self, index): + def get_image_and_label(self, index): """Get and return label information from the dataset.""" label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948 label.pop('shape', None) # shape is for rect, remove it @@ -229,8 +242,7 @@ class BaseDataset(Dataset): label['resized_shape'][1] / label['ori_shape'][1]) # for evaluation if self.rect: label['rect_shape'] = self.batch_shapes[self.batch[index]] - label = self.update_labels_info(label) - return label + return self.update_labels_info(label) def __len__(self): """Returns the length of the labels list for the dataset.""" diff --git a/ultralytics/yolo/utils/instance.py b/ultralytics/yolo/utils/instance.py index caef68d..619ac30 100644 --- a/ultralytics/yolo/utils/instance.py +++ b/ultralytics/yolo/utils/instance.py @@ -326,10 +326,20 @@ class Instances: self.keypoints[..., 0] = self.keypoints[..., 0].clip(0, w) self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h) + def remove_zero_area_boxes(self): + """Remove zero-area boxes, i.e. after clipping some boxes may have zero width or height. This removes them.""" + good = self._bboxes.areas() > 0 + if not all(good): + self._bboxes = Bboxes(self._bboxes.bboxes[good], format=self._bboxes.format) + if len(self.segments): + self.segments = self.segments[good] + if self.keypoints is not None: + self.keypoints = self.keypoints[good] + return good + def update(self, bboxes, segments=None, keypoints=None): """Updates instance variables.""" - new_bboxes = Bboxes(bboxes, format=self._bboxes.format) - self._bboxes = new_bboxes + self._bboxes = Bboxes(bboxes, format=self._bboxes.format) if segments is not None: self.segments = segments if keypoints is not None: