New dataset fraction=1.0
argument (#2860)
This commit is contained in:
@ -36,6 +36,7 @@ class BaseDataset(Dataset):
|
||||
pad (float, optional): Padding. Defaults to 0.0.
|
||||
single_cls (bool, optional): If True, single class training is used. Defaults to False.
|
||||
classes (list): List of included classes. Default is None.
|
||||
fraction (float): Fraction of dataset to utilize. Default is 1.0 (use all data).
|
||||
|
||||
Attributes:
|
||||
im_files (list): List of image file paths.
|
||||
@ -58,13 +59,15 @@ class BaseDataset(Dataset):
|
||||
stride=32,
|
||||
pad=0.5,
|
||||
single_cls=False,
|
||||
classes=None):
|
||||
classes=None,
|
||||
fraction=1.0):
|
||||
super().__init__()
|
||||
self.img_path = img_path
|
||||
self.imgsz = imgsz
|
||||
self.augment = augment
|
||||
self.single_cls = single_cls
|
||||
self.prefix = prefix
|
||||
self.fraction = fraction
|
||||
self.im_files = self.get_img_files(self.img_path)
|
||||
self.labels = self.get_labels()
|
||||
self.update_labels(include_class=classes) # single_cls and include_class
|
||||
@ -114,6 +117,8 @@ class BaseDataset(Dataset):
|
||||
assert im_files, f'{self.prefix}No images found'
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(f'{self.prefix}Error loading data from {img_path}\n{HELP_URL}') from e
|
||||
if self.fraction < 1:
|
||||
im_files = im_files[:round(len(im_files) * self.fraction)]
|
||||
return im_files
|
||||
|
||||
def update_labels(self, include_class: Optional[list]):
|
||||
|
@ -69,7 +69,7 @@ def seed_worker(worker_id): # noqa
|
||||
random.seed(worker_seed)
|
||||
|
||||
|
||||
def build_yolo_dataset(cfg, img_path, batch, data_info, mode='train', rect=False, stride=32):
|
||||
def build_yolo_dataset(cfg, img_path, batch, data, mode='train', rect=False, stride=32):
|
||||
"""Build YOLO Dataset"""
|
||||
return YOLODataset(
|
||||
img_path=img_path,
|
||||
@ -86,7 +86,8 @@ def build_yolo_dataset(cfg, img_path, batch, data_info, mode='train', rect=False
|
||||
use_segments=cfg.task == 'segment',
|
||||
use_keypoints=cfg.task == 'pose',
|
||||
classes=cfg.classes,
|
||||
data=data_info)
|
||||
data=data,
|
||||
fraction=cfg.fraction if mode == 'train' else 1.0)
|
||||
|
||||
|
||||
def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
|
||||
|
@ -226,6 +226,8 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
||||
cache (Union[bool, str], optional): Cache setting, can be True, False, 'ram' or 'disk'. Defaults to False.
|
||||
"""
|
||||
super().__init__(root=root)
|
||||
if augment and args.fraction < 1.0: # reduce training fraction
|
||||
self.samples = self.samples[:round(len(self.samples) * args.fraction)]
|
||||
self.cache_ram = cache is True or cache == 'ram'
|
||||
self.cache_disk = cache == 'disk'
|
||||
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
|
||||
@ -269,4 +271,4 @@ class SemanticDataset(BaseDataset):
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize a SemanticDataset object."""
|
||||
pass
|
||||
super().__init__()
|
||||
|
Reference in New Issue
Block a user