# Ultralytics YOLO 🚀, AGPL-3.0 license from itertools import repeat from multiprocessing.pool import ThreadPool from pathlib import Path import cv2 import numpy as np import torch import torchvision from tqdm import tqdm from ..utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT, is_dir_writeable from .augment import Compose, Format, Instances, LetterBox, classify_albumentations, classify_transforms, v8_transforms from .base import BaseDataset from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image_label class YOLODataset(BaseDataset): """ Dataset class for loading object detection and/or segmentation labels in YOLO format. Args: img_path (str): Path to the folder containing images. imgsz (int, optional): Image size. Defaults to 640. cache (bool, optional): Cache images to RAM or disk during training. Defaults to False. augment (bool, optional): If True, data augmentation is applied. Defaults to True. hyp (dict, optional): Hyperparameters to apply data augmentation. Defaults to None. prefix (str, optional): Prefix to print in log messages. Defaults to ''. rect (bool, optional): If True, rectangular training is used. Defaults to False. batch_size (int, optional): Size of batches. Defaults to None. stride (int, optional): Stride. Defaults to 32. pad (float, optional): Padding. Defaults to 0.0. single_cls (bool, optional): If True, single class training is used. Defaults to False. use_segments (bool, optional): If True, segmentation masks are used as labels. Defaults to False. use_keypoints (bool, optional): If True, keypoints are used as labels. Defaults to False. data (dict, optional): A dataset YAML dictionary. Defaults to None. classes (list): List of included classes. Default is None. Returns: (torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model. """ cache_version = '1.0.2' # dataset labels *.cache version, >= 1.0.0 for YOLOv8 rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4] def __init__(self, img_path, imgsz=640, cache=False, augment=True, hyp=None, prefix='', rect=False, batch_size=None, stride=32, pad=0.0, single_cls=False, use_segments=False, use_keypoints=False, data=None, classes=None): self.use_segments = use_segments self.use_keypoints = use_keypoints self.data = data assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.' super().__init__(img_path, imgsz, cache, augment, hyp, prefix, rect, batch_size, stride, pad, single_cls, classes) def cache_labels(self, path=Path('./labels.cache')): """Cache dataset labels, check images and read shapes. Args: path (Path): path where to save the cache file (default: Path('./labels.cache')). Returns: (dict): labels. """ x = {'labels': []} nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages desc = f'{self.prefix}Scanning {path.parent / path.stem}...' total = len(self.im_files) nkpt, ndim = self.data.get('kpt_shape', (0, 0)) if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)): raise ValueError("'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of " "keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'") with ThreadPool(NUM_THREADS) as pool: results = pool.imap(func=verify_image_label, iterable=zip(self.im_files, self.label_files, repeat(self.prefix), repeat(self.use_keypoints), repeat(len(self.data['names'])), repeat(nkpt), repeat(ndim))) pbar = tqdm(results, desc=desc, total=total, bar_format=TQDM_BAR_FORMAT) for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar: nm += nm_f nf += nf_f ne += ne_f nc += nc_f if im_file: x['labels'].append( dict( im_file=im_file, shape=shape, cls=lb[:, 0:1], # n, 1 bboxes=lb[:, 1:], # n, 4 segments=segments, keypoints=keypoint, normalized=True, bbox_format='xywh')) if msg: msgs.append(msg) pbar.desc = f'{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt' pbar.close() if msgs: LOGGER.info('\n'.join(msgs)) if nf == 0: LOGGER.warning(f'{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}') x['hash'] = get_hash(self.label_files + self.im_files) x['results'] = nf, nm, ne, nc, len(self.im_files) x['msgs'] = msgs # warnings x['version'] = self.cache_version # cache version if is_dir_writeable(path.parent): if path.exists(): path.unlink() # remove *.cache file if exists np.save(str(path), x) # save cache for next time path.with_suffix('.cache.npy').rename(path) # remove .npy suffix LOGGER.info(f'{self.prefix}New cache created: {path}') else: LOGGER.warning(f'{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.') return x def get_labels(self): self.label_files = img2label_paths(self.im_files) cache_path = Path(self.label_files[0]).parent.with_suffix('.cache') try: import gc gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585 cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dict gc.enable() assert cache['version'] == self.cache_version # matches current version assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash except (FileNotFoundError, AssertionError, AttributeError): cache, exists = self.cache_labels(cache_path), False # run cache ops # Display cache nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total if exists and LOCAL_RANK in (-1, 0): d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt' tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results if cache['msgs']: LOGGER.info('\n'.join(cache['msgs'])) # display warnings if nf == 0: # number of labels found raise FileNotFoundError(f'{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}') # Read cache [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items labels = cache['labels'] self.im_files = [lb['im_file'] for lb in labels] # update im_files # Check if the dataset is all boxes or all segments lengths = ((len(lb['cls']), len(lb['bboxes']), len(lb['segments'])) for lb in labels) len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths)) if len_segments and len_boxes != len_segments: LOGGER.warning( f'WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, ' f'len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. ' 'To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.') for lb in labels: lb['segments'] = [] if len_cls == 0: raise ValueError(f'All labels empty in {cache_path}, can not start training without labels. {HELP_URL}') return labels # TODO: use hyp config to set all these augmentations def build_transforms(self, hyp=None): if self.augment: hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0 hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0 transforms = v8_transforms(self, self.imgsz, hyp) else: transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)]) transforms.append( Format(bbox_format='xywh', normalize=True, return_mask=self.use_segments, return_keypoint=self.use_keypoints, batch_idx=True, mask_ratio=hyp.mask_ratio, mask_overlap=hyp.overlap_mask)) return transforms def close_mosaic(self, hyp): hyp.mosaic = 0.0 # set mosaic ratio=0.0 hyp.copy_paste = 0.0 # keep the same behavior as previous v8 close-mosaic hyp.mixup = 0.0 # keep the same behavior as previous v8 close-mosaic self.transforms = self.build_transforms(hyp) def update_labels_info(self, label): """custom your label format here.""" # NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label # we can make it also support classification and semantic segmentation by add or remove some dict keys there. bboxes = label.pop('bboxes') segments = label.pop('segments') keypoints = label.pop('keypoints', None) bbox_format = label.pop('bbox_format') normalized = label.pop('normalized') label['instances'] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized) return label @staticmethod def collate_fn(batch): new_batch = {} keys = batch[0].keys() values = list(zip(*[list(b.values()) for b in batch])) for i, k in enumerate(keys): value = values[i] if k == 'img': value = torch.stack(value, 0) if k in ['masks', 'keypoints', 'bboxes', 'cls']: value = torch.cat(value, 0) new_batch[k] = value new_batch['batch_idx'] = list(new_batch['batch_idx']) for i in range(len(new_batch['batch_idx'])): new_batch['batch_idx'][i] += i # add target image index for build_targets() new_batch['batch_idx'] = torch.cat(new_batch['batch_idx'], 0) return new_batch # Classification dataloaders ------------------------------------------------------------------------------------------- class ClassificationDataset(torchvision.datasets.ImageFolder): """ YOLOv5 Classification Dataset. Arguments root: Dataset path transform: torchvision transforms, used by default album_transform: Albumentations transforms, used if installed """ def __init__(self, root, augment, imgsz, cache=False): super().__init__(root=root) self.torch_transforms = classify_transforms(imgsz) self.album_transforms = classify_albumentations(augment, imgsz) if augment else None self.cache_ram = cache is True or cache == 'ram' self.cache_disk = cache == 'disk' self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im def __getitem__(self, i): f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image if self.cache_ram and im is None: im = self.samples[i][3] = cv2.imread(f) elif self.cache_disk: if not fn.exists(): # load npy np.save(fn.as_posix(), cv2.imread(f)) im = np.load(fn) else: # read image im = cv2.imread(f) # BGR if self.album_transforms: sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))['image'] else: sample = self.torch_transforms(im) return {'img': sample, 'cls': j} def __len__(self) -> int: return len(self.samples) # TODO: support semantic segmentation class SemanticDataset(BaseDataset): def __init__(self): pass