YOLOv5 updates (#90)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2022-12-25 14:33:18 +01:00
committed by GitHub
parent ebd3cfb2fd
commit 98815d560f
27 changed files with 281 additions and 161 deletions

View File

@ -82,7 +82,7 @@ class BaseMixTransform:
indexes = [indexes]
# get images information will be used for Mosaic or MixUp
mix_labels = [deepcopy(dataset.get_label_info(index)) for index in indexes]
mix_labels = [dataset.get_label_info(index) for index in indexes]
if self.pre_transform is not None:
for i, data in enumerate(mix_labels):
@ -134,9 +134,8 @@ class Mosaic(BaseMixTransform):
assert len(labels.get("mix_labels", [])) > 0, "There are no other images for mosaic augment."
s = self.imgsz
yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y
mix_labels = labels["mix_labels"]
for i in range(4):
labels_patch = deepcopy(labels) if i == 0 else deepcopy(mix_labels[i - 1])
labels_patch = (labels if i == 0 else labels["mix_labels"][i - 1]).copy()
# Load image
img = labels_patch["img"]
h, w = labels_patch["resized_shape"]
@ -186,9 +185,8 @@ class Mosaic(BaseMixTransform):
"ori_shape": mosaic_labels[0]["ori_shape"],
"resized_shape": (self.imgsz * 2, self.imgsz * 2),
"im_file": mosaic_labels[0]["im_file"],
"cls": np.concatenate(cls, 0)}
final_labels["instances"] = Instances.concatenate(instances, axis=0)
"cls": np.concatenate(cls, 0),
"instances": Instances.concatenate(instances, axis=0)}
final_labels["instances"].clip(self.imgsz * 2, self.imgsz * 2)
return final_labels
@ -345,7 +343,6 @@ class RandomPerspective:
Affine images and targets.
Args:
img(ndarray): image.
labels(Dict): a dict of `bboxes`, `segments`, `keypoints`.
"""
img = labels["img"]
@ -387,7 +384,7 @@ class RandomPerspective:
return labels
def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
# Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
# Compute box candidates: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
@ -609,6 +606,7 @@ class Format:
self.batch_idx = batch_idx # keep the batch indexes
def __call__(self, labels):
labels.pop("dataset", None)
img = labels["img"]
h, w = img.shape[:2]
cls = labels.pop("cls")
@ -672,10 +670,7 @@ def mosaic_transforms(imgsz, hyp):
),])
return Compose([
pre_transform,
MixUp(
pre_transform=pre_transform,
p=hyp.mixup,
),
MixUp(pre_transform=pre_transform, p=hyp.mixup),
Albumentations(p=1.0),
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
RandomFlip(direction="vertical", p=hyp.flipud),

View File

@ -1,4 +1,5 @@
import glob
import math
import os
from multiprocessing.pool import ThreadPool
from pathlib import Path
@ -121,7 +122,7 @@ class BaseDataset(Dataset):
r = self.imgsz / max(h0, w0) # ratio
if r != 1: # if sizes are not equal
interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
im = cv2.resize(im, (int(w0 * r), int(h0 * r)), interpolation=interp)
im = cv2.resize(im, (math.ceil(w0 * r), math.ceil(h0 * r)), interpolation=interp)
return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
@ -179,10 +180,7 @@ class BaseDataset(Dataset):
def get_label_info(self, index):
label = self.labels[index].copy()
img, (h0, w0), (h, w) = self.load_image(index)
label["img"] = img
label["ori_shape"] = (h0, w0)
label["resized_shape"] = (h, w)
label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index)
if self.rect:
label["rect_shape"] = self.batch_shapes[self.batch[index]]
label = self.update_labels_info(label)

View File

@ -64,7 +64,7 @@ def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank
label_path=label_path,
imgsz=cfg.imgsz,
batch_size=batch_size,
augment=True if mode == "train" else False, # augmentation
augment=mode == "train", # augmentation
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
rect=cfg.rect if mode == "train" else True, # rectangular batches
cache=None if cfg.noval else cfg.get("cache", None),
@ -73,31 +73,25 @@ def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank
pad=0.0 if mode == "train" else 0.5,
prefix=colorstr(f"{mode}: "),
use_segments=cfg.task == "segment",
use_keypoints=cfg.task == "keypoint",
)
use_keypoints=cfg.task == "keypoint")
batch_size = min(batch_size, len(dataset))
nd = torch.cuda.device_count() # number of CUDA devices
workers = cfg.workers if mode == "train" else cfg.workers * 2
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
loader = DataLoader if cfg.image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
loader = DataLoader if cfg.image_weights or cfg.close_mosaic else InfiniteDataLoader # allow attribute updates
generator = torch.Generator()
generator.manual_seed(6148914691236517205 + RANK)
return (
loader(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle and sampler is None,
num_workers=nw,
sampler=sampler,
pin_memory=PIN_MEMORY,
collate_fn=getattr(dataset, "collate_fn", None),
worker_init_fn=seed_worker,
generator=generator,
),
dataset,
)
return loader(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle and sampler is None,
num_workers=nw,
sampler=sampler,
pin_memory=PIN_MEMORY,
collate_fn=getattr(dataset, "collate_fn", None),
worker_init_fn=seed_worker,
generator=generator), dataset
# build classification

View File

@ -124,13 +124,9 @@ class YOLODataset(BaseDataset):
# TODO: use hyp config to set all these augmentations
def build_transforms(self, hyp=None):
mosaic = self.augment and not self.rect
# mosaic = False
if self.augment:
if mosaic:
transforms = mosaic_transforms(self.imgsz, hyp)
else:
transforms = affine_transforms(self.imgsz, hyp)
mosaic = self.augment and not self.rect
transforms = mosaic_transforms(self.imgsz, hyp) if mosaic else affine_transforms(self.imgsz, hyp)
else:
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz))])
transforms.append(
@ -143,7 +139,7 @@ class YOLODataset(BaseDataset):
def update_labels_info(self, label):
"""custom your label format here"""
# NOTE: cls is not with bboxes now, since other tasks like classification and semantic segmentation need a independent cls label
# NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label
# we can make it also support classification and semantic segmentation by add or remove some dict keys there.
bboxes = label.pop("bboxes")
segments = label.pop("segments")
@ -206,7 +202,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
else:
sample = self.torch_transforms(im)
return OrderedDict(img=sample, cls=j)
return {'img': sample, 'cls': j}
def __len__(self) -> int:
return len(self.samples)

View File

@ -0,0 +1,113 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# COCO 2017 dataset http://cocodataset.org by Microsoft
# Example usage: python train.py --data coco.yaml
# parent
# ├── yolov5
# └── datasets
# └── coco ← downloads here (20.1 GB)
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
path: ../datasets/coco # dataset root dir
train: train2017.txt # train images (relative to 'path') 118287 images
val: val2017.txt # val images (relative to 'path') 5000 images
test: test-dev2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794
# Classes
names:
0: person
1: bicycle
2: car
3: motorcycle
4: airplane
5: bus
6: train
7: truck
8: boat
9: traffic light
10: fire hydrant
11: stop sign
12: parking meter
13: bench
14: bird
15: cat
16: dog
17: horse
18: sheep
19: cow
20: elephant
21: bear
22: zebra
23: giraffe
24: backpack
25: umbrella
26: handbag
27: tie
28: suitcase
29: frisbee
30: skis
31: snowboard
32: sports ball
33: kite
34: baseball bat
35: baseball glove
36: skateboard
37: surfboard
38: tennis racket
39: bottle
40: wine glass
41: cup
42: fork
43: knife
44: spoon
45: bowl
46: banana
47: apple
48: sandwich
49: orange
50: broccoli
51: carrot
52: hot dog
53: pizza
54: donut
55: cake
56: chair
57: couch
58: potted plant
59: bed
60: dining table
61: toilet
62: tv
63: laptop
64: mouse
65: remote
66: keyboard
67: cell phone
68: microwave
69: oven
70: toaster
71: sink
72: refrigerator
73: book
74: clock
75: vase
76: scissors
77: teddy bear
78: hair drier
79: toothbrush
# Download script/URL (optional)
download: |
from utils.general import download, Path
# Download labels
segments = True # segment or box labels
dir = Path(yaml['path']) # dataset root dir
url = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/'
urls = [url + ('coco2017labels-segments.zip' if segments else 'coco2017labels.zip')] # labels
download(urls, dir=dir.parent)
# Download data
urls = ['http://images.cocodataset.org/zips/train2017.zip', # 19G, 118k images
'http://images.cocodataset.org/zips/val2017.zip', # 1G, 5k images
'http://images.cocodataset.org/zips/test2017.zip'] # 7G, 41k images (optional)
download(urls, dir=dir / 'images', threads=3)