Improved Classify dataset splits and introspection (#4312)

single_channel
Glenn Jocher 1 year ago committed by GitHub
parent 645c715350
commit 39395aedc8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -304,7 +304,8 @@ def check_cls_dataset(dataset: str, split=''):
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n" s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
LOGGER.info(s) LOGGER.info(s)
train_set = data_dir / 'train' train_set = data_dir / 'train'
val_set = data_dir / 'val' if (data_dir / 'val').exists() else None # data/test or data/val val_set = data_dir / 'val' if (data_dir / 'val').exists() else data_dir / 'validation' if (
data_dir / 'validation').exists() else None # data/test or data/val
test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test
if split == 'val' and not val_set: if split == 'val' and not val_set:
LOGGER.info("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.") LOGGER.info("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
@ -314,6 +315,17 @@ def check_cls_dataset(dataset: str, split=''):
nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
names = dict(enumerate(sorted(names))) names = dict(enumerate(sorted(names)))
# Print to console
for k, v in {'train': train_set, 'val': val_set, 'test': test_set}.items():
if v is None:
LOGGER.info(colorstr(k) + f': {v}')
else:
files = [path for path in v.rglob('*.*') if path.suffix[1:].lower() in IMG_FORMATS]
nf = len(files) # number of files
nd = len({file.parent for file in files}) # number of directories
LOGGER.info(colorstr(k) + f': {v}... found {nf} images in {nd} classes ✅ ') # keep trailing space
return {'train': train_set, 'val': val_set or test_set, 'test': test_set or val_set, 'nc': nc, 'names': names} return {'train': train_set, 'val': val_set or test_set, 'test': test_set or val_set, 'nc': nc, 'names': names}

Loading…
Cancel
Save