Import YOLOv5 dataloader (#94)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Ayush Chaurasia 2 years ago committed by GitHub
parent ae05d44877
commit 16e3c08883
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -20,21 +20,21 @@ jobs:
matrix: matrix:
os: [ ubuntu-latest ] os: [ ubuntu-latest ]
python-version: [ '3.10' ] python-version: [ '3.10' ]
model: [ yolov5n ] model: [ yolov8n ]
torch: [ latest ] torch: [ latest ]
# include: # include:
# - os: ubuntu-latest # - os: ubuntu-latest
# python-version: '3.7' # '3.6.8' min # python-version: '3.7' # '3.6.8' min
# model: yolov5n # model: yolov8n
# - os: ubuntu-latest # - os: ubuntu-latest
# python-version: '3.8' # python-version: '3.8'
# model: yolov5n # model: yolov8n
# - os: ubuntu-latest # - os: ubuntu-latest
# python-version: '3.9' # python-version: '3.9'
# model: yolov5n # model: yolov8n
# - os: ubuntu-latest # - os: ubuntu-latest
# python-version: '3.8' # torch 1.7.0 requires python >=3.6, <=3.8 # python-version: '3.8' # torch 1.7.0 requires python >=3.6, <=3.8
# model: yolov5n # model: yolov8n
# torch: '1.7.0' # min torch version CI https://pypi.org/project/torchvision/ # torch: '1.7.0' # min torch version CI https://pypi.org/project/torchvision/
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3

@ -0,0 +1,402 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
"""
Image augmentation functions
"""
import math
import random
import cv2
import numpy as np
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from ultralytics.yolo.utils import LOGGER, colorstr
from ultralytics.yolo.utils.checks import check_version
from ultralytics.yolo.utils.metrics import bbox_ioa
from ultralytics.yolo.utils.ops import resample_segments, segment2box, xywhn2xyxy
IMAGENET_MEAN = 0.485, 0.456, 0.406 # RGB mean
IMAGENET_STD = 0.229, 0.224, 0.225 # RGB standard deviation
class Albumentations:
# YOLOv5 Albumentations class (optional, only used if package is installed)
def __init__(self, size=640):
self.transform = None
prefix = colorstr('albumentations: ')
try:
import albumentations as A
check_version(A.__version__, '1.0.3', hard=True) # version requirement
T = [
A.RandomResizedCrop(height=size, width=size, scale=(0.8, 1.0), ratio=(0.9, 1.11), p=0.0),
A.Blur(p=0.01),
A.MedianBlur(p=0.01),
A.ToGray(p=0.01),
A.CLAHE(p=0.01),
A.RandomBrightnessContrast(p=0.0),
A.RandomGamma(p=0.0),
A.ImageCompression(quality_lower=75, p=0.0)] # transforms
self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
except ImportError: # package not installed, skip
pass
except Exception as e:
LOGGER.info(f'{prefix}{e}')
def __call__(self, im, labels, p=1.0):
if self.transform and random.random() < p:
new = self.transform(image=im, bboxes=labels[:, 1:], class_labels=labels[:, 0]) # transformed
im, labels = new['image'], np.array([[c, *b] for c, b in zip(new['class_labels'], new['bboxes'])])
return im, labels
def normalize(x, mean=IMAGENET_MEAN, std=IMAGENET_STD, inplace=False):
# Denormalize RGB images x per ImageNet stats in BCHW format, i.e. = (x - mean) / std
return TF.normalize(x, mean, std, inplace=inplace)
def denormalize(x, mean=IMAGENET_MEAN, std=IMAGENET_STD):
# Denormalize RGB images x per ImageNet stats in BCHW format, i.e. = x * std + mean
for i in range(3):
x[:, i] = x[:, i] * std[i] + mean[i]
return x
def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
# HSV color-space augmentation
if hgain or sgain or vgain:
r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
hue, sat, val = cv2.split(cv2.cvtColor(im, cv2.COLOR_BGR2HSV))
dtype = im.dtype # uint8
x = np.arange(0, 256, dtype=r.dtype)
lut_hue = ((x * r[0]) % 180).astype(dtype)
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=im) # no return needed
def hist_equalize(im, clahe=True, bgr=False):
# Equalize histogram on BGR image 'im' with im.shape(n,m,3) and range 0-255
yuv = cv2.cvtColor(im, cv2.COLOR_BGR2YUV if bgr else cv2.COLOR_RGB2YUV)
if clahe:
c = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
yuv[:, :, 0] = c.apply(yuv[:, :, 0])
else:
yuv[:, :, 0] = cv2.equalizeHist(yuv[:, :, 0]) # equalize Y channel histogram
return cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR if bgr else cv2.COLOR_YUV2RGB) # convert YUV image to RGB
def replicate(im, labels):
# Replicate labels
h, w = im.shape[:2]
boxes = labels[:, 1:].astype(int)
x1, y1, x2, y2 = boxes.T
s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels)
for i in s.argsort()[:round(s.size * 0.5)]: # smallest indices
x1b, y1b, x2b, y2b = boxes[i]
bh, bw = y2b - y1b, x2b - x1b
yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y
x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh]
im[y1a:y2a, x1a:x2a] = im[y1b:y2b, x1b:x2b] # im4[ymin:ymax, xmin:xmax]
labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0)
return im, labels
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
# Resize and pad image while meeting stride-multiple constraints
shape = im.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
if not scaleup: # only scale down, do not scale up (for better val mAP)
r = min(r, 1.0)
# Compute padding
ratio = r, r # width, height ratios
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
if auto: # minimum rectangle
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
elif scaleFill: # stretch
dw, dh = 0.0, 0.0
new_unpad = (new_shape[1], new_shape[0])
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
dw /= 2 # divide padding into 2 sides
dh /= 2
if shape[::-1] != new_unpad: # resize
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
return im, ratio, (dw, dh)
def random_perspective(im,
targets=(),
segments=(),
degrees=10,
translate=.1,
scale=.1,
shear=10,
perspective=0.0,
border=(0, 0)):
# torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(0.1, 0.1), scale=(0.9, 1.1), shear=(-10, 10))
# targets = [cls, xyxy]
height = im.shape[0] + border[0] * 2 # shape(h,w,c)
width = im.shape[1] + border[1] * 2
# Center
C = np.eye(3)
C[0, 2] = -im.shape[1] / 2 # x translation (pixels)
C[1, 2] = -im.shape[0] / 2 # y translation (pixels)
# Perspective
P = np.eye(3)
P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y)
P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x)
# Rotation and Scale
R = np.eye(3)
a = random.uniform(-degrees, degrees)
# a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
s = random.uniform(1 - scale, 1 + scale)
# s = 2 ** random.uniform(-scale, scale)
R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
# Shear
S = np.eye(3)
S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
# Translation
T = np.eye(3)
T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels)
T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels)
# Combined rotation matrix
M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
if perspective:
im = cv2.warpPerspective(im, M, dsize=(width, height), borderValue=(114, 114, 114))
else: # affine
im = cv2.warpAffine(im, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
# Visualize
# import matplotlib.pyplot as plt
# ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
# ax[0].imshow(im[:, :, ::-1]) # base
# ax[1].imshow(im2[:, :, ::-1]) # warped
# Transform label coordinates
n = len(targets)
if n:
use_segments = any(x.any() for x in segments)
new = np.zeros((n, 4))
if use_segments: # warp segments
segments = resample_segments(segments) # upsample
for i, segment in enumerate(segments):
xy = np.ones((len(segment), 3))
xy[:, :2] = segment
xy = xy @ M.T # transform
xy = xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2] # perspective rescale or affine
# clip
new[i] = segment2box(xy, width, height)
else: # warp boxes
xy = np.ones((n * 4, 3))
xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
xy = xy @ M.T # transform
xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
# create new boxes
x = xy[:, [0, 2, 4, 6]]
y = xy[:, [1, 3, 5, 7]]
new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
# clip
new[:, [0, 2]] = new[:, [0, 2]].clip(0, width)
new[:, [1, 3]] = new[:, [1, 3]].clip(0, height)
# filter candidates
i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01 if use_segments else 0.10)
targets = targets[i]
targets[:, 1:5] = new[i]
return im, targets
def copy_paste(im, labels, segments, p=0.5):
# Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
n = len(segments)
if p and n:
h, w, c = im.shape # height, width, channels
im_new = np.zeros(im.shape, np.uint8)
# calculate ioa first then select indexes randomly
boxes = np.stack([w - labels[:, 3], labels[:, 2], w - labels[:, 1], labels[:, 4]], axis=-1) # (n, 4)
ioa = bbox_ioa(boxes, labels[:, 1:5]) # intersection over area
indexes = np.nonzero((ioa < 0.30).all(1))[0] # (N, )
n = len(indexes)
for j in random.sample(list(indexes), k=round(p * n)):
l, box, s = labels[j], boxes[j], segments[j]
labels = np.concatenate((labels, [[l[0], *box]]), 0)
segments.append(np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1))
cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (1, 1, 1), cv2.FILLED)
result = cv2.flip(im, 1) # augment segments (flip left-right)
i = cv2.flip(im_new, 1).astype(bool)
im[i] = result[i] # cv2.imwrite('debug.jpg', im) # debug
return im, labels, segments
def cutout(im, labels, p=0.5):
# Applies image cutout augmentation https://arxiv.org/abs/1708.04552
if random.random() < p:
h, w = im.shape[:2]
scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
for s in scales:
mask_h = random.randint(1, int(h * s)) # create random masks
mask_w = random.randint(1, int(w * s))
# box
xmin = max(0, random.randint(0, w) - mask_w // 2)
ymin = max(0, random.randint(0, h) - mask_h // 2)
xmax = min(w, xmin + mask_w)
ymax = min(h, ymin + mask_h)
# apply random color mask
im[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]
# return unobscured labels
if len(labels) and s > 0.03:
box = np.array([[xmin, ymin, xmax, ymax]], dtype=np.float32)
ioa = bbox_ioa(box, xywhn2xyxy(labels[:, 1:5], w, h))[0] # intersection over area
labels = labels[ioa < 0.60] # remove >60% obscured labels
return labels
def mixup(im, labels, im2, labels2):
# Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
im = (im * r + im2 * (1 - r)).astype(np.uint8)
labels = np.concatenate((labels, labels2), 0)
return im, labels
def box_candidates(box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
# Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
def classify_albumentations(
augment=True,
size=224,
scale=(0.08, 1.0),
ratio=(0.75, 1.0 / 0.75), # 0.75, 1.33
hflip=0.5,
vflip=0.0,
jitter=0.4,
mean=IMAGENET_MEAN,
std=IMAGENET_STD,
auto_aug=False):
# YOLOv5 classification Albumentations (optional, only used if package is installed)
prefix = colorstr('albumentations: ')
try:
import albumentations as A
from albumentations.pytorch import ToTensorV2
check_version(A.__version__, '1.0.3', hard=True) # version requirement
if augment: # Resize and crop
T = [A.RandomResizedCrop(height=size, width=size, scale=scale, ratio=ratio)]
if auto_aug:
# TODO: implement AugMix, AutoAug & RandAug in albumentation
LOGGER.info(f'{prefix}auto augmentations are currently not supported')
else:
if hflip > 0:
T += [A.HorizontalFlip(p=hflip)]
if vflip > 0:
T += [A.VerticalFlip(p=vflip)]
if jitter > 0:
color_jitter = (float(jitter),) * 3 # repeat value for brightness, contrast, satuaration, 0 hue
T += [A.ColorJitter(*color_jitter, 0)]
else: # Use fixed crop for eval set (reproducibility)
T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor
LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
return A.Compose(T)
except ImportError: # package not installed, skip
LOGGER.warning(f'{prefix}⚠️ not found, install with `pip install albumentations` (recommended)')
except Exception as e:
LOGGER.info(f'{prefix}{e}')
def classify_transforms(size=224):
# Transforms to apply if albumentations not installed
assert isinstance(size, int), f'ERROR: classify_transforms size {size} must be integer, not (list, tuple)'
# T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
class LetterBox:
# YOLOv5 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
def __init__(self, size=(640, 640), auto=False, stride=32):
super().__init__()
self.h, self.w = (size, size) if isinstance(size, int) else size
self.auto = auto # pass max size integer, automatically solve for short side using stride
self.stride = stride # used with auto
def __call__(self, im): # im = np.array HWC
imh, imw = im.shape[:2]
r = min(self.h / imh, self.w / imw) # ratio of new/old
h, w = round(imh * r), round(imw * r) # resized image
hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w
top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1)
im_out = np.full((self.h, self.w, 3), 114, dtype=im.dtype)
im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
return im_out
class CenterCrop:
# YOLOv5 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])
def __init__(self, size=640):
super().__init__()
self.h, self.w = (size, size) if isinstance(size, int) else size
def __call__(self, im): # im = np.array HWC
imh, imw = im.shape[:2]
m = min(imh, imw) # min dimension
top, left = (imh - m) // 2, (imw - m) // 2
return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
class ToTensor:
# YOLOv5 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
def __init__(self, half=False):
super().__init__()
self.half = half
def __call__(self, im): # im = np.array HWC in BGR order
im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1]) # HWC to CHW -> BGR to RGB -> contiguous
im = torch.from_numpy(im) # to torch
im = im.half() if self.half else im.float() # uint8 to fp16/32
im /= 255.0 # 0-255 to 0.0-1.0
return im

File diff suppressed because it is too large Load Diff

@ -1,7 +1,7 @@
import torch import torch
import yaml import yaml
# from ultralytics import yolo from ultralytics import yolo # (required for python usage)
# from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml # from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
from ultralytics.yolo.utils import LOGGER from ultralytics.yolo.utils import LOGGER

@ -110,8 +110,9 @@ class BaseTrainer:
world_size = torch.cuda.device_count() world_size = torch.cuda.device_count()
if world_size > 1 and "LOCAL_RANK" not in os.environ: if world_size > 1 and "LOCAL_RANK" not in os.environ:
command = generate_ddp_command(world_size, self) command = generate_ddp_command(world_size, self)
print('DDP command: ', command)
subprocess.Popen(command) subprocess.Popen(command)
ddp_cleanup(command, self) # ddp_cleanup(command, self) # TODO: uncomment and fix
else: else:
self._do_train(int(os.getenv("RANK", -1)), world_size) self._do_train(int(os.getenv("RANK", -1)), world_size)
@ -121,7 +122,7 @@ class BaseTrainer:
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
self.device = torch.device('cuda', rank) self.device = torch.device('cuda', rank)
self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ") self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ")
mp.use_start_method('spawn', force=True) mp.set_start_method('spawn', force=True)
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size) dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
def _setup_train(self, rank, world_size): def _setup_train(self, rank, world_size):
@ -195,6 +196,11 @@ class BaseTrainer:
for i, batch in pbar: for i, batch in pbar:
self.trigger_callbacks("on_train_batch_start") self.trigger_callbacks("on_train_batch_start")
# update dataloader attributes (optional)
if epoch == (self.epochs - self.args.close_mosaic) and hasattr(self.train_loader.dataset, 'mosaic'):
LOGGER.info("Closing dataloader mosaic")
self.train_loader.dataset.mosaic = False
# warmup # warmup
ni = i + nb * epoch ni = i + nb * epoch
if ni <= nw: if ni <= nw:
@ -207,7 +213,7 @@ class BaseTrainer:
if 'momentum' in x: if 'momentum' in x:
x['momentum'] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum]) x['momentum'] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
# Forward # forward
with torch.cuda.amp.autocast(self.amp): with torch.cuda.amp.autocast(self.amp):
batch = self.preprocess_batch(batch) batch = self.preprocess_batch(batch)
preds = self.model(batch["img"]) preds = self.model(batch["img"])
@ -217,10 +223,10 @@ class BaseTrainer:
self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \ self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
else self.loss_items else self.loss_items
# Backward # backward
self.scaler.scale(self.loss).backward() self.scaler.scale(self.loss).backward()
# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html # optimize - https://pytorch.org/docs/master/notes/amp_examples.html
if ni - last_opt_step >= self.accumulate: if ni - last_opt_step >= self.accumulate:
self.optimizer_step() self.optimizer_step()
last_opt_step = ni last_opt_step = ni

@ -6,13 +6,15 @@ import sys
import threading import threading
from pathlib import Path from pathlib import Path
import cv2
import IPython import IPython
import pandas as pd
# Constants # Constants
FILE = Path(__file__).resolve() FILE = Path(__file__).resolve()
ROOT = FILE.parents[2] # YOLO ROOT = FILE.parents[2] # YOLO
RANK = int(os.getenv('RANK', -1)) RANK = int(os.getenv('RANK', -1))
DATASETS_DIR = ROOT.parent / 'datasets' # YOLOv5 datasets directory DATASETS_DIR = Path(os.getenv('YOLOv5_DATASETS_DIR', ROOT.parent / 'datasets')) # global datasets directory
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
@ -20,6 +22,14 @@ VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global ver
TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' # tqdm bar format TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' # tqdm bar format
LOGGING_NAME = 'yolov5' LOGGING_NAME = 'yolov5'
# Settings
# torch.set_printoptions(linewidth=320, precision=5, profile='long')
# np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
pd.options.display.max_columns = 10
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
os.environ['OMP_NUM_THREADS'] = '1' if platform.system() == 'darwin' else str(NUM_THREADS) # OpenMP (PyTorch and SciPy)
def is_colab(): def is_colab():
# Is environment a Google Colab instance? # Is environment a Google Colab instance?

@ -8,7 +8,7 @@ mode: "train" # choice=['train', 'val', 'infer']
# Train settings ------------------------------------------------------------------------------------------------------- # Train settings -------------------------------------------------------------------------------------------------------
model: null # i.e. yolov5s.pt, yolo.yaml model: null # i.e. yolov5s.pt, yolo.yaml
data: null # i.e. coco128.yaml data: null # i.e. coco128.yaml
epochs: 300 epochs: 100
batch_size: 16 batch_size: 16
imgsz: 640 imgsz: 640
nosave: False nosave: False
@ -42,7 +42,7 @@ noval: False
save_json: False save_json: False
save_hybrid: False save_hybrid: False
conf_thres: 0.001 conf_thres: 0.001
iou_thres: 0.6 iou_thres: 0.7
max_det: 300 max_det: 300
half: True half: True
dnn: False # use OpenCV DNN for ONNX inference dnn: False # use OpenCV DNN for ONNX inference
@ -92,6 +92,9 @@ mosaic: 1.0 # image mosaic (probability)
mixup: 0.0 # image mixup (probability) mixup: 0.0 # image mixup (probability)
copy_paste: 0.0 # segment copy-paste (probability) copy_paste: 0.0 # segment copy-paste (probability)
# For debugging. Don't change
v5loader: True
# Hydra configs -------------------------------------------------------------------------------------------------------- # Hydra configs --------------------------------------------------------------------------------------------------------
hydra: hydra:
output_subdir: null # disable hydra directory creation output_subdir: null # disable hydra directory creation

@ -47,7 +47,7 @@ def generate_ddp_command(world_size, trainer):
if using_cli: if using_cli:
file_name = generate_ddp_file(trainer) file_name = generate_ddp_file(trainer)
return [ return [
sys.executable, "-m", "torch.distributed.launch", "--nproc_per_node", f"{world_size}", "--master_port", sys.executable, "-m", "torch.distributed.run", "--nproc_per_node", f"{world_size}", "--master_port",
f"{find_free_network_port()}", file_name] + sys.argv[1:] f"{find_free_network_port()}", file_name] + sys.argv[1:]
@ -55,7 +55,7 @@ def ddp_cleanup(command, trainer):
# delete temp file if created # delete temp file if created
# TODO: this is a temp solution in case the file is deleted before DDP launching # TODO: this is a temp solution in case the file is deleted before DDP launching
time.sleep(5) time.sleep(5)
tempfile_suffix = str(id(trainer)) + ".py" tempfile_suffix = f"{id(trainer)}.py"
if tempfile_suffix in "".join(command): if tempfile_suffix in "".join(command):
for chunk in command: for chunk in command:
if tempfile_suffix in chunk: if tempfile_suffix in chunk:

@ -4,7 +4,9 @@ import torch.nn as nn
from ultralytics.yolo import v8 from ultralytics.yolo import v8
from ultralytics.yolo.data import build_dataloader from ultralytics.yolo.data import build_dataloader
from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
from ultralytics.yolo.utils import colorstr
from ultralytics.yolo.utils.loss import BboxLoss from ultralytics.yolo.utils.loss import BboxLoss
from ultralytics.yolo.utils.metrics import smooth_BCE from ultralytics.yolo.utils.metrics import smooth_BCE
from ultralytics.yolo.utils.modeling.tasks import DetectionModel from ultralytics.yolo.utils.modeling.tasks import DetectionModel
@ -21,7 +23,22 @@ class DetectionTrainer(BaseTrainer):
# TODO: manage splits differently # TODO: manage splits differently
# calculate stride - check if model is initialized # calculate stride - check if model is initialized
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
return build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, rank=rank, mode=mode)[0] return create_dataloader(path=dataset_path,
imgsz=self.args.imgsz,
batch_size=batch_size,
stride=gs,
hyp=dict(self.args),
augment=mode == "train",
cache=self.args.cache,
pad=0 if mode == "train" else 0.5,
rect=self.args.rect,
rank=rank,
workers=self.args.workers,
close_mosaic=self.args.close_mosaic != 0,
prefix=colorstr(f'{mode}: '),
shuffle=mode == "train",
seed=self.args.seed)[0] if self.args.v5loader else \
build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, rank=rank, mode=mode)[0]
def preprocess_batch(self, batch): def preprocess_batch(self, batch):
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255 batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255

@ -0,0 +1,42 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# Parameters
nc: 80 # number of classes
depth_multiple: 0.67 # model depth multiple
width_multiple: 0.75 # layer channel multiple
# YOLOv8.0m backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 3, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C2f, [128, True]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C2f, [256, True]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 6, C2f, [512, True]],
[-1, 1, Conv, [768, 3, 2]], # 7-P5/32
[-1, 3, C2f, [768, True]],
[-1, 1, SPPF, [768, 5]], # 9
]
# YOLOv8.0m head
head:
[[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C2f, [512]], # 13
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C2f, [256]], # 17 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 12], 1, Concat, [1]], # cat head P4
[-1, 3, C2f, [512]], # 20 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 9], 1, Concat, [1]], # cat head P5
[-1, 3, C2f, [768]], # 23 (P5/32-large)
[[15, 18, 21], 1, Detect, [nc]], # Detect(P3, P4, P5)
]

@ -0,0 +1,42 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.50 # layer channel multiple
# YOLOv8.0s backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 3, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C2f, [128, True]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C2f, [256, True]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 6, C2f, [512, True]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 3, C2f, [1024, True]],
[-1, 1, SPPF, [1024, 5]], # 9
]
# YOLOv8.0s head
head:
[[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C2f, [512]], # 13
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C2f, [256]], # 17 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 12], 1, Concat, [1]], # cat head P4
[-1, 3, C2f, [512]], # 20 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 9], 1, Concat, [1]], # cat head P5
[-1, 3, C2f, [1024]], # 23 (P5/32-large)
[[15, 18, 21], 1, Detect, [nc]], # Detect(P3, P4, P5)
]
Loading…
Cancel
Save