New dataset fraction=1.0 argument (#2860)

This commit is contained in:
Glenn Jocher
2023-05-28 02:13:46 +02:00
committed by GitHub
parent 61fa5efe6d
commit 0bdd4ad379
7 changed files with 16 additions and 5 deletions

View File

@ -36,6 +36,7 @@ class BaseDataset(Dataset):
pad (float, optional): Padding. Defaults to 0.0.
single_cls (bool, optional): If True, single class training is used. Defaults to False.
classes (list): List of included classes. Default is None.
fraction (float): Fraction of dataset to utilize. Default is 1.0 (use all data).
Attributes:
im_files (list): List of image file paths.
@ -58,13 +59,15 @@ class BaseDataset(Dataset):
stride=32,
pad=0.5,
single_cls=False,
classes=None):
classes=None,
fraction=1.0):
super().__init__()
self.img_path = img_path
self.imgsz = imgsz
self.augment = augment
self.single_cls = single_cls
self.prefix = prefix
self.fraction = fraction
self.im_files = self.get_img_files(self.img_path)
self.labels = self.get_labels()
self.update_labels(include_class=classes) # single_cls and include_class
@ -114,6 +117,8 @@ class BaseDataset(Dataset):
assert im_files, f'{self.prefix}No images found'
except Exception as e:
raise FileNotFoundError(f'{self.prefix}Error loading data from {img_path}\n{HELP_URL}') from e
if self.fraction < 1:
im_files = im_files[:round(len(im_files) * self.fraction)]
return im_files
def update_labels(self, include_class: Optional[list]):

View File

@ -69,7 +69,7 @@ def seed_worker(worker_id): # noqa
random.seed(worker_seed)
def build_yolo_dataset(cfg, img_path, batch, data_info, mode='train', rect=False, stride=32):
def build_yolo_dataset(cfg, img_path, batch, data, mode='train', rect=False, stride=32):
"""Build YOLO Dataset"""
return YOLODataset(
img_path=img_path,
@ -86,7 +86,8 @@ def build_yolo_dataset(cfg, img_path, batch, data_info, mode='train', rect=False
use_segments=cfg.task == 'segment',
use_keypoints=cfg.task == 'pose',
classes=cfg.classes,
data=data_info)
data=data,
fraction=cfg.fraction if mode == 'train' else 1.0)
def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):

View File

@ -226,6 +226,8 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
cache (Union[bool, str], optional): Cache setting, can be True, False, 'ram' or 'disk'. Defaults to False.
"""
super().__init__(root=root)
if augment and args.fraction < 1.0: # reduce training fraction
self.samples = self.samples[:round(len(self.samples) * args.fraction)]
self.cache_ram = cache is True or cache == 'ram'
self.cache_disk = cache == 'disk'
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
@ -269,4 +271,4 @@ class SemanticDataset(BaseDataset):
def __init__(self):
"""Initialize a SemanticDataset object."""
pass
super().__init__()