You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
214 lines
8.9 KiB
214 lines
8.9 KiB
2 years ago
|
from itertools import repeat
|
||
|
from multiprocessing.pool import Pool
|
||
|
from pathlib import Path
|
||
|
|
||
|
import cv2
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import torchvision
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
from ..utils.general import LOGGER, NUM_THREADS
|
||
|
from .augment import *
|
||
|
from .base import BaseDataset
|
||
|
from .utils import BAR_FORMAT, HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label
|
||
|
|
||
|
|
||
|
class YOLODataset(BaseDataset):
|
||
|
cache_version = 0.6 # dataset labels *.cache version
|
||
|
rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
|
||
|
"""YOLO Dataset.
|
||
|
Args:
|
||
|
img_path (str): image path.
|
||
|
prefix (str): prefix.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
img_path,
|
||
|
img_size=640,
|
||
|
label_path=None,
|
||
|
cache=False,
|
||
|
augment=True,
|
||
|
hyp=None,
|
||
|
prefix="",
|
||
|
rect=False,
|
||
|
batch_size=None,
|
||
|
stride=32,
|
||
|
pad=0.0,
|
||
|
single_cls=False,
|
||
|
use_segments=False,
|
||
|
use_keypoints=False,
|
||
|
):
|
||
|
self.use_segments = use_segments
|
||
|
self.use_keypoints = use_keypoints
|
||
|
assert not (self.use_segments and self.use_keypoints), "We can't use both of segmentation and pose."
|
||
|
super().__init__(img_path, img_size, label_path, cache, augment, hyp, prefix, rect, batch_size, stride, pad,
|
||
|
single_cls)
|
||
|
|
||
|
def cache_labels(self, path=Path("./labels.cache")):
|
||
|
# Cache dataset labels, check images and read shapes
|
||
|
x = {"labels": []}
|
||
|
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
||
|
desc = f"{self.prefix}Scanning '{path.parent / path.stem}' images and labels..."
|
||
|
with Pool(NUM_THREADS) as pool:
|
||
|
pbar = tqdm(
|
||
|
pool.imap(verify_image_label,
|
||
|
zip(self.im_files, self.label_files, repeat(self.prefix), repeat(self.use_keypoints))),
|
||
|
desc=desc,
|
||
|
total=len(self.im_files),
|
||
|
bar_format=BAR_FORMAT,
|
||
|
)
|
||
|
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
|
||
|
nm += nm_f
|
||
|
nf += nf_f
|
||
|
ne += ne_f
|
||
|
nc += nc_f
|
||
|
if im_file:
|
||
|
x["labels"].append(
|
||
|
dict(
|
||
|
im_file=im_file,
|
||
|
shape=shape,
|
||
|
cls=lb[:, 0:1], # n, 1
|
||
|
bboxes=lb[:, 1:], # n, 4
|
||
|
segments=segments,
|
||
|
keypoints=keypoint,
|
||
|
normalized=True,
|
||
|
bbox_format="xywh",
|
||
|
))
|
||
|
if msg:
|
||
|
msgs.append(msg)
|
||
|
pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupt"
|
||
|
|
||
|
pbar.close()
|
||
|
if msgs:
|
||
|
LOGGER.info("\n".join(msgs))
|
||
|
if nf == 0:
|
||
|
LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}")
|
||
|
x["hash"] = get_hash(self.label_files + self.im_files)
|
||
|
x["results"] = nf, nm, ne, nc, len(self.im_files)
|
||
|
x["msgs"] = msgs # warnings
|
||
|
x["version"] = self.cache_version # cache version
|
||
|
try:
|
||
|
np.save(path, x) # save cache for next time
|
||
|
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
|
||
|
LOGGER.info(f"{self.prefix}New cache created: {path}")
|
||
|
except Exception as e:
|
||
|
LOGGER.warning(
|
||
|
f"{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable: {e}") # not writeable
|
||
|
return x
|
||
|
|
||
|
def get_labels(self):
|
||
|
self.label_files = img2label_paths(self.im_files)
|
||
|
cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
|
||
|
try:
|
||
|
cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
|
||
|
assert cache["version"] == self.cache_version # matches current version
|
||
|
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
|
||
|
except Exception:
|
||
|
cache, exists = self.cache_labels(cache_path), False # run cache ops
|
||
|
|
||
|
# Display cache
|
||
|
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
|
||
|
if exists and LOCAL_RANK in {-1, 0}:
|
||
|
d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupt"
|
||
|
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=BAR_FORMAT) # display cache results
|
||
|
if cache["msgs"]:
|
||
|
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
||
|
assert nf > 0, f"{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}"
|
||
|
|
||
|
# Read cache
|
||
|
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
|
||
|
labels = cache["labels"]
|
||
|
nl = len(np.concatenate([label["cls"] for label in labels], 0)) # number of labels
|
||
|
assert nl > 0, f"{self.prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}"
|
||
|
return labels
|
||
|
|
||
|
# TODO: use hyp config to set all these augmentations
|
||
|
def build_transforms(self, hyp=None):
|
||
|
mosaic = self.augment and not self.rect
|
||
|
# mosaic = False
|
||
|
if self.augment:
|
||
|
if mosaic:
|
||
|
transforms = mosaic_transforms(self.img_size, hyp)
|
||
|
else:
|
||
|
transforms = affine_transforms(self.img_size, hyp)
|
||
|
else:
|
||
|
transforms = Compose([LetterBox(new_shape=(self.img_size, self.img_size))])
|
||
|
transforms.append(Format(bbox_format="xywh", normalize=True, mask=self.use_segments, batch_idx=True))
|
||
|
return transforms
|
||
|
|
||
|
def update_labels_info(self, label):
|
||
|
"""custom your label format here"""
|
||
|
# NOTE: cls is not with bboxes now, since other tasks like classification and semantic segmentation need a independent cls label
|
||
|
# we can make it also support classification and semantic segmentation by add or remove some dict keys there.
|
||
|
bboxes = label.pop("bboxes")
|
||
|
segments = label.pop("segments", None)
|
||
|
keypoints = label.pop("keypoints", None)
|
||
|
bbox_format = label.pop("bbox_format")
|
||
|
normalized = label.pop("normalized")
|
||
|
label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
|
||
|
return label
|
||
|
|
||
|
@staticmethod
|
||
|
def collate_fn(batch):
|
||
|
# TODO: returning a dict can make thing easier and cleaner when using dataset in training
|
||
|
# but I don't know if this will slow down a little bit.
|
||
|
new_batch = {}
|
||
|
keys = batch[0].keys()
|
||
|
values = list(zip(*[list(b.values()) for b in batch]))
|
||
|
for i, k in enumerate(keys):
|
||
|
value = values[i]
|
||
|
if k == "img":
|
||
|
value = torch.stack(value, 0)
|
||
|
if k in ["mask", "keypoint", "bboxes", "cls"]:
|
||
|
value = torch.cat(value, 0)
|
||
|
new_batch[k] = values[i]
|
||
|
new_batch["batch_idx"] = list(new_batch["batch_idx"])
|
||
|
for i in range(len(new_batch["batch_idx"])):
|
||
|
new_batch["batch_idx"][i] += i # add target image index for build_targets()
|
||
|
new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
|
||
|
return new_batch
|
||
|
|
||
|
|
||
|
# Classification dataloaders -------------------------------------------------------------------------------------------
|
||
|
class ClassificationDataset(torchvision.datasets.ImageFolder):
|
||
|
"""
|
||
|
YOLOv5 Classification Dataset.
|
||
|
Arguments
|
||
|
root: Dataset path
|
||
|
transform: torchvision transforms, used by default
|
||
|
album_transform: Albumentations transforms, used if installed
|
||
|
"""
|
||
|
|
||
|
def __init__(self, root, augment, imgsz, cache=False):
|
||
|
super().__init__(root=root)
|
||
|
self.torch_transforms = classify_transforms(imgsz)
|
||
|
self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
|
||
|
self.cache_ram = cache is True or cache == "ram"
|
||
|
self.cache_disk = cache == "disk"
|
||
|
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
|
||
|
|
||
|
def __getitem__(self, i):
|
||
|
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
|
||
|
if self.cache_ram and im is None:
|
||
|
im = self.samples[i][3] = cv2.imread(f)
|
||
|
elif self.cache_disk:
|
||
|
if not fn.exists(): # load npy
|
||
|
np.save(fn.as_posix(), cv2.imread(f))
|
||
|
im = np.load(fn)
|
||
|
else: # read image
|
||
|
im = cv2.imread(f) # BGR
|
||
|
if self.album_transforms:
|
||
|
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
|
||
|
else:
|
||
|
sample = self.torch_transforms(im)
|
||
|
return sample, j
|
||
|
|
||
|
|
||
|
# TODO: support semantic segmentation
|
||
|
class SemanticDataset(BaseDataset):
|
||
|
|
||
|
def __init__(self):
|
||
|
pass
|