Update .pre-commit-config.yaml
(#1026)
This commit is contained in:
@ -6,11 +6,11 @@ from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
|
||||
from .dataset_wrappers import MixAndRectDataset
|
||||
|
||||
__all__ = [
|
||||
"BaseDataset",
|
||||
"ClassificationDataset",
|
||||
"MixAndRectDataset",
|
||||
"SemanticDataset",
|
||||
"YOLODataset",
|
||||
"build_classification_dataloader",
|
||||
"build_dataloader",
|
||||
"load_inference_source",]
|
||||
'BaseDataset',
|
||||
'ClassificationDataset',
|
||||
'MixAndRectDataset',
|
||||
'SemanticDataset',
|
||||
'YOLODataset',
|
||||
'build_classification_dataloader',
|
||||
'build_dataloader',
|
||||
'load_inference_source',]
|
||||
|
@ -55,11 +55,11 @@ class Compose:
|
||||
return self.transforms
|
||||
|
||||
def __repr__(self):
|
||||
format_string = f"{self.__class__.__name__}("
|
||||
format_string = f'{self.__class__.__name__}('
|
||||
for t in self.transforms:
|
||||
format_string += "\n"
|
||||
format_string += f" {t}"
|
||||
format_string += "\n)"
|
||||
format_string += '\n'
|
||||
format_string += f' {t}'
|
||||
format_string += '\n)'
|
||||
return format_string
|
||||
|
||||
|
||||
@ -86,11 +86,11 @@ class BaseMixTransform:
|
||||
if self.pre_transform is not None:
|
||||
for i, data in enumerate(mix_labels):
|
||||
mix_labels[i] = self.pre_transform(data)
|
||||
labels["mix_labels"] = mix_labels
|
||||
labels['mix_labels'] = mix_labels
|
||||
|
||||
# Mosaic or MixUp
|
||||
labels = self._mix_transform(labels)
|
||||
labels.pop("mix_labels", None)
|
||||
labels.pop('mix_labels', None)
|
||||
return labels
|
||||
|
||||
def _mix_transform(self, labels):
|
||||
@ -109,7 +109,7 @@ class Mosaic(BaseMixTransform):
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, imgsz=640, p=1.0, border=(0, 0)):
|
||||
assert 0 <= p <= 1.0, "The probability should be in range [0, 1]. " f"got {p}."
|
||||
assert 0 <= p <= 1.0, 'The probability should be in range [0, 1]. ' f'got {p}.'
|
||||
super().__init__(dataset=dataset, p=p)
|
||||
self.dataset = dataset
|
||||
self.imgsz = imgsz
|
||||
@ -120,15 +120,15 @@ class Mosaic(BaseMixTransform):
|
||||
|
||||
def _mix_transform(self, labels):
|
||||
mosaic_labels = []
|
||||
assert labels.get("rect_shape", None) is None, "rect and mosaic is exclusive."
|
||||
assert len(labels.get("mix_labels", [])) > 0, "There are no other images for mosaic augment."
|
||||
assert labels.get('rect_shape', None) is None, 'rect and mosaic is exclusive.'
|
||||
assert len(labels.get('mix_labels', [])) > 0, 'There are no other images for mosaic augment.'
|
||||
s = self.imgsz
|
||||
yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y
|
||||
for i in range(4):
|
||||
labels_patch = (labels if i == 0 else labels["mix_labels"][i - 1]).copy()
|
||||
labels_patch = (labels if i == 0 else labels['mix_labels'][i - 1]).copy()
|
||||
# Load image
|
||||
img = labels_patch["img"]
|
||||
h, w = labels_patch.pop("resized_shape")
|
||||
img = labels_patch['img']
|
||||
h, w = labels_patch.pop('resized_shape')
|
||||
|
||||
# place img in img4
|
||||
if i == 0: # top left
|
||||
@ -152,15 +152,15 @@ class Mosaic(BaseMixTransform):
|
||||
labels_patch = self._update_labels(labels_patch, padw, padh)
|
||||
mosaic_labels.append(labels_patch)
|
||||
final_labels = self._cat_labels(mosaic_labels)
|
||||
final_labels["img"] = img4
|
||||
final_labels['img'] = img4
|
||||
return final_labels
|
||||
|
||||
def _update_labels(self, labels, padw, padh):
|
||||
"""Update labels"""
|
||||
nh, nw = labels["img"].shape[:2]
|
||||
labels["instances"].convert_bbox(format="xyxy")
|
||||
labels["instances"].denormalize(nw, nh)
|
||||
labels["instances"].add_padding(padw, padh)
|
||||
nh, nw = labels['img'].shape[:2]
|
||||
labels['instances'].convert_bbox(format='xyxy')
|
||||
labels['instances'].denormalize(nw, nh)
|
||||
labels['instances'].add_padding(padw, padh)
|
||||
return labels
|
||||
|
||||
def _cat_labels(self, mosaic_labels):
|
||||
@ -169,16 +169,16 @@ class Mosaic(BaseMixTransform):
|
||||
cls = []
|
||||
instances = []
|
||||
for labels in mosaic_labels:
|
||||
cls.append(labels["cls"])
|
||||
instances.append(labels["instances"])
|
||||
cls.append(labels['cls'])
|
||||
instances.append(labels['instances'])
|
||||
final_labels = {
|
||||
"im_file": mosaic_labels[0]["im_file"],
|
||||
"ori_shape": mosaic_labels[0]["ori_shape"],
|
||||
"resized_shape": (self.imgsz * 2, self.imgsz * 2),
|
||||
"cls": np.concatenate(cls, 0),
|
||||
"instances": Instances.concatenate(instances, axis=0),
|
||||
"mosaic_border": self.border}
|
||||
final_labels["instances"].clip(self.imgsz * 2, self.imgsz * 2)
|
||||
'im_file': mosaic_labels[0]['im_file'],
|
||||
'ori_shape': mosaic_labels[0]['ori_shape'],
|
||||
'resized_shape': (self.imgsz * 2, self.imgsz * 2),
|
||||
'cls': np.concatenate(cls, 0),
|
||||
'instances': Instances.concatenate(instances, axis=0),
|
||||
'mosaic_border': self.border}
|
||||
final_labels['instances'].clip(self.imgsz * 2, self.imgsz * 2)
|
||||
return final_labels
|
||||
|
||||
|
||||
@ -193,10 +193,10 @@ class MixUp(BaseMixTransform):
|
||||
def _mix_transform(self, labels):
|
||||
# Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
|
||||
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
|
||||
labels2 = labels["mix_labels"][0]
|
||||
labels["img"] = (labels["img"] * r + labels2["img"] * (1 - r)).astype(np.uint8)
|
||||
labels["instances"] = Instances.concatenate([labels["instances"], labels2["instances"]], axis=0)
|
||||
labels["cls"] = np.concatenate([labels["cls"], labels2["cls"]], 0)
|
||||
labels2 = labels['mix_labels'][0]
|
||||
labels['img'] = (labels['img'] * r + labels2['img'] * (1 - r)).astype(np.uint8)
|
||||
labels['instances'] = Instances.concatenate([labels['instances'], labels2['instances']], axis=0)
|
||||
labels['cls'] = np.concatenate([labels['cls'], labels2['cls']], 0)
|
||||
return labels
|
||||
|
||||
|
||||
@ -338,18 +338,18 @@ class RandomPerspective:
|
||||
Args:
|
||||
labels(Dict): a dict of `bboxes`, `segments`, `keypoints`.
|
||||
"""
|
||||
if self.pre_transform and "mosaic_border" not in labels:
|
||||
if self.pre_transform and 'mosaic_border' not in labels:
|
||||
labels = self.pre_transform(labels)
|
||||
labels.pop("ratio_pad") # do not need ratio pad
|
||||
labels.pop('ratio_pad') # do not need ratio pad
|
||||
|
||||
img = labels["img"]
|
||||
cls = labels["cls"]
|
||||
instances = labels.pop("instances")
|
||||
img = labels['img']
|
||||
cls = labels['cls']
|
||||
instances = labels.pop('instances')
|
||||
# make sure the coord formats are right
|
||||
instances.convert_bbox(format="xyxy")
|
||||
instances.convert_bbox(format='xyxy')
|
||||
instances.denormalize(*img.shape[:2][::-1])
|
||||
|
||||
border = labels.pop("mosaic_border", self.border)
|
||||
border = labels.pop('mosaic_border', self.border)
|
||||
self.size = img.shape[1] + border[1] * 2, img.shape[0] + border[0] * 2 # w, h
|
||||
# M is affine matrix
|
||||
# scale for func:`box_candidates`
|
||||
@ -365,7 +365,7 @@ class RandomPerspective:
|
||||
|
||||
if keypoints is not None:
|
||||
keypoints = self.apply_keypoints(keypoints, M)
|
||||
new_instances = Instances(bboxes, segments, keypoints, bbox_format="xyxy", normalized=False)
|
||||
new_instances = Instances(bboxes, segments, keypoints, bbox_format='xyxy', normalized=False)
|
||||
# clip
|
||||
new_instances.clip(*self.size)
|
||||
|
||||
@ -375,10 +375,10 @@ class RandomPerspective:
|
||||
i = self.box_candidates(box1=instances.bboxes.T,
|
||||
box2=new_instances.bboxes.T,
|
||||
area_thr=0.01 if len(segments) else 0.10)
|
||||
labels["instances"] = new_instances[i]
|
||||
labels["cls"] = cls[i]
|
||||
labels["img"] = img
|
||||
labels["resized_shape"] = img.shape[:2]
|
||||
labels['instances'] = new_instances[i]
|
||||
labels['cls'] = cls[i]
|
||||
labels['img'] = img
|
||||
labels['resized_shape'] = img.shape[:2]
|
||||
return labels
|
||||
|
||||
def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
|
||||
@ -397,7 +397,7 @@ class RandomHSV:
|
||||
self.vgain = vgain
|
||||
|
||||
def __call__(self, labels):
|
||||
img = labels["img"]
|
||||
img = labels['img']
|
||||
if self.hgain or self.sgain or self.vgain:
|
||||
r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1 # random gains
|
||||
hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
|
||||
@ -415,30 +415,30 @@ class RandomHSV:
|
||||
|
||||
class RandomFlip:
|
||||
|
||||
def __init__(self, p=0.5, direction="horizontal") -> None:
|
||||
assert direction in ["horizontal", "vertical"], f"Support direction `horizontal` or `vertical`, got {direction}"
|
||||
def __init__(self, p=0.5, direction='horizontal') -> None:
|
||||
assert direction in ['horizontal', 'vertical'], f'Support direction `horizontal` or `vertical`, got {direction}'
|
||||
assert 0 <= p <= 1.0
|
||||
|
||||
self.p = p
|
||||
self.direction = direction
|
||||
|
||||
def __call__(self, labels):
|
||||
img = labels["img"]
|
||||
instances = labels.pop("instances")
|
||||
instances.convert_bbox(format="xywh")
|
||||
img = labels['img']
|
||||
instances = labels.pop('instances')
|
||||
instances.convert_bbox(format='xywh')
|
||||
h, w = img.shape[:2]
|
||||
h = 1 if instances.normalized else h
|
||||
w = 1 if instances.normalized else w
|
||||
|
||||
# Flip up-down
|
||||
if self.direction == "vertical" and random.random() < self.p:
|
||||
if self.direction == 'vertical' and random.random() < self.p:
|
||||
img = np.flipud(img)
|
||||
instances.flipud(h)
|
||||
if self.direction == "horizontal" and random.random() < self.p:
|
||||
if self.direction == 'horizontal' and random.random() < self.p:
|
||||
img = np.fliplr(img)
|
||||
instances.fliplr(w)
|
||||
labels["img"] = np.ascontiguousarray(img)
|
||||
labels["instances"] = instances
|
||||
labels['img'] = np.ascontiguousarray(img)
|
||||
labels['instances'] = instances
|
||||
return labels
|
||||
|
||||
|
||||
@ -455,9 +455,9 @@ class LetterBox:
|
||||
def __call__(self, labels=None, image=None):
|
||||
if labels is None:
|
||||
labels = {}
|
||||
img = labels.get("img") if image is None else image
|
||||
img = labels.get('img') if image is None else image
|
||||
shape = img.shape[:2] # current shape [height, width]
|
||||
new_shape = labels.pop("rect_shape", self.new_shape)
|
||||
new_shape = labels.pop('rect_shape', self.new_shape)
|
||||
if isinstance(new_shape, int):
|
||||
new_shape = (new_shape, new_shape)
|
||||
|
||||
@ -479,8 +479,8 @@ class LetterBox:
|
||||
|
||||
dw /= 2 # divide padding into 2 sides
|
||||
dh /= 2
|
||||
if labels.get("ratio_pad"):
|
||||
labels["ratio_pad"] = (labels["ratio_pad"], (dw, dh)) # for evaluation
|
||||
if labels.get('ratio_pad'):
|
||||
labels['ratio_pad'] = (labels['ratio_pad'], (dw, dh)) # for evaluation
|
||||
|
||||
if shape[::-1] != new_unpad: # resize
|
||||
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||
@ -491,18 +491,18 @@ class LetterBox:
|
||||
|
||||
if len(labels):
|
||||
labels = self._update_labels(labels, ratio, dw, dh)
|
||||
labels["img"] = img
|
||||
labels["resized_shape"] = new_shape
|
||||
labels['img'] = img
|
||||
labels['resized_shape'] = new_shape
|
||||
return labels
|
||||
else:
|
||||
return img
|
||||
|
||||
def _update_labels(self, labels, ratio, padw, padh):
|
||||
"""Update labels"""
|
||||
labels["instances"].convert_bbox(format="xyxy")
|
||||
labels["instances"].denormalize(*labels["img"].shape[:2][::-1])
|
||||
labels["instances"].scale(*ratio)
|
||||
labels["instances"].add_padding(padw, padh)
|
||||
labels['instances'].convert_bbox(format='xyxy')
|
||||
labels['instances'].denormalize(*labels['img'].shape[:2][::-1])
|
||||
labels['instances'].scale(*ratio)
|
||||
labels['instances'].add_padding(padw, padh)
|
||||
return labels
|
||||
|
||||
|
||||
@ -513,11 +513,11 @@ class CopyPaste:
|
||||
|
||||
def __call__(self, labels):
|
||||
# Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
|
||||
im = labels["img"]
|
||||
cls = labels["cls"]
|
||||
im = labels['img']
|
||||
cls = labels['cls']
|
||||
h, w = im.shape[:2]
|
||||
instances = labels.pop("instances")
|
||||
instances.convert_bbox(format="xyxy")
|
||||
instances = labels.pop('instances')
|
||||
instances.convert_bbox(format='xyxy')
|
||||
instances.denormalize(w, h)
|
||||
if self.p and len(instances.segments):
|
||||
n = len(instances)
|
||||
@ -540,9 +540,9 @@ class CopyPaste:
|
||||
i = cv2.flip(im_new, 1).astype(bool)
|
||||
im[i] = result[i] # cv2.imwrite('debug.jpg', im) # debug
|
||||
|
||||
labels["img"] = im
|
||||
labels["cls"] = cls
|
||||
labels["instances"] = instances
|
||||
labels['img'] = im
|
||||
labels['cls'] = cls
|
||||
labels['instances'] = instances
|
||||
return labels
|
||||
|
||||
|
||||
@ -551,11 +551,11 @@ class Albumentations:
|
||||
def __init__(self, p=1.0):
|
||||
self.p = p
|
||||
self.transform = None
|
||||
prefix = colorstr("albumentations: ")
|
||||
prefix = colorstr('albumentations: ')
|
||||
try:
|
||||
import albumentations as A
|
||||
|
||||
check_version(A.__version__, "1.0.3", hard=True) # version requirement
|
||||
check_version(A.__version__, '1.0.3', hard=True) # version requirement
|
||||
|
||||
T = [
|
||||
A.Blur(p=0.01),
|
||||
@ -565,28 +565,28 @@ class Albumentations:
|
||||
A.RandomBrightnessContrast(p=0.0),
|
||||
A.RandomGamma(p=0.0),
|
||||
A.ImageCompression(quality_lower=75, p=0.0),] # transforms
|
||||
self.transform = A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"]))
|
||||
self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
|
||||
|
||||
LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
|
||||
LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
|
||||
except ImportError: # package not installed, skip
|
||||
pass
|
||||
except Exception as e:
|
||||
LOGGER.info(f"{prefix}{e}")
|
||||
LOGGER.info(f'{prefix}{e}')
|
||||
|
||||
def __call__(self, labels):
|
||||
im = labels["img"]
|
||||
cls = labels["cls"]
|
||||
im = labels['img']
|
||||
cls = labels['cls']
|
||||
if len(cls):
|
||||
labels["instances"].convert_bbox("xywh")
|
||||
labels["instances"].normalize(*im.shape[:2][::-1])
|
||||
bboxes = labels["instances"].bboxes
|
||||
labels['instances'].convert_bbox('xywh')
|
||||
labels['instances'].normalize(*im.shape[:2][::-1])
|
||||
bboxes = labels['instances'].bboxes
|
||||
# TODO: add supports of segments and keypoints
|
||||
if self.transform and random.random() < self.p:
|
||||
new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed
|
||||
labels["img"] = new["image"]
|
||||
labels["cls"] = np.array(new["class_labels"])
|
||||
bboxes = np.array(new["bboxes"])
|
||||
labels["instances"].update(bboxes=bboxes)
|
||||
labels['img'] = new['image']
|
||||
labels['cls'] = np.array(new['class_labels'])
|
||||
bboxes = np.array(new['bboxes'])
|
||||
labels['instances'].update(bboxes=bboxes)
|
||||
return labels
|
||||
|
||||
|
||||
@ -594,7 +594,7 @@ class Albumentations:
|
||||
class Format:
|
||||
|
||||
def __init__(self,
|
||||
bbox_format="xywh",
|
||||
bbox_format='xywh',
|
||||
normalize=True,
|
||||
return_mask=False,
|
||||
return_keypoint=False,
|
||||
@ -610,10 +610,10 @@ class Format:
|
||||
self.batch_idx = batch_idx # keep the batch indexes
|
||||
|
||||
def __call__(self, labels):
|
||||
img = labels.pop("img")
|
||||
img = labels.pop('img')
|
||||
h, w = img.shape[:2]
|
||||
cls = labels.pop("cls")
|
||||
instances = labels.pop("instances")
|
||||
cls = labels.pop('cls')
|
||||
instances = labels.pop('instances')
|
||||
instances.convert_bbox(format=self.bbox_format)
|
||||
instances.denormalize(w, h)
|
||||
nl = len(instances)
|
||||
@ -625,17 +625,17 @@ class Format:
|
||||
else:
|
||||
masks = torch.zeros(1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio,
|
||||
img.shape[1] // self.mask_ratio)
|
||||
labels["masks"] = masks
|
||||
labels['masks'] = masks
|
||||
if self.normalize:
|
||||
instances.normalize(w, h)
|
||||
labels["img"] = self._format_img(img)
|
||||
labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl)
|
||||
labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
|
||||
labels['img'] = self._format_img(img)
|
||||
labels['cls'] = torch.from_numpy(cls) if nl else torch.zeros(nl)
|
||||
labels['bboxes'] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
|
||||
if self.return_keypoint:
|
||||
labels["keypoints"] = torch.from_numpy(instances.keypoints) if nl else torch.zeros((nl, 17, 2))
|
||||
labels['keypoints'] = torch.from_numpy(instances.keypoints) if nl else torch.zeros((nl, 17, 2))
|
||||
# then we can use collate_fn
|
||||
if self.batch_idx:
|
||||
labels["batch_idx"] = torch.zeros(nl)
|
||||
labels['batch_idx'] = torch.zeros(nl)
|
||||
return labels
|
||||
|
||||
def _format_img(self, img):
|
||||
@ -676,15 +676,15 @@ def v8_transforms(dataset, imgsz, hyp):
|
||||
MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
|
||||
Albumentations(p=1.0),
|
||||
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
|
||||
RandomFlip(direction="vertical", p=hyp.flipud),
|
||||
RandomFlip(direction="horizontal", p=hyp.fliplr),]) # transforms
|
||||
RandomFlip(direction='vertical', p=hyp.flipud),
|
||||
RandomFlip(direction='horizontal', p=hyp.fliplr),]) # transforms
|
||||
|
||||
|
||||
# Classification augmentations -----------------------------------------------------------------------------------------
|
||||
def classify_transforms(size=224):
|
||||
# Transforms to apply if albumentations not installed
|
||||
if not isinstance(size, int):
|
||||
raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)")
|
||||
raise TypeError(f'classify_transforms() size {size} must be integer, not (list, tuple)')
|
||||
# T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
|
||||
return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
|
||||
|
||||
@ -701,17 +701,17 @@ def classify_albumentations(
|
||||
auto_aug=False,
|
||||
):
|
||||
# YOLOv8 classification Albumentations (optional, only used if package is installed)
|
||||
prefix = colorstr("albumentations: ")
|
||||
prefix = colorstr('albumentations: ')
|
||||
try:
|
||||
import albumentations as A
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
|
||||
check_version(A.__version__, "1.0.3", hard=True) # version requirement
|
||||
check_version(A.__version__, '1.0.3', hard=True) # version requirement
|
||||
if augment: # Resize and crop
|
||||
T = [A.RandomResizedCrop(height=size, width=size, scale=scale)]
|
||||
if auto_aug:
|
||||
# TODO: implement AugMix, AutoAug & RandAug in albumentation
|
||||
LOGGER.info(f"{prefix}auto augmentations are currently not supported")
|
||||
LOGGER.info(f'{prefix}auto augmentations are currently not supported')
|
||||
else:
|
||||
if hflip > 0:
|
||||
T += [A.HorizontalFlip(p=hflip)]
|
||||
@ -723,13 +723,13 @@ def classify_albumentations(
|
||||
else: # Use fixed crop for eval set (reproducibility)
|
||||
T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
|
||||
T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor
|
||||
LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
|
||||
LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
|
||||
return A.Compose(T)
|
||||
|
||||
except ImportError: # package not installed, skip
|
||||
pass
|
||||
except Exception as e:
|
||||
LOGGER.info(f"{prefix}{e}")
|
||||
LOGGER.info(f'{prefix}{e}')
|
||||
|
||||
|
||||
class ClassifyLetterBox:
|
||||
|
@ -31,7 +31,7 @@ class BaseDataset(Dataset):
|
||||
cache=False,
|
||||
augment=True,
|
||||
hyp=None,
|
||||
prefix="",
|
||||
prefix='',
|
||||
rect=False,
|
||||
batch_size=None,
|
||||
stride=32,
|
||||
@ -63,7 +63,7 @@ class BaseDataset(Dataset):
|
||||
|
||||
# cache stuff
|
||||
self.ims = [None] * self.ni
|
||||
self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files]
|
||||
self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
|
||||
if cache:
|
||||
self.cache_images(cache)
|
||||
|
||||
@ -77,21 +77,21 @@ class BaseDataset(Dataset):
|
||||
for p in img_path if isinstance(img_path, list) else [img_path]:
|
||||
p = Path(p) # os-agnostic
|
||||
if p.is_dir(): # dir
|
||||
f += glob.glob(str(p / "**" / "*.*"), recursive=True)
|
||||
f += glob.glob(str(p / '**' / '*.*'), recursive=True)
|
||||
# f = list(p.rglob('*.*')) # pathlib
|
||||
elif p.is_file(): # file
|
||||
with open(p) as t:
|
||||
t = t.read().strip().splitlines()
|
||||
parent = str(p.parent) + os.sep
|
||||
f += [x.replace("./", parent) if x.startswith("./") else x for x in t] # local to global path
|
||||
f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
|
||||
# f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
|
||||
else:
|
||||
raise FileNotFoundError(f"{self.prefix}{p} does not exist")
|
||||
im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
|
||||
raise FileNotFoundError(f'{self.prefix}{p} does not exist')
|
||||
im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
|
||||
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
|
||||
assert im_files, f"{self.prefix}No images found"
|
||||
assert im_files, f'{self.prefix}No images found'
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e
|
||||
raise FileNotFoundError(f'{self.prefix}Error loading data from {img_path}\n{HELP_URL}') from e
|
||||
return im_files
|
||||
|
||||
def update_labels(self, include_class: Optional[list]):
|
||||
@ -99,16 +99,16 @@ class BaseDataset(Dataset):
|
||||
include_class_array = np.array(include_class).reshape(1, -1)
|
||||
for i in range(len(self.labels)):
|
||||
if include_class:
|
||||
cls = self.labels[i]["cls"]
|
||||
bboxes = self.labels[i]["bboxes"]
|
||||
segments = self.labels[i]["segments"]
|
||||
cls = self.labels[i]['cls']
|
||||
bboxes = self.labels[i]['bboxes']
|
||||
segments = self.labels[i]['segments']
|
||||
j = (cls == include_class_array).any(1)
|
||||
self.labels[i]["cls"] = cls[j]
|
||||
self.labels[i]["bboxes"] = bboxes[j]
|
||||
self.labels[i]['cls'] = cls[j]
|
||||
self.labels[i]['bboxes'] = bboxes[j]
|
||||
if segments:
|
||||
self.labels[i]["segments"] = segments[j]
|
||||
self.labels[i]['segments'] = segments[j]
|
||||
if self.single_cls:
|
||||
self.labels[i]["cls"][:, 0] = 0
|
||||
self.labels[i]['cls'][:, 0] = 0
|
||||
|
||||
def load_image(self, i):
|
||||
# Loads 1 image from dataset index 'i', returns (im, resized hw)
|
||||
@ -119,7 +119,7 @@ class BaseDataset(Dataset):
|
||||
else: # read image
|
||||
im = cv2.imread(f) # BGR
|
||||
if im is None:
|
||||
raise FileNotFoundError(f"Image Not Found {f}")
|
||||
raise FileNotFoundError(f'Image Not Found {f}')
|
||||
h0, w0 = im.shape[:2] # orig hw
|
||||
r = self.imgsz / max(h0, w0) # ratio
|
||||
if r != 1: # if sizes are not equal
|
||||
@ -132,17 +132,17 @@ class BaseDataset(Dataset):
|
||||
# cache images to memory or disk
|
||||
gb = 0 # Gigabytes of cached images
|
||||
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
|
||||
fcn = self.cache_images_to_disk if cache == "disk" else self.load_image
|
||||
fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
results = pool.imap(fcn, range(self.ni))
|
||||
pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
|
||||
for i, x in pbar:
|
||||
if cache == "disk":
|
||||
if cache == 'disk':
|
||||
gb += self.npy_files[i].stat().st_size
|
||||
else: # 'ram'
|
||||
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
|
||||
gb += self.ims[i].nbytes
|
||||
pbar.desc = f"{self.prefix}Caching images ({gb / 1E9:.1f}GB {cache})"
|
||||
pbar.desc = f'{self.prefix}Caching images ({gb / 1E9:.1f}GB {cache})'
|
||||
pbar.close()
|
||||
|
||||
def cache_images_to_disk(self, i):
|
||||
@ -155,7 +155,7 @@ class BaseDataset(Dataset):
|
||||
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
|
||||
nb = bi[-1] + 1 # number of batches
|
||||
|
||||
s = np.array([x.pop("shape") for x in self.labels]) # hw
|
||||
s = np.array([x.pop('shape') for x in self.labels]) # hw
|
||||
ar = s[:, 0] / s[:, 1] # aspect ratio
|
||||
irect = ar.argsort()
|
||||
self.im_files = [self.im_files[i] for i in irect]
|
||||
@ -180,14 +180,14 @@ class BaseDataset(Dataset):
|
||||
|
||||
def get_label_info(self, index):
|
||||
label = self.labels[index].copy()
|
||||
label.pop("shape", None) # shape is for rect, remove it
|
||||
label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index)
|
||||
label["ratio_pad"] = (
|
||||
label["resized_shape"][0] / label["ori_shape"][0],
|
||||
label["resized_shape"][1] / label["ori_shape"][1],
|
||||
label.pop('shape', None) # shape is for rect, remove it
|
||||
label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index)
|
||||
label['ratio_pad'] = (
|
||||
label['resized_shape'][0] / label['ori_shape'][0],
|
||||
label['resized_shape'][1] / label['ori_shape'][1],
|
||||
) # for evaluation
|
||||
if self.rect:
|
||||
label["rect_shape"] = self.batch_shapes[self.batch[index]]
|
||||
label['rect_shape'] = self.batch_shapes[self.batch[index]]
|
||||
label = self.update_labels_info(label)
|
||||
return label
|
||||
|
||||
|
@ -28,7 +28,7 @@ class InfiniteDataLoader(dataloader.DataLoader):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))
|
||||
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
|
||||
self.iterator = super().__iter__()
|
||||
|
||||
def __len__(self):
|
||||
@ -61,9 +61,9 @@ def seed_worker(worker_id):
|
||||
random.seed(worker_seed)
|
||||
|
||||
|
||||
def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, rank=-1, mode="train"):
|
||||
assert mode in ["train", "val"]
|
||||
shuffle = mode == "train"
|
||||
def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, rank=-1, mode='train'):
|
||||
assert mode in ['train', 'val']
|
||||
shuffle = mode == 'train'
|
||||
if cfg.rect and shuffle:
|
||||
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
|
||||
shuffle = False
|
||||
@ -72,21 +72,21 @@ def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, ra
|
||||
img_path=img_path,
|
||||
imgsz=cfg.imgsz,
|
||||
batch_size=batch,
|
||||
augment=mode == "train", # augmentation
|
||||
augment=mode == 'train', # augmentation
|
||||
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
|
||||
rect=cfg.rect or rect, # rectangular batches
|
||||
cache=cfg.cache or None,
|
||||
single_cls=cfg.single_cls or False,
|
||||
stride=int(stride),
|
||||
pad=0.0 if mode == "train" else 0.5,
|
||||
prefix=colorstr(f"{mode}: "),
|
||||
use_segments=cfg.task == "segment",
|
||||
use_keypoints=cfg.task == "keypoint",
|
||||
pad=0.0 if mode == 'train' else 0.5,
|
||||
prefix=colorstr(f'{mode}: '),
|
||||
use_segments=cfg.task == 'segment',
|
||||
use_keypoints=cfg.task == 'keypoint',
|
||||
names=names)
|
||||
|
||||
batch = min(batch, len(dataset))
|
||||
nd = torch.cuda.device_count() # number of CUDA devices
|
||||
workers = cfg.workers if mode == "train" else cfg.workers * 2
|
||||
workers = cfg.workers if mode == 'train' else cfg.workers * 2
|
||||
nw = min([os.cpu_count() // max(nd, 1), batch if batch > 1 else 0, workers]) # number of workers
|
||||
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
||||
loader = DataLoader if cfg.image_weights or cfg.close_mosaic else InfiniteDataLoader # allow attribute updates
|
||||
@ -98,7 +98,7 @@ def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, ra
|
||||
num_workers=nw,
|
||||
sampler=sampler,
|
||||
pin_memory=PIN_MEMORY,
|
||||
collate_fn=getattr(dataset, "collate_fn", None),
|
||||
collate_fn=getattr(dataset, 'collate_fn', None),
|
||||
worker_init_fn=seed_worker,
|
||||
generator=generator), dataset
|
||||
|
||||
@ -151,7 +151,7 @@ def check_source(source):
|
||||
from_img = True
|
||||
else:
|
||||
raise Exception(
|
||||
"Unsupported type encountered! See docs for supported types https://docs.ultralytics.com/predict")
|
||||
'Unsupported type encountered! See docs for supported types https://docs.ultralytics.com/predict')
|
||||
|
||||
return source, webcam, screenshot, from_img, in_memory
|
||||
|
||||
|
@ -47,7 +47,7 @@ class LoadStreams:
|
||||
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc'
|
||||
check_requirements(('pafy', 'youtube_dl==2020.12.2'))
|
||||
import pafy # noqa
|
||||
s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
|
||||
s = pafy.new(s).getbest(preftype='mp4').url # YouTube URL
|
||||
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
|
||||
if s == 0 and (is_colab() or is_kaggle()):
|
||||
raise NotImplementedError("'source=0' webcam not supported in Colab and Kaggle notebooks. "
|
||||
@ -65,7 +65,7 @@ class LoadStreams:
|
||||
if not success or self.imgs[i] is None:
|
||||
raise ConnectionError(f'{st}Failed to read images from {s}')
|
||||
self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
|
||||
LOGGER.info(f"{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)")
|
||||
LOGGER.info(f'{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)')
|
||||
self.threads[i].start()
|
||||
LOGGER.info('') # newline
|
||||
|
||||
@ -145,11 +145,11 @@ class LoadScreenshots:
|
||||
|
||||
# Parse monitor shape
|
||||
monitor = self.sct.monitors[self.screen]
|
||||
self.top = monitor["top"] if top is None else (monitor["top"] + top)
|
||||
self.left = monitor["left"] if left is None else (monitor["left"] + left)
|
||||
self.width = width or monitor["width"]
|
||||
self.height = height or monitor["height"]
|
||||
self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}
|
||||
self.top = monitor['top'] if top is None else (monitor['top'] + top)
|
||||
self.left = monitor['left'] if left is None else (monitor['left'] + left)
|
||||
self.width = width or monitor['width']
|
||||
self.height = height or monitor['height']
|
||||
self.monitor = {'left': self.left, 'top': self.top, 'width': self.width, 'height': self.height}
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
@ -157,7 +157,7 @@ class LoadScreenshots:
|
||||
def __next__(self):
|
||||
# mss screen capture: get raw pixels from the screen as np array
|
||||
im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR
|
||||
s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
|
||||
s = f'screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: '
|
||||
|
||||
if self.transforms:
|
||||
im = self.transforms(im0) # transforms
|
||||
@ -172,7 +172,7 @@ class LoadScreenshots:
|
||||
class LoadImages:
|
||||
# YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4`
|
||||
def __init__(self, path, imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1):
|
||||
if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
|
||||
if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line
|
||||
path = Path(path).read_text().rsplit()
|
||||
files = []
|
||||
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
|
||||
@ -290,12 +290,12 @@ class LoadPilAndNumpy:
|
||||
self.transforms = transforms
|
||||
self.mode = 'image'
|
||||
# generate fake paths
|
||||
self.paths = [f"image{i}.jpg" for i in range(len(self.im0))]
|
||||
self.paths = [f'image{i}.jpg' for i in range(len(self.im0))]
|
||||
self.bs = len(self.im0)
|
||||
|
||||
@staticmethod
|
||||
def _single_check(im):
|
||||
assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}"
|
||||
assert isinstance(im, (Image.Image, np.ndarray)), f'Expected PIL/np.ndarray image type, but got {type(im)}'
|
||||
if isinstance(im, Image.Image):
|
||||
im = np.asarray(im)[:, :, ::-1]
|
||||
im = np.ascontiguousarray(im) # contiguous
|
||||
@ -338,16 +338,16 @@ def autocast_list(source):
|
||||
elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image
|
||||
files.append(im)
|
||||
else:
|
||||
raise TypeError(f"type {type(im).__name__} is not a supported Ultralytics prediction source type. \n"
|
||||
f"See https://docs.ultralytics.com/predict for supported source types.")
|
||||
raise TypeError(f'type {type(im).__name__} is not a supported Ultralytics prediction source type. \n'
|
||||
f'See https://docs.ultralytics.com/predict for supported source types.')
|
||||
|
||||
return files
|
||||
|
||||
|
||||
LOADERS = [LoadStreams, LoadPilAndNumpy, LoadImages, LoadScreenshots]
|
||||
|
||||
if __name__ == "__main__":
|
||||
img = cv2.imread(str(ROOT / "assets/bus.jpg"))
|
||||
if __name__ == '__main__':
|
||||
img = cv2.imread(str(ROOT / 'assets/bus.jpg'))
|
||||
dataset = LoadPilAndNumpy(im0=img)
|
||||
for d in dataset:
|
||||
print(d[0])
|
||||
|
@ -92,7 +92,7 @@ def exif_transpose(image):
|
||||
if method is not None:
|
||||
image = image.transpose(method)
|
||||
del exif[0x0112]
|
||||
image.info["exif"] = exif.tobytes()
|
||||
image.info['exif'] = exif.tobytes()
|
||||
return image
|
||||
|
||||
|
||||
@ -217,11 +217,11 @@ class LoadScreenshots:
|
||||
|
||||
# Parse monitor shape
|
||||
monitor = self.sct.monitors[self.screen]
|
||||
self.top = monitor["top"] if top is None else (monitor["top"] + top)
|
||||
self.left = monitor["left"] if left is None else (monitor["left"] + left)
|
||||
self.width = width or monitor["width"]
|
||||
self.height = height or monitor["height"]
|
||||
self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}
|
||||
self.top = monitor['top'] if top is None else (monitor['top'] + top)
|
||||
self.left = monitor['left'] if left is None else (monitor['left'] + left)
|
||||
self.width = width or monitor['width']
|
||||
self.height = height or monitor['height']
|
||||
self.monitor = {'left': self.left, 'top': self.top, 'width': self.width, 'height': self.height}
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
@ -229,7 +229,7 @@ class LoadScreenshots:
|
||||
def __next__(self):
|
||||
# mss screen capture: get raw pixels from the screen as np array
|
||||
im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR
|
||||
s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
|
||||
s = f'screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: '
|
||||
|
||||
if self.transforms:
|
||||
im = self.transforms(im0) # transforms
|
||||
@ -244,7 +244,7 @@ class LoadScreenshots:
|
||||
class LoadImages:
|
||||
# YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
|
||||
def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
|
||||
if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
|
||||
if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line
|
||||
path = Path(path).read_text().rsplit()
|
||||
files = []
|
||||
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
|
||||
@ -363,7 +363,7 @@ class LoadStreams:
|
||||
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc'
|
||||
check_requirements(('pafy', 'youtube_dl==2020.12.2'))
|
||||
import pafy
|
||||
s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
|
||||
s = pafy.new(s).getbest(preftype='mp4').url # YouTube URL
|
||||
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
|
||||
if s == 0:
|
||||
assert not is_colab(), '--source 0 webcam unsupported on Colab. Rerun command in a local environment.'
|
||||
@ -378,7 +378,7 @@ class LoadStreams:
|
||||
|
||||
_, self.imgs[i] = cap.read() # guarantee first frame
|
||||
self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
|
||||
LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
|
||||
LOGGER.info(f'{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)')
|
||||
self.threads[i].start()
|
||||
LOGGER.info('') # newline
|
||||
|
||||
@ -500,7 +500,7 @@ class LoadImagesAndLabels(Dataset):
|
||||
# Display cache
|
||||
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
|
||||
if exists and LOCAL_RANK in {-1, 0}:
|
||||
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||||
d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt'
|
||||
tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
|
||||
if cache['msgs']:
|
||||
LOGGER.info('\n'.join(cache['msgs'])) # display warnings
|
||||
@ -604,8 +604,8 @@ class LoadImagesAndLabels(Dataset):
|
||||
mem = psutil.virtual_memory()
|
||||
cache = mem_required * (1 + safety_margin) < mem.available # to cache or not to cache, that is the question
|
||||
if not cache:
|
||||
LOGGER.info(f"{prefix}{mem_required / gb:.1f}GB RAM required, "
|
||||
f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, "
|
||||
LOGGER.info(f'{prefix}{mem_required / gb:.1f}GB RAM required, '
|
||||
f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, '
|
||||
f"{'caching images ✅' if cache else 'not caching images ⚠️'}")
|
||||
return cache
|
||||
|
||||
@ -615,7 +615,7 @@ class LoadImagesAndLabels(Dataset):
|
||||
path.unlink() # remove *.cache file if exists
|
||||
x = {} # dict
|
||||
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
||||
desc = f"{prefix}Scanning {path.parent / path.stem}..."
|
||||
desc = f'{prefix}Scanning {path.parent / path.stem}...'
|
||||
total = len(self.im_files)
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
results = pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix)))
|
||||
@ -629,7 +629,7 @@ class LoadImagesAndLabels(Dataset):
|
||||
x[im_file] = [lb, shape, segments]
|
||||
if msg:
|
||||
msgs.append(msg)
|
||||
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||||
pbar.desc = f'{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt'
|
||||
pbar.close()
|
||||
|
||||
if msgs:
|
||||
@ -1060,7 +1060,7 @@ class HUBDatasetStats():
|
||||
if zipped:
|
||||
data['path'] = data_dir
|
||||
except Exception as e:
|
||||
raise Exception("error/HUB/dataset_stats/yaml_load") from e
|
||||
raise Exception('error/HUB/dataset_stats/yaml_load') from e
|
||||
|
||||
check_det_dataset(data, autodownload) # download dataset if missing
|
||||
self.hub_dir = Path(data['path'] + '-hub')
|
||||
@ -1187,7 +1187,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
||||
else: # read image
|
||||
im = cv2.imread(f) # BGR
|
||||
if self.album_transforms:
|
||||
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
|
||||
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))['image']
|
||||
else:
|
||||
sample = self.torch_transforms(im)
|
||||
return sample, j
|
||||
|
@ -28,7 +28,7 @@ class YOLODataset(BaseDataset):
|
||||
cache=False,
|
||||
augment=True,
|
||||
hyp=None,
|
||||
prefix="",
|
||||
prefix='',
|
||||
rect=False,
|
||||
batch_size=None,
|
||||
stride=32,
|
||||
@ -40,14 +40,14 @@ class YOLODataset(BaseDataset):
|
||||
self.use_segments = use_segments
|
||||
self.use_keypoints = use_keypoints
|
||||
self.names = names
|
||||
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
|
||||
assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.'
|
||||
super().__init__(img_path, imgsz, cache, augment, hyp, prefix, rect, batch_size, stride, pad, single_cls)
|
||||
|
||||
def cache_labels(self, path=Path("./labels.cache")):
|
||||
def cache_labels(self, path=Path('./labels.cache')):
|
||||
# Cache dataset labels, check images and read shapes
|
||||
x = {"labels": []}
|
||||
x = {'labels': []}
|
||||
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
||||
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
|
||||
desc = f'{self.prefix}Scanning {path.parent / path.stem}...'
|
||||
total = len(self.im_files)
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
results = pool.imap(func=verify_image_label,
|
||||
@ -60,7 +60,7 @@ class YOLODataset(BaseDataset):
|
||||
ne += ne_f
|
||||
nc += nc_f
|
||||
if im_file:
|
||||
x["labels"].append(
|
||||
x['labels'].append(
|
||||
dict(
|
||||
im_file=im_file,
|
||||
shape=shape,
|
||||
@ -69,68 +69,68 @@ class YOLODataset(BaseDataset):
|
||||
segments=segments,
|
||||
keypoints=keypoint,
|
||||
normalized=True,
|
||||
bbox_format="xywh"))
|
||||
bbox_format='xywh'))
|
||||
if msg:
|
||||
msgs.append(msg)
|
||||
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||||
pbar.desc = f'{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt'
|
||||
pbar.close()
|
||||
|
||||
if msgs:
|
||||
LOGGER.info("\n".join(msgs))
|
||||
LOGGER.info('\n'.join(msgs))
|
||||
if nf == 0:
|
||||
LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}")
|
||||
x["hash"] = get_hash(self.label_files + self.im_files)
|
||||
x["results"] = nf, nm, ne, nc, len(self.im_files)
|
||||
x["msgs"] = msgs # warnings
|
||||
x["version"] = self.cache_version # cache version
|
||||
LOGGER.warning(f'{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}')
|
||||
x['hash'] = get_hash(self.label_files + self.im_files)
|
||||
x['results'] = nf, nm, ne, nc, len(self.im_files)
|
||||
x['msgs'] = msgs # warnings
|
||||
x['version'] = self.cache_version # cache version
|
||||
if is_dir_writeable(path.parent):
|
||||
if path.exists():
|
||||
path.unlink() # remove *.cache file if exists
|
||||
np.save(str(path), x) # save cache for next time
|
||||
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
|
||||
LOGGER.info(f"{self.prefix}New cache created: {path}")
|
||||
path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
|
||||
LOGGER.info(f'{self.prefix}New cache created: {path}')
|
||||
else:
|
||||
LOGGER.warning(f"{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.")
|
||||
LOGGER.warning(f'{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.')
|
||||
return x
|
||||
|
||||
def get_labels(self):
|
||||
self.label_files = img2label_paths(self.im_files)
|
||||
cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
|
||||
cache_path = Path(self.label_files[0]).parent.with_suffix('.cache')
|
||||
try:
|
||||
cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dict
|
||||
assert cache["version"] == self.cache_version # matches current version
|
||||
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
|
||||
assert cache['version'] == self.cache_version # matches current version
|
||||
assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash
|
||||
except (FileNotFoundError, AssertionError, AttributeError):
|
||||
cache, exists = self.cache_labels(cache_path), False # run cache ops
|
||||
|
||||
# Display cache
|
||||
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
|
||||
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
|
||||
if exists and LOCAL_RANK in {-1, 0}:
|
||||
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||||
d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt'
|
||||
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
|
||||
if cache["msgs"]:
|
||||
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
||||
if cache['msgs']:
|
||||
LOGGER.info('\n'.join(cache['msgs'])) # display warnings
|
||||
if nf == 0: # number of labels found
|
||||
raise FileNotFoundError(f"{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}")
|
||||
raise FileNotFoundError(f'{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}')
|
||||
|
||||
# Read cache
|
||||
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
|
||||
labels = cache["labels"]
|
||||
self.im_files = [lb["im_file"] for lb in labels] # update im_files
|
||||
[cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
|
||||
labels = cache['labels']
|
||||
self.im_files = [lb['im_file'] for lb in labels] # update im_files
|
||||
|
||||
# Check if the dataset is all boxes or all segments
|
||||
len_cls = sum(len(lb["cls"]) for lb in labels)
|
||||
len_boxes = sum(len(lb["bboxes"]) for lb in labels)
|
||||
len_segments = sum(len(lb["segments"]) for lb in labels)
|
||||
len_cls = sum(len(lb['cls']) for lb in labels)
|
||||
len_boxes = sum(len(lb['bboxes']) for lb in labels)
|
||||
len_segments = sum(len(lb['segments']) for lb in labels)
|
||||
if len_segments and len_boxes != len_segments:
|
||||
LOGGER.warning(
|
||||
f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, "
|
||||
f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
|
||||
"To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.")
|
||||
f'WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, '
|
||||
f'len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. '
|
||||
'To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.')
|
||||
for lb in labels:
|
||||
lb["segments"] = []
|
||||
lb['segments'] = []
|
||||
if len_cls == 0:
|
||||
raise ValueError(f"All labels empty in {cache_path}, can not start training without labels. {HELP_URL}")
|
||||
raise ValueError(f'All labels empty in {cache_path}, can not start training without labels. {HELP_URL}')
|
||||
return labels
|
||||
|
||||
# TODO: use hyp config to set all these augmentations
|
||||
@ -142,7 +142,7 @@ class YOLODataset(BaseDataset):
|
||||
else:
|
||||
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
|
||||
transforms.append(
|
||||
Format(bbox_format="xywh",
|
||||
Format(bbox_format='xywh',
|
||||
normalize=True,
|
||||
return_mask=self.use_segments,
|
||||
return_keypoint=self.use_keypoints,
|
||||
@ -161,12 +161,12 @@ class YOLODataset(BaseDataset):
|
||||
"""custom your label format here"""
|
||||
# NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label
|
||||
# we can make it also support classification and semantic segmentation by add or remove some dict keys there.
|
||||
bboxes = label.pop("bboxes")
|
||||
segments = label.pop("segments")
|
||||
keypoints = label.pop("keypoints", None)
|
||||
bbox_format = label.pop("bbox_format")
|
||||
normalized = label.pop("normalized")
|
||||
label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
|
||||
bboxes = label.pop('bboxes')
|
||||
segments = label.pop('segments')
|
||||
keypoints = label.pop('keypoints', None)
|
||||
bbox_format = label.pop('bbox_format')
|
||||
normalized = label.pop('normalized')
|
||||
label['instances'] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
|
||||
return label
|
||||
|
||||
@staticmethod
|
||||
@ -176,15 +176,15 @@ class YOLODataset(BaseDataset):
|
||||
values = list(zip(*[list(b.values()) for b in batch]))
|
||||
for i, k in enumerate(keys):
|
||||
value = values[i]
|
||||
if k == "img":
|
||||
if k == 'img':
|
||||
value = torch.stack(value, 0)
|
||||
if k in ["masks", "keypoints", "bboxes", "cls"]:
|
||||
if k in ['masks', 'keypoints', 'bboxes', 'cls']:
|
||||
value = torch.cat(value, 0)
|
||||
new_batch[k] = value
|
||||
new_batch["batch_idx"] = list(new_batch["batch_idx"])
|
||||
for i in range(len(new_batch["batch_idx"])):
|
||||
new_batch["batch_idx"][i] += i # add target image index for build_targets()
|
||||
new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
|
||||
new_batch['batch_idx'] = list(new_batch['batch_idx'])
|
||||
for i in range(len(new_batch['batch_idx'])):
|
||||
new_batch['batch_idx'][i] += i # add target image index for build_targets()
|
||||
new_batch['batch_idx'] = torch.cat(new_batch['batch_idx'], 0)
|
||||
return new_batch
|
||||
|
||||
|
||||
@ -202,9 +202,9 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
||||
super().__init__(root=root)
|
||||
self.torch_transforms = classify_transforms(imgsz)
|
||||
self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
|
||||
self.cache_ram = cache is True or cache == "ram"
|
||||
self.cache_disk = cache == "disk"
|
||||
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
|
||||
self.cache_ram = cache is True or cache == 'ram'
|
||||
self.cache_disk = cache == 'disk'
|
||||
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
|
||||
|
||||
def __getitem__(self, i):
|
||||
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
|
||||
@ -217,7 +217,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
||||
else: # read image
|
||||
im = cv2.imread(f) # BGR
|
||||
if self.album_transforms:
|
||||
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
|
||||
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))['image']
|
||||
else:
|
||||
sample = self.torch_transforms(im)
|
||||
return {'img': sample, 'cls': j}
|
||||
|
@ -25,15 +25,15 @@ class MixAndRectDataset:
|
||||
labels = deepcopy(self.dataset[index])
|
||||
for transform in self.dataset.transforms.tolist():
|
||||
# mosaic and mixup
|
||||
if hasattr(transform, "get_indexes"):
|
||||
if hasattr(transform, 'get_indexes'):
|
||||
indexes = transform.get_indexes(self.dataset)
|
||||
if not isinstance(indexes, collections.abc.Sequence):
|
||||
indexes = [indexes]
|
||||
mix_labels = [deepcopy(self.dataset[index]) for index in indexes]
|
||||
labels["mix_labels"] = mix_labels
|
||||
labels['mix_labels'] = mix_labels
|
||||
if self.dataset.rect and isinstance(transform, LetterBox):
|
||||
transform.new_shape = self.dataset.batch_shapes[self.dataset.batch[index]]
|
||||
labels = transform(labels)
|
||||
if "mix_labels" in labels:
|
||||
labels.pop("mix_labels")
|
||||
if 'mix_labels' in labels:
|
||||
labels.pop('mix_labels')
|
||||
return labels
|
||||
|
@ -55,4 +55,4 @@ download: |
|
||||
for r in x[images == im]:
|
||||
w, h = r[6], r[7] # image width, height
|
||||
xywh = xyxy2xywh(np.array([[r[1] / w, r[2] / h, r[3] / w, r[4] / h]]))[0] # instance
|
||||
f.write(f"{cls} {xywh[0]:.5f} {xywh[1]:.5f} {xywh[2]:.5f} {xywh[3]:.5f}\n") # write label
|
||||
f.write(f"{cls} {xywh[0]:.5f} {xywh[1]:.5f} {xywh[2]:.5f} {xywh[3]:.5f}\n") # write label
|
||||
|
@ -112,4 +112,4 @@ download: |
|
||||
urls = ['http://images.cocodataset.org/zips/train2017.zip', # 19G, 118k images
|
||||
'http://images.cocodataset.org/zips/val2017.zip', # 1G, 5k images
|
||||
'http://images.cocodataset.org/zips/test2017.zip'] # 7G, 41k images (optional)
|
||||
download(urls, dir=dir / 'images', threads=3)
|
||||
download(urls, dir=dir / 'images', threads=3)
|
||||
|
@ -98,4 +98,4 @@ names:
|
||||
|
||||
|
||||
# Download script/URL (optional)
|
||||
download: https://ultralytics.com/assets/coco128-seg.zip
|
||||
download: https://ultralytics.com/assets/coco128-seg.zip
|
||||
|
@ -98,4 +98,4 @@ names:
|
||||
|
||||
|
||||
# Download script/URL (optional)
|
||||
download: https://ultralytics.com/assets/coco128.zip
|
||||
download: https://ultralytics.com/assets/coco128.zip
|
||||
|
@ -98,4 +98,4 @@ names:
|
||||
|
||||
|
||||
# Download script/URL (optional)
|
||||
download: https://ultralytics.com/assets/coco8-seg.zip
|
||||
download: https://ultralytics.com/assets/coco8-seg.zip
|
||||
|
@ -98,4 +98,4 @@ names:
|
||||
|
||||
|
||||
# Download script/URL (optional)
|
||||
download: https://ultralytics.com/assets/coco8.zip
|
||||
download: https://ultralytics.com/assets/coco8.zip
|
||||
|
@ -18,32 +18,32 @@ from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii
|
||||
from ultralytics.yolo.utils.downloads import download, safe_download
|
||||
from ultralytics.yolo.utils.ops import segments2boxes
|
||||
|
||||
HELP_URL = "See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data"
|
||||
IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # include image suffixes
|
||||
VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv" # include video suffixes
|
||||
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||
HELP_URL = 'See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
|
||||
IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # include image suffixes
|
||||
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
|
||||
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||
RANK = int(os.getenv('RANK', -1))
|
||||
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
|
||||
PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
|
||||
IMAGENET_MEAN = 0.485, 0.456, 0.406 # RGB mean
|
||||
IMAGENET_STD = 0.229, 0.224, 0.225 # RGB standard deviation
|
||||
|
||||
# Get orientation exif tag
|
||||
for orientation in ExifTags.TAGS.keys():
|
||||
if ExifTags.TAGS[orientation] == "Orientation":
|
||||
if ExifTags.TAGS[orientation] == 'Orientation':
|
||||
break
|
||||
|
||||
|
||||
def img2label_paths(img_paths):
|
||||
# Define label paths as a function of image paths
|
||||
sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings
|
||||
return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
|
||||
sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}' # /images/, /labels/ substrings
|
||||
return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
|
||||
|
||||
|
||||
def get_hash(paths):
|
||||
# Returns a single hash value of a list of paths (files or dirs)
|
||||
size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
|
||||
h = hashlib.sha256(str(size).encode()) # hash sizes
|
||||
h.update("".join(paths).encode()) # hash paths
|
||||
h.update(''.join(paths).encode()) # hash paths
|
||||
return h.hexdigest() # return hash
|
||||
|
||||
|
||||
@ -61,21 +61,21 @@ def verify_image_label(args):
|
||||
# Verify one image-label pair
|
||||
im_file, lb_file, prefix, keypoint, num_cls = args
|
||||
# number (missing, found, empty, corrupt), message, segments, keypoints
|
||||
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None
|
||||
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, '', [], None
|
||||
try:
|
||||
# verify images
|
||||
im = Image.open(im_file)
|
||||
im.verify() # PIL verify
|
||||
shape = exif_size(im) # image size
|
||||
shape = (shape[1], shape[0]) # hw
|
||||
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
||||
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
|
||||
if im.format.lower() in ("jpg", "jpeg"):
|
||||
with open(im_file, "rb") as f:
|
||||
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
|
||||
assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
|
||||
if im.format.lower() in ('jpg', 'jpeg'):
|
||||
with open(im_file, 'rb') as f:
|
||||
f.seek(-2, 2)
|
||||
if f.read() != b"\xff\xd9": # corrupt JPEG
|
||||
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
|
||||
msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved"
|
||||
if f.read() != b'\xff\xd9': # corrupt JPEG
|
||||
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
|
||||
msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'
|
||||
|
||||
# verify labels
|
||||
if os.path.isfile(lb_file):
|
||||
@ -90,31 +90,31 @@ def verify_image_label(args):
|
||||
nl = len(lb)
|
||||
if nl:
|
||||
if keypoint:
|
||||
assert lb.shape[1] == 56, "labels require 56 columns each"
|
||||
assert (lb[:, 5::3] <= 1).all(), "non-normalized or out of bounds coordinate labels"
|
||||
assert (lb[:, 6::3] <= 1).all(), "non-normalized or out of bounds coordinate labels"
|
||||
assert lb.shape[1] == 56, 'labels require 56 columns each'
|
||||
assert (lb[:, 5::3] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
|
||||
assert (lb[:, 6::3] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
|
||||
kpts = np.zeros((lb.shape[0], 39))
|
||||
for i in range(len(lb)):
|
||||
kpt = np.delete(lb[i, 5:], np.arange(2, lb.shape[1] - 5, 3)) # remove occlusion param from GT
|
||||
kpts[i] = np.hstack((lb[i, :5], kpt))
|
||||
lb = kpts
|
||||
assert lb.shape[1] == 39, "labels require 39 columns each after removing occlusion parameter"
|
||||
assert lb.shape[1] == 39, 'labels require 39 columns each after removing occlusion parameter'
|
||||
else:
|
||||
assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
|
||||
assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
|
||||
assert (lb[:, 1:] <= 1).all(), \
|
||||
f"non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}"
|
||||
f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}'
|
||||
# All labels
|
||||
max_cls = int(lb[:, 0].max()) # max label count
|
||||
assert max_cls <= num_cls, \
|
||||
f'Label class {max_cls} exceeds dataset class count {num_cls}. ' \
|
||||
f'Possible class labels are 0-{num_cls - 1}'
|
||||
assert (lb >= 0).all(), f"negative label values {lb[lb < 0]}"
|
||||
assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
|
||||
_, i = np.unique(lb, axis=0, return_index=True)
|
||||
if len(i) < nl: # duplicate row check
|
||||
lb = lb[i] # remove duplicates
|
||||
if segments:
|
||||
segments = [segments[x] for x in i]
|
||||
msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed"
|
||||
msg = f'{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed'
|
||||
else:
|
||||
ne = 1 # label empty
|
||||
lb = np.zeros((0, 39), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32)
|
||||
@ -127,7 +127,7 @@ def verify_image_label(args):
|
||||
return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
|
||||
except Exception as e:
|
||||
nc = 1
|
||||
msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
|
||||
msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
|
||||
return [None, None, None, None, None, nm, nf, ne, nc, msg]
|
||||
|
||||
|
||||
@ -248,8 +248,8 @@ def check_det_dataset(dataset, autodownload=True):
|
||||
else: # python script
|
||||
r = exec(s, {'yaml': data}) # return None
|
||||
dt = f'({round(time.time() - t, 1)}s)'
|
||||
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
|
||||
LOGGER.info(f"Dataset download {s}\n")
|
||||
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f'failure {dt} ❌'
|
||||
LOGGER.info(f'Dataset download {s}\n')
|
||||
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf') # download fonts
|
||||
|
||||
return data # dictionary
|
||||
@ -284,9 +284,9 @@ def check_cls_dataset(dataset: str):
|
||||
download(url, dir=data_dir.parent)
|
||||
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
|
||||
LOGGER.info(s)
|
||||
train_set = data_dir / "train"
|
||||
train_set = data_dir / 'train'
|
||||
test_set = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val' # data/test or data/val
|
||||
nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
|
||||
names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
|
||||
names = dict(enumerate(sorted(names)))
|
||||
return {"train": train_set, "val": test_set, "nc": nc, "names": names}
|
||||
return {'train': train_set, 'val': test_set, 'nc': nc, 'names': names}
|
||||
|
Reference in New Issue
Block a user