Fix Classify train from arbitrary CWD (#3692)

single_channel
Glenn Jocher 1 year ago committed by GitHub
parent 15e9eac27b
commit 395cc47c53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -268,28 +268,33 @@ def check_det_dataset(dataset, autodownload=True):
def check_cls_dataset(dataset: str, split=''): def check_cls_dataset(dataset: str, split=''):
""" """
Check a classification dataset such as Imagenet. Checks a classification dataset such as Imagenet.
This function takes a `dataset` name as input and returns a dictionary containing information about the dataset. This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information.
If the dataset is not found, it attempts to download the dataset from the internet and save it locally. If the dataset is not found locally, it attempts to download the dataset from the internet and save it locally.
Args: Args:
dataset (str): Name of the dataset. dataset (str): The name of the dataset.
split (str, optional): Dataset split, either 'val', 'test', or ''. Defaults to ''. split (str, optional): The split of the dataset. Either 'val', 'test', or ''. Defaults to ''.
Returns: Returns:
data (dict): A dictionary containing the following keys and values: dict: A dictionary containing the following keys:
'train': Path object for the directory containing the training set of the dataset - 'train' (Path): The directory path containing the training set of the dataset.
'val': Path object for the directory containing the validation set of the dataset - 'val' (Path): The directory path containing the validation set of the dataset.
'test': Path object for the directory containing the test set of the dataset - 'test' (Path): The directory path containing the test set of the dataset.
'nc': Number of classes in the dataset - 'nc' (int): The number of classes in the dataset.
'names': List of class names in the dataset - 'names' (dict): A dictionary of class names in the dataset.
Raises:
FileNotFoundError: If the specified dataset is not found and cannot be downloaded.
""" """
data_dir = (DATASETS_DIR / dataset).resolve()
dataset = Path(dataset)
data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
if not data_dir.is_dir(): if not data_dir.is_dir():
LOGGER.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...') LOGGER.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
t = time.time() t = time.time()
if dataset == 'imagenet': if str(dataset) == 'imagenet':
subprocess.run(f"bash {ROOT / 'yolo/data/scripts/get_imagenet.sh'}", shell=True, check=True) subprocess.run(f"bash {ROOT / 'yolo/data/scripts/get_imagenet.sh'}", shell=True, check=True)
else: else:
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip' url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
@ -312,12 +317,12 @@ def check_cls_dataset(dataset: str, split=''):
class HUBDatasetStats(): class HUBDatasetStats():
""" """
Class for generating HUB dataset JSON and `-hub` dataset directory A class for generating HUB dataset JSON and `-hub` dataset directory.
Arguments Args:
path: Path to data.yaml or data.zip (with data.yaml inside data.zip) path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip). Default is 'coco128.yaml'.
task: Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Default is 'detect'.
autodownload: Attempt to download dataset if not found locally autodownload (bool): Attempt to download dataset if not found locally. Default is False.
Usage Usage
from ultralytics.yolo.data.utils import HUBDatasetStats from ultralytics.yolo.data.utils import HUBDatasetStats

Loading…
Cancel
Save