Improved Classify dataset splits and introspection (#4312)
This commit is contained in:
@ -304,7 +304,8 @@ def check_cls_dataset(dataset: str, split=''):
|
||||
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
|
||||
LOGGER.info(s)
|
||||
train_set = data_dir / 'train'
|
||||
val_set = data_dir / 'val' if (data_dir / 'val').exists() else None # data/test or data/val
|
||||
val_set = data_dir / 'val' if (data_dir / 'val').exists() else data_dir / 'validation' if (
|
||||
data_dir / 'validation').exists() else None # data/test or data/val
|
||||
test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test
|
||||
if split == 'val' and not val_set:
|
||||
LOGGER.info("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
|
||||
@ -314,6 +315,17 @@ def check_cls_dataset(dataset: str, split=''):
|
||||
nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
|
||||
names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
|
||||
names = dict(enumerate(sorted(names)))
|
||||
|
||||
# Print to console
|
||||
for k, v in {'train': train_set, 'val': val_set, 'test': test_set}.items():
|
||||
if v is None:
|
||||
LOGGER.info(colorstr(k) + f': {v}')
|
||||
else:
|
||||
files = [path for path in v.rglob('*.*') if path.suffix[1:].lower() in IMG_FORMATS]
|
||||
nf = len(files) # number of files
|
||||
nd = len({file.parent for file in files}) # number of directories
|
||||
LOGGER.info(colorstr(k) + f': {v}... found {nf} images in {nd} classes ✅ ') # keep trailing space
|
||||
|
||||
return {'train': train_set, 'val': val_set or test_set, 'test': test_set or val_set, 'nc': nc, 'names': names}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user