From 39395aedc8888b67136ba617ef599d86d7ab63e8 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 11 Aug 2023 20:39:50 +0200 Subject: [PATCH] Improved Classify dataset splits and introspection (#4312) --- ultralytics/data/utils.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/ultralytics/data/utils.py b/ultralytics/data/utils.py index a46f0b7..b728546 100644 --- a/ultralytics/data/utils.py +++ b/ultralytics/data/utils.py @@ -304,7 +304,8 @@ def check_cls_dataset(dataset: str, split=''): s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n" LOGGER.info(s) train_set = data_dir / 'train' - val_set = data_dir / 'val' if (data_dir / 'val').exists() else None # data/test or data/val + val_set = data_dir / 'val' if (data_dir / 'val').exists() else data_dir / 'validation' if ( + data_dir / 'validation').exists() else None # data/test or data/val test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test if split == 'val' and not val_set: LOGGER.info("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.") @@ -314,6 +315,17 @@ def check_cls_dataset(dataset: str, split=''): nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list names = dict(enumerate(sorted(names))) + + # Print to console + for k, v in {'train': train_set, 'val': val_set, 'test': test_set}.items(): + if v is None: + LOGGER.info(colorstr(k) + f': {v}') + else: + files = [path for path in v.rglob('*.*') if path.suffix[1:].lower() in IMG_FORMATS] + nf = len(files) # number of files + nd = len({file.parent for file in files}) # number of directories + LOGGER.info(colorstr(k) + f': {v}... found {nf} images in {nd} classes ✅ ') # keep trailing space + return {'train': train_set, 'val': val_set or test_set, 'test': test_set or val_set, 'nc': nc, 'names': names}