|
|
|
@ -80,7 +80,7 @@ class BaseDataset(Dataset):
|
|
|
|
|
# Cache stuff
|
|
|
|
|
if cache == 'ram' and not self.check_cache_ram():
|
|
|
|
|
cache = False
|
|
|
|
|
self.ims = [None] * self.ni
|
|
|
|
|
self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni
|
|
|
|
|
self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
|
|
|
|
|
if cache:
|
|
|
|
|
self.cache_images(cache)
|
|
|
|
@ -88,6 +88,10 @@ class BaseDataset(Dataset):
|
|
|
|
|
# Transforms
|
|
|
|
|
self.transforms = self.build_transforms(hyp=hyp)
|
|
|
|
|
|
|
|
|
|
# Buffer thread for mosaic images
|
|
|
|
|
self.buffer = [] # buffer size = batch size
|
|
|
|
|
self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0
|
|
|
|
|
|
|
|
|
|
def get_img_files(self, img_path):
|
|
|
|
|
"""Read image files."""
|
|
|
|
|
try:
|
|
|
|
@ -147,13 +151,22 @@ class BaseDataset(Dataset):
|
|
|
|
|
interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
|
|
|
|
|
im = cv2.resize(im, (min(math.ceil(w0 * r), self.imgsz), min(math.ceil(h0 * r), self.imgsz)),
|
|
|
|
|
interpolation=interp)
|
|
|
|
|
return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
|
|
|
|
|
return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
|
|
|
|
|
|
|
|
|
|
# Add to buffer if training with augmentations
|
|
|
|
|
if self.augment:
|
|
|
|
|
self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
|
|
|
|
|
self.buffer.append(i)
|
|
|
|
|
if len(self.buffer) >= self.max_buffer_length:
|
|
|
|
|
j = self.buffer.pop(0)
|
|
|
|
|
self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None
|
|
|
|
|
|
|
|
|
|
return im, (h0, w0), im.shape[:2]
|
|
|
|
|
|
|
|
|
|
return self.ims[i], self.im_hw0[i], self.im_hw[i]
|
|
|
|
|
|
|
|
|
|
def cache_images(self, cache):
|
|
|
|
|
"""Cache images to memory or disk."""
|
|
|
|
|
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
|
|
|
|
|
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
|
|
|
|
|
fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image
|
|
|
|
|
with ThreadPool(NUM_THREADS) as pool:
|
|
|
|
|
results = pool.imap(fcn, range(self.ni))
|
|
|
|
@ -218,9 +231,9 @@ class BaseDataset(Dataset):
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
|
"""Returns transformed label information for given index."""
|
|
|
|
|
return self.transforms(self.get_label_info(index))
|
|
|
|
|
return self.transforms(self.get_image_and_label(index))
|
|
|
|
|
|
|
|
|
|
def get_label_info(self, index):
|
|
|
|
|
def get_image_and_label(self, index):
|
|
|
|
|
"""Get and return label information from the dataset."""
|
|
|
|
|
label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948
|
|
|
|
|
label.pop('shape', None) # shape is for rect, remove it
|
|
|
|
@ -229,8 +242,7 @@ class BaseDataset(Dataset):
|
|
|
|
|
label['resized_shape'][1] / label['ori_shape'][1]) # for evaluation
|
|
|
|
|
if self.rect:
|
|
|
|
|
label['rect_shape'] = self.batch_shapes[self.batch[index]]
|
|
|
|
|
label = self.update_labels_info(label)
|
|
|
|
|
return label
|
|
|
|
|
return self.update_labels_info(label)
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
"""Returns the length of the labels list for the dataset."""
|
|
|
|
|