diff --git a/ultralytics/yolo/data/utils.py b/ultralytics/yolo/data/utils.py index b4832be..a4ce973 100644 --- a/ultralytics/yolo/data/utils.py +++ b/ultralytics/yolo/data/utils.py @@ -268,28 +268,33 @@ def check_det_dataset(dataset, autodownload=True): def check_cls_dataset(dataset: str, split=''): """ - Check a classification dataset such as Imagenet. + Checks a classification dataset such as Imagenet. - This function takes a `dataset` name as input and returns a dictionary containing information about the dataset. - If the dataset is not found, it attempts to download the dataset from the internet and save it locally. + This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information. + If the dataset is not found locally, it attempts to download the dataset from the internet and save it locally. Args: - dataset (str): Name of the dataset. - split (str, optional): Dataset split, either 'val', 'test', or ''. Defaults to ''. + dataset (str): The name of the dataset. + split (str, optional): The split of the dataset. Either 'val', 'test', or ''. Defaults to ''. Returns: - data (dict): A dictionary containing the following keys and values: - 'train': Path object for the directory containing the training set of the dataset - 'val': Path object for the directory containing the validation set of the dataset - 'test': Path object for the directory containing the test set of the dataset - 'nc': Number of classes in the dataset - 'names': List of class names in the dataset + dict: A dictionary containing the following keys: + - 'train' (Path): The directory path containing the training set of the dataset. + - 'val' (Path): The directory path containing the validation set of the dataset. + - 'test' (Path): The directory path containing the test set of the dataset. + - 'nc' (int): The number of classes in the dataset. + - 'names' (dict): A dictionary of class names in the dataset. + + Raises: + FileNotFoundError: If the specified dataset is not found and cannot be downloaded. """ - data_dir = (DATASETS_DIR / dataset).resolve() + + dataset = Path(dataset) + data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve() if not data_dir.is_dir(): LOGGER.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...') t = time.time() - if dataset == 'imagenet': + if str(dataset) == 'imagenet': subprocess.run(f"bash {ROOT / 'yolo/data/scripts/get_imagenet.sh'}", shell=True, check=True) else: url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip' @@ -312,12 +317,12 @@ def check_cls_dataset(dataset: str, split=''): class HUBDatasetStats(): """ - Class for generating HUB dataset JSON and `-hub` dataset directory + A class for generating HUB dataset JSON and `-hub` dataset directory. - Arguments - path: Path to data.yaml or data.zip (with data.yaml inside data.zip) - task: Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. - autodownload: Attempt to download dataset if not found locally + Args: + path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip). Default is 'coco128.yaml'. + task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Default is 'detect'. + autodownload (bool): Attempt to download dataset if not found locally. Default is False. Usage from ultralytics.yolo.data.utils import HUBDatasetStats