ultralytics 8.0.105
classification hyp fix and new onplot
callbacks (#2684)
Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ivan Shcheklein <shcheklein@gmail.com>
This commit is contained in:
@ -789,13 +789,20 @@ def classify_transforms(size=224, mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0)): #
|
||||
return T.Compose([CenterCrop(size), ToTensor()])
|
||||
|
||||
|
||||
def hsv2colorjitter(h, s, v):
|
||||
"""Map HSV (hue, saturation, value) jitter into ColorJitter values (brightness, contrast, saturation, hue)"""
|
||||
return v, v, s, h
|
||||
|
||||
|
||||
def classify_albumentations(
|
||||
augment=True,
|
||||
size=224,
|
||||
scale=(0.08, 1.0),
|
||||
hflip=0.5,
|
||||
vflip=0.0,
|
||||
jitter=0.4,
|
||||
hsv_h=0.015, # image HSV-Hue augmentation (fraction)
|
||||
hsv_s=0.7, # image HSV-Saturation augmentation (fraction)
|
||||
hsv_v=0.4, # image HSV-Value augmentation (fraction)
|
||||
mean=(0.0, 0.0, 0.0), # IMAGENET_MEAN
|
||||
std=(1.0, 1.0, 1.0), # IMAGENET_STD
|
||||
auto_aug=False,
|
||||
@ -810,16 +817,15 @@ def classify_albumentations(
|
||||
if augment: # Resize and crop
|
||||
T = [A.RandomResizedCrop(height=size, width=size, scale=scale)]
|
||||
if auto_aug:
|
||||
# TODO: implement AugMix, AutoAug & RandAug in albumentation
|
||||
# TODO: implement AugMix, AutoAug & RandAug in albumentations
|
||||
LOGGER.info(f'{prefix}auto augmentations are currently not supported')
|
||||
else:
|
||||
if hflip > 0:
|
||||
T += [A.HorizontalFlip(p=hflip)]
|
||||
if vflip > 0:
|
||||
T += [A.VerticalFlip(p=vflip)]
|
||||
if jitter > 0:
|
||||
jitter = float(jitter)
|
||||
T += [A.ColorJitter(jitter, jitter, jitter, 0)] # brightness, contrast, saturation, 0 hue
|
||||
if any((hsv_h, hsv_s, hsv_v)):
|
||||
T += [A.ColorJitter(*hsv2colorjitter(hsv_h, hsv_s, hsv_v))] # brightness, contrast, saturation, hue
|
||||
else: # Use fixed crop for eval set (reproducibility)
|
||||
T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
|
||||
T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor
|
||||
|
@ -202,21 +202,48 @@ class YOLODataset(BaseDataset):
|
||||
# Classification dataloaders -------------------------------------------------------------------------------------------
|
||||
class ClassificationDataset(torchvision.datasets.ImageFolder):
|
||||
"""
|
||||
YOLOv5 Classification Dataset.
|
||||
Arguments
|
||||
root: Dataset path
|
||||
transform: torchvision transforms, used by default
|
||||
album_transform: Albumentations transforms, used if installed
|
||||
YOLO Classification Dataset.
|
||||
|
||||
Args:
|
||||
root (str): Dataset path.
|
||||
transform (callable, optional): torchvision transforms, used by default.
|
||||
album_transform (callable, optional): Albumentations transforms, used if installed.
|
||||
|
||||
Attributes:
|
||||
cache_ram (bool): True if images should be cached in RAM, False otherwise.
|
||||
cache_disk (bool): True if images should be cached on disk, False otherwise.
|
||||
samples (list): List of samples containing file, index, npy, and im.
|
||||
torch_transforms (callable): torchvision transforms applied to the dataset.
|
||||
album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True.
|
||||
"""
|
||||
|
||||
def __init__(self, root, augment=False, imgsz=224, cache=False):
|
||||
"""Initialize YOLO object with root, image size, augmentations, and cache settings"""
|
||||
def __init__(self, root, args, augment=False, cache=False):
|
||||
"""
|
||||
Initialize YOLO object with root, image size, augmentations, and cache settings.
|
||||
|
||||
Args:
|
||||
root (str): Dataset path.
|
||||
args (Namespace): Argument parser containing dataset related settings.
|
||||
augment (bool, optional): True if dataset should be augmented, False otherwise. Defaults to False.
|
||||
cache (Union[bool, str], optional): Cache setting, can be True, False, 'ram' or 'disk'. Defaults to False.
|
||||
"""
|
||||
super().__init__(root=root)
|
||||
self.torch_transforms = classify_transforms(imgsz)
|
||||
self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
|
||||
self.cache_ram = cache is True or cache == 'ram'
|
||||
self.cache_disk = cache == 'disk'
|
||||
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
|
||||
self.torch_transforms = classify_transforms(args.imgsz)
|
||||
self.album_transforms = classify_albumentations(
|
||||
augment=augment,
|
||||
size=args.imgsz,
|
||||
scale=(1.0 - args.scale, 1.0), # (0.08, 1.0)
|
||||
hflip=args.fliplr,
|
||||
vflip=args.flipud,
|
||||
hsv_h=args.hsv_h, # HSV-Hue augmentation (fraction)
|
||||
hsv_s=args.hsv_s, # HSV-Saturation augmentation (fraction)
|
||||
hsv_v=args.hsv_v, # HSV-Value augmentation (fraction)
|
||||
mean=(0.0, 0.0, 0.0), # IMAGENET_MEAN
|
||||
std=(1.0, 1.0, 1.0), # IMAGENET_STD
|
||||
auto_aug=False) if augment else None
|
||||
|
||||
def __getitem__(self, i):
|
||||
"""Returns subset of data and targets corresponding to given indices."""
|
||||
|
Reference in New Issue
Block a user