from itertools import repeat from multiprocessing.pool import Pool from pathlib import Path import cv2 import numpy as np import torch import torchvision from tqdm import tqdm from ..utils.general import LOGGER, NUM_THREADS from .augment import * from .base import BaseDataset from .utils import BAR_FORMAT, HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label class YOLODataset(BaseDataset): cache_version = 0.6 # dataset labels *.cache version rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4] """YOLO Dataset. Args: img_path (str): image path. prefix (str): prefix. """ def __init__( self, img_path, img_size=640, label_path=None, cache=False, augment=True, hyp=None, prefix="", rect=False, batch_size=None, stride=32, pad=0.0, single_cls=False, use_segments=False, use_keypoints=False, ): self.use_segments = use_segments self.use_keypoints = use_keypoints assert not (self.use_segments and self.use_keypoints), "We can't use both of segmentation and pose." super().__init__(img_path, img_size, label_path, cache, augment, hyp, prefix, rect, batch_size, stride, pad, single_cls) def cache_labels(self, path=Path("./labels.cache")): # Cache dataset labels, check images and read shapes x = {"labels": []} nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages desc = f"{self.prefix}Scanning '{path.parent / path.stem}' images and labels..." with Pool(NUM_THREADS) as pool: pbar = tqdm( pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(self.prefix), repeat(self.use_keypoints))), desc=desc, total=len(self.im_files), bar_format=BAR_FORMAT, ) for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar: nm += nm_f nf += nf_f ne += ne_f nc += nc_f if im_file: x["labels"].append( dict( im_file=im_file, shape=shape, cls=lb[:, 0:1], # n, 1 bboxes=lb[:, 1:], # n, 4 segments=segments, keypoints=keypoint, normalized=True, bbox_format="xywh", )) if msg: msgs.append(msg) pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupt" pbar.close() if msgs: LOGGER.info("\n".join(msgs)) if nf == 0: LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}") x["hash"] = get_hash(self.label_files + self.im_files) x["results"] = nf, nm, ne, nc, len(self.im_files) x["msgs"] = msgs # warnings x["version"] = self.cache_version # cache version try: np.save(path, x) # save cache for next time path.with_suffix(".cache.npy").rename(path) # remove .npy suffix LOGGER.info(f"{self.prefix}New cache created: {path}") except Exception as e: LOGGER.warning( f"{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable: {e}") # not writeable return x def get_labels(self): self.label_files = img2label_paths(self.im_files) cache_path = Path(self.label_files[0]).parent.with_suffix(".cache") try: cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict assert cache["version"] == self.cache_version # matches current version assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash except Exception: cache, exists = self.cache_labels(cache_path), False # run cache ops # Display cache nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total if exists and LOCAL_RANK in {-1, 0}: d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupt" tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=BAR_FORMAT) # display cache results if cache["msgs"]: LOGGER.info("\n".join(cache["msgs"])) # display warnings assert nf > 0, f"{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}" # Read cache [cache.pop(k) for k in ("hash", "version", "msgs")] # remove items labels = cache["labels"] nl = len(np.concatenate([label["cls"] for label in labels], 0)) # number of labels assert nl > 0, f"{self.prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}" return labels # TODO: use hyp config to set all these augmentations def build_transforms(self, hyp=None): mosaic = self.augment and not self.rect # mosaic = False if self.augment: if mosaic: transforms = mosaic_transforms(self.img_size, hyp) else: transforms = affine_transforms(self.img_size, hyp) else: transforms = Compose([LetterBox(new_shape=(self.img_size, self.img_size))]) transforms.append(Format(bbox_format="xywh", normalize=True, mask=self.use_segments, batch_idx=True)) return transforms def update_labels_info(self, label): """custom your label format here""" # NOTE: cls is not with bboxes now, since other tasks like classification and semantic segmentation need a independent cls label # we can make it also support classification and semantic segmentation by add or remove some dict keys there. bboxes = label.pop("bboxes") segments = label.pop("segments", None) keypoints = label.pop("keypoints", None) bbox_format = label.pop("bbox_format") normalized = label.pop("normalized") label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized) return label @staticmethod def collate_fn(batch): # TODO: returning a dict can make thing easier and cleaner when using dataset in training # but I don't know if this will slow down a little bit. new_batch = {} keys = batch[0].keys() values = list(zip(*[list(b.values()) for b in batch])) for i, k in enumerate(keys): value = values[i] if k == "img": value = torch.stack(value, 0) if k in ["mask", "keypoint", "bboxes", "cls"]: value = torch.cat(value, 0) new_batch[k] = values[i] new_batch["batch_idx"] = list(new_batch["batch_idx"]) for i in range(len(new_batch["batch_idx"])): new_batch["batch_idx"][i] += i # add target image index for build_targets() new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0) return new_batch # Classification dataloaders ------------------------------------------------------------------------------------------- class ClassificationDataset(torchvision.datasets.ImageFolder): """ YOLOv5 Classification Dataset. Arguments root: Dataset path transform: torchvision transforms, used by default album_transform: Albumentations transforms, used if installed """ def __init__(self, root, augment, imgsz, cache=False): super().__init__(root=root) self.torch_transforms = classify_transforms(imgsz) self.album_transforms = classify_albumentations(augment, imgsz) if augment else None self.cache_ram = cache is True or cache == "ram" self.cache_disk = cache == "disk" self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im def __getitem__(self, i): f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image if self.cache_ram and im is None: im = self.samples[i][3] = cv2.imread(f) elif self.cache_disk: if not fn.exists(): # load npy np.save(fn.as_posix(), cv2.imread(f)) im = np.load(fn) else: # read image im = cv2.imread(f) # BGR if self.album_transforms: sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"] else: sample = self.torch_transforms(im) return sample, j # TODO: support semantic segmentation class SemanticDataset(BaseDataset): def __init__(self): pass