Import YOLOv5 dataloader (#94)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
		
							
								
								
									
										10
									
								
								.github/workflows/ci.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										10
									
								
								.github/workflows/ci.yaml
									
									
									
									
										vendored
									
									
								
							| @ -20,21 +20,21 @@ jobs: | ||||
|       matrix: | ||||
|         os: [ ubuntu-latest ] | ||||
|         python-version: [ '3.10' ] | ||||
|         model: [ yolov5n ] | ||||
|         model: [ yolov8n ] | ||||
|         torch: [ latest ] | ||||
| #        include: | ||||
| #          - os: ubuntu-latest | ||||
| #            python-version: '3.7'  # '3.6.8' min | ||||
| #            model: yolov5n | ||||
| #            model: yolov8n | ||||
| #          - os: ubuntu-latest | ||||
| #            python-version: '3.8' | ||||
| #            model: yolov5n | ||||
| #            model: yolov8n | ||||
| #          - os: ubuntu-latest | ||||
| #            python-version: '3.9' | ||||
| #            model: yolov5n | ||||
| #            model: yolov8n | ||||
| #          - os: ubuntu-latest | ||||
| #            python-version: '3.8'  # torch 1.7.0 requires python >=3.6, <=3.8 | ||||
| #            model: yolov5n | ||||
| #            model: yolov8n | ||||
| #            torch: '1.7.0'  # min torch version CI https://pypi.org/project/torchvision/ | ||||
|     steps: | ||||
|       - uses: actions/checkout@v3 | ||||
|  | ||||
							
								
								
									
										402
									
								
								ultralytics/yolo/data/dataloaders/v5augmentations.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										402
									
								
								ultralytics/yolo/data/dataloaders/v5augmentations.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,402 @@ | ||||
| # YOLOv5 🚀 by Ultralytics, GPL-3.0 license | ||||
| """ | ||||
| Image augmentation functions | ||||
| """ | ||||
|  | ||||
| import math | ||||
| import random | ||||
|  | ||||
| import cv2 | ||||
| import numpy as np | ||||
| import torch | ||||
| import torchvision.transforms as T | ||||
| import torchvision.transforms.functional as TF | ||||
|  | ||||
| from ultralytics.yolo.utils import LOGGER, colorstr | ||||
| from ultralytics.yolo.utils.checks import check_version | ||||
| from ultralytics.yolo.utils.metrics import bbox_ioa | ||||
| from ultralytics.yolo.utils.ops import resample_segments, segment2box, xywhn2xyxy | ||||
|  | ||||
| IMAGENET_MEAN = 0.485, 0.456, 0.406  # RGB mean | ||||
| IMAGENET_STD = 0.229, 0.224, 0.225  # RGB standard deviation | ||||
|  | ||||
|  | ||||
| class Albumentations: | ||||
|     # YOLOv5 Albumentations class (optional, only used if package is installed) | ||||
|     def __init__(self, size=640): | ||||
|         self.transform = None | ||||
|         prefix = colorstr('albumentations: ') | ||||
|         try: | ||||
|             import albumentations as A | ||||
|             check_version(A.__version__, '1.0.3', hard=True)  # version requirement | ||||
|  | ||||
|             T = [ | ||||
|                 A.RandomResizedCrop(height=size, width=size, scale=(0.8, 1.0), ratio=(0.9, 1.11), p=0.0), | ||||
|                 A.Blur(p=0.01), | ||||
|                 A.MedianBlur(p=0.01), | ||||
|                 A.ToGray(p=0.01), | ||||
|                 A.CLAHE(p=0.01), | ||||
|                 A.RandomBrightnessContrast(p=0.0), | ||||
|                 A.RandomGamma(p=0.0), | ||||
|                 A.ImageCompression(quality_lower=75, p=0.0)]  # transforms | ||||
|             self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels'])) | ||||
|  | ||||
|             LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p)) | ||||
|         except ImportError:  # package not installed, skip | ||||
|             pass | ||||
|         except Exception as e: | ||||
|             LOGGER.info(f'{prefix}{e}') | ||||
|  | ||||
|     def __call__(self, im, labels, p=1.0): | ||||
|         if self.transform and random.random() < p: | ||||
|             new = self.transform(image=im, bboxes=labels[:, 1:], class_labels=labels[:, 0])  # transformed | ||||
|             im, labels = new['image'], np.array([[c, *b] for c, b in zip(new['class_labels'], new['bboxes'])]) | ||||
|         return im, labels | ||||
|  | ||||
|  | ||||
| def normalize(x, mean=IMAGENET_MEAN, std=IMAGENET_STD, inplace=False): | ||||
|     # Denormalize RGB images x per ImageNet stats in BCHW format, i.e. = (x - mean) / std | ||||
|     return TF.normalize(x, mean, std, inplace=inplace) | ||||
|  | ||||
|  | ||||
| def denormalize(x, mean=IMAGENET_MEAN, std=IMAGENET_STD): | ||||
|     # Denormalize RGB images x per ImageNet stats in BCHW format, i.e. = x * std + mean | ||||
|     for i in range(3): | ||||
|         x[:, i] = x[:, i] * std[i] + mean[i] | ||||
|     return x | ||||
|  | ||||
|  | ||||
| def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5): | ||||
|     # HSV color-space augmentation | ||||
|     if hgain or sgain or vgain: | ||||
|         r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1  # random gains | ||||
|         hue, sat, val = cv2.split(cv2.cvtColor(im, cv2.COLOR_BGR2HSV)) | ||||
|         dtype = im.dtype  # uint8 | ||||
|  | ||||
|         x = np.arange(0, 256, dtype=r.dtype) | ||||
|         lut_hue = ((x * r[0]) % 180).astype(dtype) | ||||
|         lut_sat = np.clip(x * r[1], 0, 255).astype(dtype) | ||||
|         lut_val = np.clip(x * r[2], 0, 255).astype(dtype) | ||||
|  | ||||
|         im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))) | ||||
|         cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=im)  # no return needed | ||||
|  | ||||
|  | ||||
| def hist_equalize(im, clahe=True, bgr=False): | ||||
|     # Equalize histogram on BGR image 'im' with im.shape(n,m,3) and range 0-255 | ||||
|     yuv = cv2.cvtColor(im, cv2.COLOR_BGR2YUV if bgr else cv2.COLOR_RGB2YUV) | ||||
|     if clahe: | ||||
|         c = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | ||||
|         yuv[:, :, 0] = c.apply(yuv[:, :, 0]) | ||||
|     else: | ||||
|         yuv[:, :, 0] = cv2.equalizeHist(yuv[:, :, 0])  # equalize Y channel histogram | ||||
|     return cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR if bgr else cv2.COLOR_YUV2RGB)  # convert YUV image to RGB | ||||
|  | ||||
|  | ||||
| def replicate(im, labels): | ||||
|     # Replicate labels | ||||
|     h, w = im.shape[:2] | ||||
|     boxes = labels[:, 1:].astype(int) | ||||
|     x1, y1, x2, y2 = boxes.T | ||||
|     s = ((x2 - x1) + (y2 - y1)) / 2  # side length (pixels) | ||||
|     for i in s.argsort()[:round(s.size * 0.5)]:  # smallest indices | ||||
|         x1b, y1b, x2b, y2b = boxes[i] | ||||
|         bh, bw = y2b - y1b, x2b - x1b | ||||
|         yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw))  # offset x, y | ||||
|         x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh] | ||||
|         im[y1a:y2a, x1a:x2a] = im[y1b:y2b, x1b:x2b]  # im4[ymin:ymax, xmin:xmax] | ||||
|         labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0) | ||||
|  | ||||
|     return im, labels | ||||
|  | ||||
|  | ||||
| def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32): | ||||
|     # Resize and pad image while meeting stride-multiple constraints | ||||
|     shape = im.shape[:2]  # current shape [height, width] | ||||
|     if isinstance(new_shape, int): | ||||
|         new_shape = (new_shape, new_shape) | ||||
|  | ||||
|     # Scale ratio (new / old) | ||||
|     r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) | ||||
|     if not scaleup:  # only scale down, do not scale up (for better val mAP) | ||||
|         r = min(r, 1.0) | ||||
|  | ||||
|     # Compute padding | ||||
|     ratio = r, r  # width, height ratios | ||||
|     new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) | ||||
|     dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh padding | ||||
|     if auto:  # minimum rectangle | ||||
|         dw, dh = np.mod(dw, stride), np.mod(dh, stride)  # wh padding | ||||
|     elif scaleFill:  # stretch | ||||
|         dw, dh = 0.0, 0.0 | ||||
|         new_unpad = (new_shape[1], new_shape[0]) | ||||
|         ratio = new_shape[1] / shape[1], new_shape[0] / shape[0]  # width, height ratios | ||||
|  | ||||
|     dw /= 2  # divide padding into 2 sides | ||||
|     dh /= 2 | ||||
|  | ||||
|     if shape[::-1] != new_unpad:  # resize | ||||
|         im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) | ||||
|     top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) | ||||
|     left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) | ||||
|     im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add border | ||||
|     return im, ratio, (dw, dh) | ||||
|  | ||||
|  | ||||
| def random_perspective(im, | ||||
|                        targets=(), | ||||
|                        segments=(), | ||||
|                        degrees=10, | ||||
|                        translate=.1, | ||||
|                        scale=.1, | ||||
|                        shear=10, | ||||
|                        perspective=0.0, | ||||
|                        border=(0, 0)): | ||||
|     # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(0.1, 0.1), scale=(0.9, 1.1), shear=(-10, 10)) | ||||
|     # targets = [cls, xyxy] | ||||
|  | ||||
|     height = im.shape[0] + border[0] * 2  # shape(h,w,c) | ||||
|     width = im.shape[1] + border[1] * 2 | ||||
|  | ||||
|     # Center | ||||
|     C = np.eye(3) | ||||
|     C[0, 2] = -im.shape[1] / 2  # x translation (pixels) | ||||
|     C[1, 2] = -im.shape[0] / 2  # y translation (pixels) | ||||
|  | ||||
|     # Perspective | ||||
|     P = np.eye(3) | ||||
|     P[2, 0] = random.uniform(-perspective, perspective)  # x perspective (about y) | ||||
|     P[2, 1] = random.uniform(-perspective, perspective)  # y perspective (about x) | ||||
|  | ||||
|     # Rotation and Scale | ||||
|     R = np.eye(3) | ||||
|     a = random.uniform(-degrees, degrees) | ||||
|     # a += random.choice([-180, -90, 0, 90])  # add 90deg rotations to small rotations | ||||
|     s = random.uniform(1 - scale, 1 + scale) | ||||
|     # s = 2 ** random.uniform(-scale, scale) | ||||
|     R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s) | ||||
|  | ||||
|     # Shear | ||||
|     S = np.eye(3) | ||||
|     S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180)  # x shear (deg) | ||||
|     S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180)  # y shear (deg) | ||||
|  | ||||
|     # Translation | ||||
|     T = np.eye(3) | ||||
|     T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width  # x translation (pixels) | ||||
|     T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height  # y translation (pixels) | ||||
|  | ||||
|     # Combined rotation matrix | ||||
|     M = T @ S @ R @ P @ C  # order of operations (right to left) is IMPORTANT | ||||
|     if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any():  # image changed | ||||
|         if perspective: | ||||
|             im = cv2.warpPerspective(im, M, dsize=(width, height), borderValue=(114, 114, 114)) | ||||
|         else:  # affine | ||||
|             im = cv2.warpAffine(im, M[:2], dsize=(width, height), borderValue=(114, 114, 114)) | ||||
|  | ||||
|     # Visualize | ||||
|     # import matplotlib.pyplot as plt | ||||
|     # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel() | ||||
|     # ax[0].imshow(im[:, :, ::-1])  # base | ||||
|     # ax[1].imshow(im2[:, :, ::-1])  # warped | ||||
|  | ||||
|     # Transform label coordinates | ||||
|     n = len(targets) | ||||
|     if n: | ||||
|         use_segments = any(x.any() for x in segments) | ||||
|         new = np.zeros((n, 4)) | ||||
|         if use_segments:  # warp segments | ||||
|             segments = resample_segments(segments)  # upsample | ||||
|             for i, segment in enumerate(segments): | ||||
|                 xy = np.ones((len(segment), 3)) | ||||
|                 xy[:, :2] = segment | ||||
|                 xy = xy @ M.T  # transform | ||||
|                 xy = xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]  # perspective rescale or affine | ||||
|  | ||||
|                 # clip | ||||
|                 new[i] = segment2box(xy, width, height) | ||||
|  | ||||
|         else:  # warp boxes | ||||
|             xy = np.ones((n * 4, 3)) | ||||
|             xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2)  # x1y1, x2y2, x1y2, x2y1 | ||||
|             xy = xy @ M.T  # transform | ||||
|             xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8)  # perspective rescale or affine | ||||
|  | ||||
|             # create new boxes | ||||
|             x = xy[:, [0, 2, 4, 6]] | ||||
|             y = xy[:, [1, 3, 5, 7]] | ||||
|             new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T | ||||
|  | ||||
|             # clip | ||||
|             new[:, [0, 2]] = new[:, [0, 2]].clip(0, width) | ||||
|             new[:, [1, 3]] = new[:, [1, 3]].clip(0, height) | ||||
|  | ||||
|         # filter candidates | ||||
|         i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01 if use_segments else 0.10) | ||||
|         targets = targets[i] | ||||
|         targets[:, 1:5] = new[i] | ||||
|  | ||||
|     return im, targets | ||||
|  | ||||
|  | ||||
| def copy_paste(im, labels, segments, p=0.5): | ||||
|     # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy) | ||||
|     n = len(segments) | ||||
|     if p and n: | ||||
|         h, w, c = im.shape  # height, width, channels | ||||
|         im_new = np.zeros(im.shape, np.uint8) | ||||
|  | ||||
|         # calculate ioa first then select indexes randomly | ||||
|         boxes = np.stack([w - labels[:, 3], labels[:, 2], w - labels[:, 1], labels[:, 4]], axis=-1)  # (n, 4) | ||||
|         ioa = bbox_ioa(boxes, labels[:, 1:5])  # intersection over area | ||||
|         indexes = np.nonzero((ioa < 0.30).all(1))[0]  # (N, ) | ||||
|         n = len(indexes) | ||||
|         for j in random.sample(list(indexes), k=round(p * n)): | ||||
|             l, box, s = labels[j], boxes[j], segments[j] | ||||
|             labels = np.concatenate((labels, [[l[0], *box]]), 0) | ||||
|             segments.append(np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1)) | ||||
|             cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (1, 1, 1), cv2.FILLED) | ||||
|  | ||||
|         result = cv2.flip(im, 1)  # augment segments (flip left-right) | ||||
|         i = cv2.flip(im_new, 1).astype(bool) | ||||
|         im[i] = result[i]  # cv2.imwrite('debug.jpg', im)  # debug | ||||
|  | ||||
|     return im, labels, segments | ||||
|  | ||||
|  | ||||
| def cutout(im, labels, p=0.5): | ||||
|     # Applies image cutout augmentation https://arxiv.org/abs/1708.04552 | ||||
|     if random.random() < p: | ||||
|         h, w = im.shape[:2] | ||||
|         scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16  # image size fraction | ||||
|         for s in scales: | ||||
|             mask_h = random.randint(1, int(h * s))  # create random masks | ||||
|             mask_w = random.randint(1, int(w * s)) | ||||
|  | ||||
|             # box | ||||
|             xmin = max(0, random.randint(0, w) - mask_w // 2) | ||||
|             ymin = max(0, random.randint(0, h) - mask_h // 2) | ||||
|             xmax = min(w, xmin + mask_w) | ||||
|             ymax = min(h, ymin + mask_h) | ||||
|  | ||||
|             # apply random color mask | ||||
|             im[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)] | ||||
|  | ||||
|             # return unobscured labels | ||||
|             if len(labels) and s > 0.03: | ||||
|                 box = np.array([[xmin, ymin, xmax, ymax]], dtype=np.float32) | ||||
|                 ioa = bbox_ioa(box, xywhn2xyxy(labels[:, 1:5], w, h))[0]  # intersection over area | ||||
|                 labels = labels[ioa < 0.60]  # remove >60% obscured labels | ||||
|  | ||||
|     return labels | ||||
|  | ||||
|  | ||||
| def mixup(im, labels, im2, labels2): | ||||
|     # Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf | ||||
|     r = np.random.beta(32.0, 32.0)  # mixup ratio, alpha=beta=32.0 | ||||
|     im = (im * r + im2 * (1 - r)).astype(np.uint8) | ||||
|     labels = np.concatenate((labels, labels2), 0) | ||||
|     return im, labels | ||||
|  | ||||
|  | ||||
| def box_candidates(box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16):  # box1(4,n), box2(4,n) | ||||
|     # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio | ||||
|     w1, h1 = box1[2] - box1[0], box1[3] - box1[1] | ||||
|     w2, h2 = box2[2] - box2[0], box2[3] - box2[1] | ||||
|     ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps))  # aspect ratio | ||||
|     return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr)  # candidates | ||||
|  | ||||
|  | ||||
| def classify_albumentations( | ||||
|         augment=True, | ||||
|         size=224, | ||||
|         scale=(0.08, 1.0), | ||||
|         ratio=(0.75, 1.0 / 0.75),  # 0.75, 1.33 | ||||
|         hflip=0.5, | ||||
|         vflip=0.0, | ||||
|         jitter=0.4, | ||||
|         mean=IMAGENET_MEAN, | ||||
|         std=IMAGENET_STD, | ||||
|         auto_aug=False): | ||||
|     # YOLOv5 classification Albumentations (optional, only used if package is installed) | ||||
|     prefix = colorstr('albumentations: ') | ||||
|     try: | ||||
|         import albumentations as A | ||||
|         from albumentations.pytorch import ToTensorV2 | ||||
|         check_version(A.__version__, '1.0.3', hard=True)  # version requirement | ||||
|         if augment:  # Resize and crop | ||||
|             T = [A.RandomResizedCrop(height=size, width=size, scale=scale, ratio=ratio)] | ||||
|             if auto_aug: | ||||
|                 # TODO: implement AugMix, AutoAug & RandAug in albumentation | ||||
|                 LOGGER.info(f'{prefix}auto augmentations are currently not supported') | ||||
|             else: | ||||
|                 if hflip > 0: | ||||
|                     T += [A.HorizontalFlip(p=hflip)] | ||||
|                 if vflip > 0: | ||||
|                     T += [A.VerticalFlip(p=vflip)] | ||||
|                 if jitter > 0: | ||||
|                     color_jitter = (float(jitter),) * 3  # repeat value for brightness, contrast, satuaration, 0 hue | ||||
|                     T += [A.ColorJitter(*color_jitter, 0)] | ||||
|         else:  # Use fixed crop for eval set (reproducibility) | ||||
|             T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)] | ||||
|         T += [A.Normalize(mean=mean, std=std), ToTensorV2()]  # Normalize and convert to Tensor | ||||
|         LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p)) | ||||
|         return A.Compose(T) | ||||
|  | ||||
|     except ImportError:  # package not installed, skip | ||||
|         LOGGER.warning(f'{prefix}⚠️ not found, install with `pip install albumentations` (recommended)') | ||||
|     except Exception as e: | ||||
|         LOGGER.info(f'{prefix}{e}') | ||||
|  | ||||
|  | ||||
| def classify_transforms(size=224): | ||||
|     # Transforms to apply if albumentations not installed | ||||
|     assert isinstance(size, int), f'ERROR: classify_transforms size {size} must be integer, not (list, tuple)' | ||||
|     # T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)]) | ||||
|     return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)]) | ||||
|  | ||||
|  | ||||
| class LetterBox: | ||||
|     # YOLOv5 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()]) | ||||
|     def __init__(self, size=(640, 640), auto=False, stride=32): | ||||
|         super().__init__() | ||||
|         self.h, self.w = (size, size) if isinstance(size, int) else size | ||||
|         self.auto = auto  # pass max size integer, automatically solve for short side using stride | ||||
|         self.stride = stride  # used with auto | ||||
|  | ||||
|     def __call__(self, im):  # im = np.array HWC | ||||
|         imh, imw = im.shape[:2] | ||||
|         r = min(self.h / imh, self.w / imw)  # ratio of new/old | ||||
|         h, w = round(imh * r), round(imw * r)  # resized image | ||||
|         hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w | ||||
|         top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1) | ||||
|         im_out = np.full((self.h, self.w, 3), 114, dtype=im.dtype) | ||||
|         im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR) | ||||
|         return im_out | ||||
|  | ||||
|  | ||||
| class CenterCrop: | ||||
|     # YOLOv5 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()]) | ||||
|     def __init__(self, size=640): | ||||
|         super().__init__() | ||||
|         self.h, self.w = (size, size) if isinstance(size, int) else size | ||||
|  | ||||
|     def __call__(self, im):  # im = np.array HWC | ||||
|         imh, imw = im.shape[:2] | ||||
|         m = min(imh, imw)  # min dimension | ||||
|         top, left = (imh - m) // 2, (imw - m) // 2 | ||||
|         return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR) | ||||
|  | ||||
|  | ||||
| class ToTensor: | ||||
|     # YOLOv5 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()]) | ||||
|     def __init__(self, half=False): | ||||
|         super().__init__() | ||||
|         self.half = half | ||||
|  | ||||
|     def __call__(self, im):  # im = np.array HWC in BGR order | ||||
|         im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1])  # HWC to CHW -> BGR to RGB -> contiguous | ||||
|         im = torch.from_numpy(im)  # to torch | ||||
|         im = im.half() if self.half else im.float()  # uint8 to fp16/32 | ||||
|         im /= 255.0  # 0-255 to 0.0-1.0 | ||||
|         return im | ||||
							
								
								
									
										1216
									
								
								ultralytics/yolo/data/dataloaders/v5loader.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1216
									
								
								ultralytics/yolo/data/dataloaders/v5loader.py
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -1,7 +1,7 @@ | ||||
| import torch | ||||
| import yaml | ||||
|  | ||||
| # from ultralytics import yolo | ||||
| from ultralytics import yolo  # (required for python usage) | ||||
| # from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml | ||||
| from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG | ||||
| from ultralytics.yolo.utils import LOGGER | ||||
|  | ||||
| @ -110,8 +110,9 @@ class BaseTrainer: | ||||
|         world_size = torch.cuda.device_count() | ||||
|         if world_size > 1 and "LOCAL_RANK" not in os.environ: | ||||
|             command = generate_ddp_command(world_size, self) | ||||
|             print('DDP command: ', command) | ||||
|             subprocess.Popen(command) | ||||
|             ddp_cleanup(command, self) | ||||
|             # ddp_cleanup(command, self)  # TODO: uncomment and fix | ||||
|         else: | ||||
|             self._do_train(int(os.getenv("RANK", -1)), world_size) | ||||
|  | ||||
| @ -121,7 +122,7 @@ class BaseTrainer: | ||||
|         torch.cuda.set_device(rank) | ||||
|         self.device = torch.device('cuda', rank) | ||||
|         self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ") | ||||
|         mp.use_start_method('spawn', force=True) | ||||
|         mp.set_start_method('spawn', force=True) | ||||
|         dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size) | ||||
|  | ||||
|     def _setup_train(self, rank, world_size): | ||||
| @ -195,6 +196,11 @@ class BaseTrainer: | ||||
|             for i, batch in pbar: | ||||
|                 self.trigger_callbacks("on_train_batch_start") | ||||
|  | ||||
|                 # update dataloader attributes (optional) | ||||
|                 if epoch == (self.epochs - self.args.close_mosaic) and hasattr(self.train_loader.dataset, 'mosaic'): | ||||
|                     LOGGER.info("Closing dataloader mosaic") | ||||
|                     self.train_loader.dataset.mosaic = False | ||||
|  | ||||
|                 # warmup | ||||
|                 ni = i + nb * epoch | ||||
|                 if ni <= nw: | ||||
| @ -207,7 +213,7 @@ class BaseTrainer: | ||||
|                         if 'momentum' in x: | ||||
|                             x['momentum'] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum]) | ||||
|  | ||||
|                 # Forward | ||||
|                 # forward | ||||
|                 with torch.cuda.amp.autocast(self.amp): | ||||
|                     batch = self.preprocess_batch(batch) | ||||
|                     preds = self.model(batch["img"]) | ||||
| @ -217,10 +223,10 @@ class BaseTrainer: | ||||
|                     self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \ | ||||
|                         else self.loss_items | ||||
|  | ||||
|                 # Backward | ||||
|                 # backward | ||||
|                 self.scaler.scale(self.loss).backward() | ||||
|  | ||||
|                 # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html | ||||
|                 # optimize - https://pytorch.org/docs/master/notes/amp_examples.html | ||||
|                 if ni - last_opt_step >= self.accumulate: | ||||
|                     self.optimizer_step() | ||||
|                     last_opt_step = ni | ||||
|  | ||||
| @ -6,13 +6,15 @@ import sys | ||||
| import threading | ||||
| from pathlib import Path | ||||
|  | ||||
| import cv2 | ||||
| import IPython | ||||
| import pandas as pd | ||||
|  | ||||
| # Constants | ||||
| FILE = Path(__file__).resolve() | ||||
| ROOT = FILE.parents[2]  # YOLO | ||||
| RANK = int(os.getenv('RANK', -1)) | ||||
| DATASETS_DIR = ROOT.parent / 'datasets'  # YOLOv5 datasets directory | ||||
| DATASETS_DIR = Path(os.getenv('YOLOv5_DATASETS_DIR', ROOT.parent / 'datasets'))  # global datasets directory | ||||
| NUM_THREADS = min(8, max(1, os.cpu_count() - 1))  # number of YOLOv5 multiprocessing threads | ||||
| AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true'  # global auto-install mode | ||||
| FONT = 'Arial.ttf'  # https://ultralytics.com/assets/Arial.ttf | ||||
| @ -20,6 +22,14 @@ VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true'  # global ver | ||||
| TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}'  # tqdm bar format | ||||
| LOGGING_NAME = 'yolov5' | ||||
|  | ||||
| # Settings | ||||
| # torch.set_printoptions(linewidth=320, precision=5, profile='long') | ||||
| # np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format})  # format short g, %precision=5 | ||||
| pd.options.display.max_columns = 10 | ||||
| cv2.setNumThreads(0)  # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader) | ||||
| os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS)  # NumExpr max threads | ||||
| os.environ['OMP_NUM_THREADS'] = '1' if platform.system() == 'darwin' else str(NUM_THREADS)  # OpenMP (PyTorch and SciPy) | ||||
|  | ||||
|  | ||||
| def is_colab(): | ||||
|     # Is environment a Google Colab instance? | ||||
|  | ||||
| @ -2,13 +2,13 @@ | ||||
| # Default training settings and hyperparameters for medium-augmentation COCO training | ||||
|  | ||||
| # Task and Mode | ||||
| task: "classify" # choices=['detect', 'segment', 'classify', 'init'] # init is a special case | ||||
| mode: "train" # choice=['train', 'val', 'infer'] | ||||
| task: "classify"  # choices=['detect', 'segment', 'classify', 'init'] # init is a special case | ||||
| mode: "train"  # choice=['train', 'val', 'infer'] | ||||
|  | ||||
| # Train settings ------------------------------------------------------------------------------------------------------- | ||||
| model: null  # i.e. yolov5s.pt, yolo.yaml | ||||
| data: null  # i.e. coco128.yaml | ||||
| epochs: 300 | ||||
| epochs: 100 | ||||
| batch_size: 16 | ||||
| imgsz: 640 | ||||
| nosave: False | ||||
| @ -42,10 +42,10 @@ noval: False | ||||
| save_json: False | ||||
| save_hybrid: False | ||||
| conf_thres: 0.001 | ||||
| iou_thres: 0.6 | ||||
| iou_thres: 0.7 | ||||
| max_det: 300 | ||||
| half: True | ||||
| dnn: False # use OpenCV DNN for ONNX inference | ||||
| dnn: False  # use OpenCV DNN for ONNX inference | ||||
| plots: True | ||||
|  | ||||
| # Prediction settings: | ||||
| @ -56,9 +56,9 @@ save_conf: False | ||||
| save_crop: False | ||||
| hide_labels: False  # hide labels | ||||
| hide_conf: False | ||||
| vid_stride: 1 # video frame-rate stride | ||||
| vid_stride: 1  # video frame-rate stride | ||||
| line_thickness: 3  # bounding box thickness (pixels) | ||||
| update: False # Update all models | ||||
| update: False  # Update all models | ||||
| visualize: False | ||||
| augment: False | ||||
| agnostic_nms: False  # class-agnostic NMS | ||||
| @ -77,7 +77,7 @@ cls: 0.5  # cls loss gain (scale with pixels) | ||||
| dfl: 1.5  # dfl loss gain | ||||
| fl_gamma: 0.0  # focal loss gamma (efficientDet default gamma=1.5) | ||||
| label_smoothing: 0.0 | ||||
| nbs: 64 # nominal batch size | ||||
| nbs: 64  # nominal batch size | ||||
| hsv_h: 0.015  # image HSV-Hue augmentation (fraction) | ||||
| hsv_s: 0.7  # image HSV-Saturation augmentation (fraction) | ||||
| hsv_v: 0.4  # image HSV-Value augmentation (fraction) | ||||
| @ -92,6 +92,9 @@ mosaic: 1.0  # image mosaic (probability) | ||||
| mixup: 0.0  # image mixup (probability) | ||||
| copy_paste: 0.0  # segment copy-paste (probability) | ||||
|  | ||||
| # For debugging. Don't change | ||||
| v5loader: True | ||||
|  | ||||
| # Hydra configs -------------------------------------------------------------------------------------------------------- | ||||
| hydra: | ||||
|   output_subdir: null  # disable hydra directory creation | ||||
|  | ||||
| @ -47,7 +47,7 @@ def generate_ddp_command(world_size, trainer): | ||||
|     if using_cli: | ||||
|         file_name = generate_ddp_file(trainer) | ||||
|     return [ | ||||
|         sys.executable, "-m", "torch.distributed.launch", "--nproc_per_node", f"{world_size}", "--master_port", | ||||
|         sys.executable, "-m", "torch.distributed.run", "--nproc_per_node", f"{world_size}", "--master_port", | ||||
|         f"{find_free_network_port()}", file_name] + sys.argv[1:] | ||||
|  | ||||
|  | ||||
| @ -55,7 +55,7 @@ def ddp_cleanup(command, trainer): | ||||
|     # delete temp file if  created | ||||
|     # TODO: this is a temp solution in case the file is deleted before DDP launching | ||||
|     time.sleep(5) | ||||
|     tempfile_suffix = str(id(trainer)) + ".py" | ||||
|     tempfile_suffix = f"{id(trainer)}.py" | ||||
|     if tempfile_suffix in "".join(command): | ||||
|         for chunk in command: | ||||
|             if tempfile_suffix in chunk: | ||||
|  | ||||
| @ -4,7 +4,9 @@ import torch.nn as nn | ||||
|  | ||||
| from ultralytics.yolo import v8 | ||||
| from ultralytics.yolo.data import build_dataloader | ||||
| from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader | ||||
| from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer | ||||
| from ultralytics.yolo.utils import colorstr | ||||
| from ultralytics.yolo.utils.loss import BboxLoss | ||||
| from ultralytics.yolo.utils.metrics import smooth_BCE | ||||
| from ultralytics.yolo.utils.modeling.tasks import DetectionModel | ||||
| @ -21,7 +23,22 @@ class DetectionTrainer(BaseTrainer): | ||||
|         # TODO: manage splits differently | ||||
|         # calculate stride - check if model is initialized | ||||
|         gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) | ||||
|         return build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, rank=rank, mode=mode)[0] | ||||
|         return create_dataloader(path=dataset_path, | ||||
|                                  imgsz=self.args.imgsz, | ||||
|                                  batch_size=batch_size, | ||||
|                                  stride=gs, | ||||
|                                  hyp=dict(self.args), | ||||
|                                  augment=mode == "train", | ||||
|                                  cache=self.args.cache, | ||||
|                                  pad=0 if mode == "train" else 0.5, | ||||
|                                  rect=self.args.rect, | ||||
|                                  rank=rank, | ||||
|                                  workers=self.args.workers, | ||||
|                                  close_mosaic=self.args.close_mosaic != 0, | ||||
|                                  prefix=colorstr(f'{mode}: '), | ||||
|                                  shuffle=mode == "train", | ||||
|                                  seed=self.args.seed)[0] if self.args.v5loader else \ | ||||
|             build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, rank=rank, mode=mode)[0] | ||||
|  | ||||
|     def preprocess_batch(self, batch): | ||||
|         batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255 | ||||
|  | ||||
							
								
								
									
										42
									
								
								ultralytics/yolo/v8/models/yolov8m.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								ultralytics/yolo/v8/models/yolov8m.yaml
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,42 @@ | ||||
| # YOLOv5 🚀 by Ultralytics, GPL-3.0 license | ||||
|  | ||||
| # Parameters | ||||
| nc: 80  # number of classes | ||||
| depth_multiple: 0.67  # model depth multiple | ||||
| width_multiple: 0.75  # layer channel multiple | ||||
|  | ||||
| # YOLOv8.0m backbone | ||||
| backbone: | ||||
|   # [from, number, module, args] | ||||
|   [[-1, 1, Conv, [64, 3, 2]],  # 0-P1/2 | ||||
|    [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4 | ||||
|    [-1, 3, C2f, [128, True]], | ||||
|    [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8 | ||||
|    [-1, 6, C2f, [256, True]], | ||||
|    [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16 | ||||
|    [-1, 6, C2f, [512, True]], | ||||
|    [-1, 1, Conv, [768, 3, 2]],  # 7-P5/32 | ||||
|    [-1, 3, C2f, [768, True]], | ||||
|    [-1, 1, SPPF, [768, 5]],  # 9 | ||||
|   ] | ||||
|  | ||||
| # YOLOv8.0m head | ||||
| head: | ||||
|   [[-1, 1, nn.Upsample, [None, 2, 'nearest']], | ||||
|    [[-1, 6], 1, Concat, [1]],  # cat backbone P4 | ||||
|    [-1, 3, C2f, [512]],  # 13 | ||||
|  | ||||
|    [-1, 1, nn.Upsample, [None, 2, 'nearest']], | ||||
|    [[-1, 4], 1, Concat, [1]],  # cat backbone P3 | ||||
|    [-1, 3, C2f, [256]],  # 17 (P3/8-small) | ||||
|  | ||||
|    [-1, 1, Conv, [256, 3, 2]], | ||||
|    [[-1, 12], 1, Concat, [1]],  # cat head P4 | ||||
|    [-1, 3, C2f, [512]],  # 20 (P4/16-medium) | ||||
|  | ||||
|    [-1, 1, Conv, [512, 3, 2]], | ||||
|    [[-1, 9], 1, Concat, [1]],  # cat head P5 | ||||
|    [-1, 3, C2f, [768]],  # 23 (P5/32-large) | ||||
|  | ||||
|    [[15, 18, 21], 1, Detect, [nc]],  # Detect(P3, P4, P5) | ||||
|   ] | ||||
							
								
								
									
										42
									
								
								ultralytics/yolo/v8/models/yolov8s.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								ultralytics/yolo/v8/models/yolov8s.yaml
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,42 @@ | ||||
| # YOLOv5 🚀 by Ultralytics, GPL-3.0 license | ||||
|  | ||||
| # Parameters | ||||
| nc: 80  # number of classes | ||||
| depth_multiple: 0.33  # model depth multiple | ||||
| width_multiple: 0.50  # layer channel multiple | ||||
|  | ||||
| # YOLOv8.0s backbone | ||||
| backbone: | ||||
|   # [from, number, module, args] | ||||
|   [[-1, 1, Conv, [64, 3, 2]],  # 0-P1/2 | ||||
|    [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4 | ||||
|    [-1, 3, C2f, [128, True]], | ||||
|    [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8 | ||||
|    [-1, 6, C2f, [256, True]], | ||||
|    [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16 | ||||
|    [-1, 6, C2f, [512, True]], | ||||
|    [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32 | ||||
|    [-1, 3, C2f, [1024, True]], | ||||
|    [-1, 1, SPPF, [1024, 5]],  # 9 | ||||
|   ] | ||||
|  | ||||
| # YOLOv8.0s head | ||||
| head: | ||||
|   [[-1, 1, nn.Upsample, [None, 2, 'nearest']], | ||||
|    [[-1, 6], 1, Concat, [1]],  # cat backbone P4 | ||||
|    [-1, 3, C2f, [512]],  # 13 | ||||
|  | ||||
|    [-1, 1, nn.Upsample, [None, 2, 'nearest']], | ||||
|    [[-1, 4], 1, Concat, [1]],  # cat backbone P3 | ||||
|    [-1, 3, C2f, [256]],  # 17 (P3/8-small) | ||||
|  | ||||
|    [-1, 1, Conv, [256, 3, 2]], | ||||
|    [[-1, 12], 1, Concat, [1]],  # cat head P4 | ||||
|    [-1, 3, C2f, [512]],  # 20 (P4/16-medium) | ||||
|  | ||||
|    [-1, 1, Conv, [512, 3, 2]], | ||||
|    [[-1, 9], 1, Concat, [1]],  # cat head P5 | ||||
|    [-1, 3, C2f, [1024]],  # 23 (P5/32-large) | ||||
|  | ||||
|    [[15, 18, 21], 1, Detect, [nc]],  # Detect(P3, P4, P5) | ||||
|   ] | ||||
		Reference in New Issue
	
	Block a user