|
|
|
@ -304,7 +304,8 @@ def check_cls_dataset(dataset: str, split=''):
|
|
|
|
|
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
|
|
|
|
|
LOGGER.info(s)
|
|
|
|
|
train_set = data_dir / 'train'
|
|
|
|
|
val_set = data_dir / 'val' if (data_dir / 'val').exists() else None # data/test or data/val
|
|
|
|
|
val_set = data_dir / 'val' if (data_dir / 'val').exists() else data_dir / 'validation' if (
|
|
|
|
|
data_dir / 'validation').exists() else None # data/test or data/val
|
|
|
|
|
test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test
|
|
|
|
|
if split == 'val' and not val_set:
|
|
|
|
|
LOGGER.info("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
|
|
|
|
@ -314,6 +315,17 @@ def check_cls_dataset(dataset: str, split=''):
|
|
|
|
|
nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
|
|
|
|
|
names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
|
|
|
|
|
names = dict(enumerate(sorted(names)))
|
|
|
|
|
|
|
|
|
|
# Print to console
|
|
|
|
|
for k, v in {'train': train_set, 'val': val_set, 'test': test_set}.items():
|
|
|
|
|
if v is None:
|
|
|
|
|
LOGGER.info(colorstr(k) + f': {v}')
|
|
|
|
|
else:
|
|
|
|
|
files = [path for path in v.rglob('*.*') if path.suffix[1:].lower() in IMG_FORMATS]
|
|
|
|
|
nf = len(files) # number of files
|
|
|
|
|
nd = len({file.parent for file in files}) # number of directories
|
|
|
|
|
LOGGER.info(colorstr(k) + f': {v}... found {nf} images in {nd} classes ✅ ') # keep trailing space
|
|
|
|
|
|
|
|
|
|
return {'train': train_set, 'val': val_set or test_set, 'test': test_set or val_set, 'nc': nc, 'names': names}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|