ultralytics 8.0.107
(#2778)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Peter van Lunteren <contact@pvanlunteren.com>
This commit is contained in:
@ -766,9 +766,17 @@ def v8_transforms(dataset, imgsz, hyp):
|
||||
pre_transform=LetterBox(new_shape=(imgsz, imgsz)),
|
||||
)])
|
||||
flip_idx = dataset.data.get('flip_idx', None) # for keypoints augmentation
|
||||
if dataset.use_keypoints and flip_idx is None and hyp.fliplr > 0.0:
|
||||
hyp.fliplr = 0.0
|
||||
LOGGER.warning("WARNING ⚠️ No `flip_idx` provided while training keypoints, setting augmentation 'fliplr=0.0'")
|
||||
if dataset.use_keypoints:
|
||||
kpt_shape = dataset.data.get('kpt_shape', None)
|
||||
if flip_idx is None and hyp.fliplr > 0.0:
|
||||
hyp.fliplr = 0.0
|
||||
LOGGER.warning("WARNING ⚠️ No 'flip_idx' array defined in data.yaml, setting augmentation 'fliplr=0.0'")
|
||||
elif flip_idx:
|
||||
if len(flip_idx) != kpt_shape[0]:
|
||||
raise ValueError(f'data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}')
|
||||
elif flip_idx[0] != 0:
|
||||
raise ValueError(f'data.yaml flip_idx={flip_idx} must be zero-index (start from 0)')
|
||||
|
||||
return Compose([
|
||||
pre_transform,
|
||||
MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
|
||||
|
@ -266,7 +266,7 @@ def check_det_dataset(dataset, autodownload=True):
|
||||
return data # dictionary
|
||||
|
||||
|
||||
def check_cls_dataset(dataset: str):
|
||||
def check_cls_dataset(dataset: str, split=''):
|
||||
"""
|
||||
Check a classification dataset such as Imagenet.
|
||||
|
||||
@ -275,6 +275,7 @@ def check_cls_dataset(dataset: str):
|
||||
|
||||
Args:
|
||||
dataset (str): Name of the dataset.
|
||||
split (str, optional): Dataset split, either 'val', 'test', or ''. Defaults to ''.
|
||||
|
||||
Returns:
|
||||
data (dict): A dictionary containing the following keys and values:
|
||||
@ -298,10 +299,15 @@ def check_cls_dataset(dataset: str):
|
||||
train_set = data_dir / 'train'
|
||||
val_set = data_dir / 'val' if (data_dir / 'val').exists() else None # data/test or data/val
|
||||
test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test
|
||||
if split == 'val' and not val_set:
|
||||
LOGGER.info("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
|
||||
elif split == 'test' and not test_set:
|
||||
LOGGER.info("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.")
|
||||
|
||||
nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
|
||||
names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
|
||||
names = dict(enumerate(sorted(names)))
|
||||
return {'train': train_set, 'val': val_set, 'test': test_set, 'nc': nc, 'names': names}
|
||||
return {'train': train_set, 'val': val_set or test_set, 'test': test_set or val_set, 'nc': nc, 'names': names}
|
||||
|
||||
|
||||
class HUBDatasetStats():
|
||||
|
Reference in New Issue
Block a user