Buffered Mosaic for reduced HDD reads (#2791)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent dada5b73c4
commit 07b57c03c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -93,7 +93,7 @@ class BaseMixTransform:
indexes = [indexes]
# Get images information will be used for Mosaic or MixUp
mix_labels = [self.dataset.get_label_info(i) for i in indexes]
mix_labels = [self.dataset.get_image_and_label(i) for i in indexes]
if self.pre_transform is not None:
for i, data in enumerate(mix_labels):
@ -135,12 +135,15 @@ class Mosaic(BaseMixTransform):
super().__init__(dataset=dataset, p=p)
self.dataset = dataset
self.imgsz = imgsz
self.border = [-imgsz // 2, -imgsz // 2] if n == 4 else [-imgsz, -imgsz]
self.border = (-imgsz // 2, -imgsz // 2) # width, height
self.n = n
def get_indexes(self):
def get_indexes(self, buffer=True):
"""Return a list of random indexes from the dataset."""
return [random.randint(0, len(self.dataset) - 1) for _ in range(self.n - 1)]
if buffer: # select images from buffer
return random.choices(list(self.dataset.buffer), k=self.n - 1)
else: # select any images
return [random.randint(0, len(self.dataset) - 1) for _ in range(self.n - 1)]
def _mix_transform(self, labels):
"""Apply mixup transformation to the input image and labels."""
@ -224,10 +227,12 @@ class Mosaic(BaseMixTransform):
img9[y1:y2, x1:x2] = img[y1 - padh:, x1 - padw:] # img9[ymin:ymax, xmin:xmax]
hp, wp = h, w # height, width previous for next iteration
labels_patch = self._update_labels(labels_patch, padw, padh)
# Labels assuming imgsz*2 mosaic size
labels_patch = self._update_labels(labels_patch, padw + self.border[0], padh + self.border[1])
mosaic_labels.append(labels_patch)
final_labels = self._cat_labels(mosaic_labels)
final_labels['img'] = img9
final_labels['img'] = img9[-self.border[0]:self.border[0], -self.border[1]:self.border[1]]
return final_labels
@staticmethod
@ -245,18 +250,20 @@ class Mosaic(BaseMixTransform):
return {}
cls = []
instances = []
imgsz = self.imgsz * 2 # mosaic imgsz
for labels in mosaic_labels:
cls.append(labels['cls'])
instances.append(labels['instances'])
final_labels = {
'im_file': mosaic_labels[0]['im_file'],
'ori_shape': mosaic_labels[0]['ori_shape'],
'resized_shape': (self.imgsz * 2, self.imgsz * 2),
'resized_shape': (imgsz, imgsz),
'cls': np.concatenate(cls, 0),
'instances': Instances.concatenate(instances, axis=0),
'mosaic_border': self.border} # final_labels
clip_size = self.imgsz * (2 if self.n == 4 else 3)
final_labels['instances'].clip(clip_size, clip_size)
final_labels['instances'].clip(imgsz, imgsz)
good = final_labels['instances'].remove_zero_area_boxes()
final_labels['cls'] = final_labels['cls'][good]
return final_labels

@ -80,7 +80,7 @@ class BaseDataset(Dataset):
# Cache stuff
if cache == 'ram' and not self.check_cache_ram():
cache = False
self.ims = [None] * self.ni
self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni
self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
if cache:
self.cache_images(cache)
@ -88,6 +88,10 @@ class BaseDataset(Dataset):
# Transforms
self.transforms = self.build_transforms(hyp=hyp)
# Buffer thread for mosaic images
self.buffer = [] # buffer size = batch size
self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0
def get_img_files(self, img_path):
"""Read image files."""
try:
@ -147,13 +151,22 @@ class BaseDataset(Dataset):
interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
im = cv2.resize(im, (min(math.ceil(w0 * r), self.imgsz), min(math.ceil(h0 * r), self.imgsz)),
interpolation=interp)
return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
# Add to buffer if training with augmentations
if self.augment:
self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
self.buffer.append(i)
if len(self.buffer) >= self.max_buffer_length:
j = self.buffer.pop(0)
self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None
return im, (h0, w0), im.shape[:2]
return self.ims[i], self.im_hw0[i], self.im_hw[i]
def cache_images(self, cache):
"""Cache images to memory or disk."""
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image
with ThreadPool(NUM_THREADS) as pool:
results = pool.imap(fcn, range(self.ni))
@ -218,9 +231,9 @@ class BaseDataset(Dataset):
def __getitem__(self, index):
"""Returns transformed label information for given index."""
return self.transforms(self.get_label_info(index))
return self.transforms(self.get_image_and_label(index))
def get_label_info(self, index):
def get_image_and_label(self, index):
"""Get and return label information from the dataset."""
label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948
label.pop('shape', None) # shape is for rect, remove it
@ -229,8 +242,7 @@ class BaseDataset(Dataset):
label['resized_shape'][1] / label['ori_shape'][1]) # for evaluation
if self.rect:
label['rect_shape'] = self.batch_shapes[self.batch[index]]
label = self.update_labels_info(label)
return label
return self.update_labels_info(label)
def __len__(self):
"""Returns the length of the labels list for the dataset."""

@ -326,10 +326,20 @@ class Instances:
self.keypoints[..., 0] = self.keypoints[..., 0].clip(0, w)
self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h)
def remove_zero_area_boxes(self):
"""Remove zero-area boxes, i.e. after clipping some boxes may have zero width or height. This removes them."""
good = self._bboxes.areas() > 0
if not all(good):
self._bboxes = Bboxes(self._bboxes.bboxes[good], format=self._bboxes.format)
if len(self.segments):
self.segments = self.segments[good]
if self.keypoints is not None:
self.keypoints = self.keypoints[good]
return good
def update(self, bboxes, segments=None, keypoints=None):
"""Updates instance variables."""
new_bboxes = Bboxes(bboxes, format=self._bboxes.format)
self._bboxes = new_bboxes
self._bboxes = Bboxes(bboxes, format=self._bboxes.format)
if segments is not None:
self.segments = segments
if keypoints is not None:

Loading…
Cancel
Save