Add Classification model YAML support (#154)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
@ -255,12 +255,28 @@ def check_dataset_yaml(data, autodownload=True):
|
||||
|
||||
|
||||
def check_dataset(dataset: str):
|
||||
data = Path.cwd() / "datasets" / dataset
|
||||
data_dir = data if data.is_dir() else (Path.cwd() / data)
|
||||
"""
|
||||
Check a classification dataset such as Imagenet.
|
||||
|
||||
Copy code
|
||||
This function takes a `dataset` name as input and returns a dictionary containing information about the dataset.
|
||||
If the dataset is not found, it attempts to download the dataset from the internet and save it to the local file system.
|
||||
|
||||
Args:
|
||||
dataset (str): Name of the dataset.
|
||||
|
||||
Returns:
|
||||
data (dict): A dictionary containing the following keys and values:
|
||||
'train': Path object for the directory containing the training set of the dataset
|
||||
'val': Path object for the directory containing the validation set of the dataset
|
||||
'nc': Number of classes in the dataset
|
||||
'names': List of class names in the dataset
|
||||
"""
|
||||
data_dir = (Path.cwd() / "datasets" / dataset).resolve()
|
||||
if not data_dir.is_dir():
|
||||
LOGGER.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
|
||||
t = time.time()
|
||||
if str(data) == 'imagenet':
|
||||
if dataset == 'imagenet':
|
||||
subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
|
||||
else:
|
||||
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
|
||||
@ -271,5 +287,4 @@ def check_dataset(dataset: str):
|
||||
test_set = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val' # data/test or data/val
|
||||
nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
|
||||
names = [name for name in os.listdir(data_dir / 'train') if os.path.isdir(data_dir / 'train' / name)]
|
||||
data = {"train": train_set, "val": test_set, "nc": nc, "names": names}
|
||||
return data
|
||||
return {"train": train_set, "val": test_set, "nc": nc, "names": names}
|
||||
|
Reference in New Issue
Block a user