ultralytics 8.0.65
YOLOv8 Pose models (#1347)
Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Mert Can Demir <validatedev@gmail.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Fabian Greavu <fabiangreavu@gmail.com> Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com> Co-authored-by: Eric Pedley <ericpedley@gmail.com> Co-authored-by: JustasBart <40023722+JustasBart@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Aarni Koskela <akx@iki.fi> Co-authored-by: Sergio Sanchez <sergio.ssm.97@gmail.com> Co-authored-by: Bogdan Gheorghe <112427971+bogdan-galileo@users.noreply.github.com> Co-authored-by: Jaap van de Loosdrecht <jaap@vdlmv.nl> Co-authored-by: Noobtoss <96134731+Noobtoss@users.noreply.github.com> Co-authored-by: nerdyespresso <106761627+nerdyespresso@users.noreply.github.com> Co-authored-by: Farid Inawan <frdteknikelektro@gmail.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Alexander Duda <Alexander.Duda@me.com> Co-authored-by: Mehran Ghandehari <mehran.maps@gmail.com> Co-authored-by: Snyk bot <snyk-bot@snyk.io> Co-authored-by: majid nasiri <majnasai@gmail.com>
This commit is contained in:
@ -16,6 +16,8 @@ from ..utils.metrics import bbox_ioa
|
||||
from ..utils.ops import segment2box
|
||||
from .utils import polygons2masks, polygons2masks_overlap
|
||||
|
||||
POSE_FLIPLR_INDEX = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
|
||||
|
||||
|
||||
# TODO: we might need a BaseTransform to make all these augments be compatible with both classification and semantic
|
||||
class BaseTransform:
|
||||
@ -309,27 +311,22 @@ class RandomPerspective:
|
||||
"""apply affine to keypoints.
|
||||
|
||||
Args:
|
||||
keypoints(ndarray): keypoints, [N, 17, 2].
|
||||
keypoints(ndarray): keypoints, [N, 17, 3].
|
||||
M(ndarray): affine matrix.
|
||||
Return:
|
||||
new_keypoints(ndarray): keypoints after affine, [N, 17, 2].
|
||||
new_keypoints(ndarray): keypoints after affine, [N, 17, 3].
|
||||
"""
|
||||
n = len(keypoints)
|
||||
n, nkpt = keypoints.shape[:2]
|
||||
if n == 0:
|
||||
return keypoints
|
||||
new_keypoints = np.ones((n * 17, 3))
|
||||
new_keypoints[:, :2] = keypoints.reshape(n * 17, 2) # num_kpt is hardcoded to 17
|
||||
new_keypoints = new_keypoints @ M.T # transform
|
||||
new_keypoints = (new_keypoints[:, :2] / new_keypoints[:, 2:3]).reshape(n, 34) # perspective rescale or affine
|
||||
new_keypoints[keypoints.reshape(-1, 34) == 0] = 0
|
||||
x_kpts = new_keypoints[:, list(range(0, 34, 2))]
|
||||
y_kpts = new_keypoints[:, list(range(1, 34, 2))]
|
||||
|
||||
x_kpts[np.logical_or.reduce((x_kpts < 0, x_kpts > self.size[0], y_kpts < 0, y_kpts > self.size[1]))] = 0
|
||||
y_kpts[np.logical_or.reduce((x_kpts < 0, x_kpts > self.size[0], y_kpts < 0, y_kpts > self.size[1]))] = 0
|
||||
new_keypoints[:, list(range(0, 34, 2))] = x_kpts
|
||||
new_keypoints[:, list(range(1, 34, 2))] = y_kpts
|
||||
return new_keypoints.reshape(n, 17, 2)
|
||||
xy = np.ones((n * nkpt, 3))
|
||||
visible = keypoints[..., 2].reshape(n * nkpt, 1)
|
||||
xy[:, :2] = keypoints[..., :2].reshape(n * nkpt, 2)
|
||||
xy = xy @ M.T # transform
|
||||
xy = xy[:, :2] / xy[:, 2:3] # perspective rescale or affine
|
||||
out_mask = (xy[:, 0] < 0) | (xy[:, 1] < 0) | (xy[:, 0] > self.size[0]) | (xy[:, 1] > self.size[1])
|
||||
visible[out_mask] = 0
|
||||
return np.concatenate([xy, visible], axis=-1).reshape(n, nkpt, 3)
|
||||
|
||||
def __call__(self, labels):
|
||||
"""
|
||||
@ -415,12 +412,13 @@ class RandomHSV:
|
||||
|
||||
class RandomFlip:
|
||||
|
||||
def __init__(self, p=0.5, direction='horizontal') -> None:
|
||||
def __init__(self, p=0.5, direction='horizontal', flip_idx=None) -> None:
|
||||
assert direction in ['horizontal', 'vertical'], f'Support direction `horizontal` or `vertical`, got {direction}'
|
||||
assert 0 <= p <= 1.0
|
||||
|
||||
self.p = p
|
||||
self.direction = direction
|
||||
self.flip_idx = flip_idx
|
||||
|
||||
def __call__(self, labels):
|
||||
img = labels['img']
|
||||
@ -437,6 +435,9 @@ class RandomFlip:
|
||||
if self.direction == 'horizontal' and random.random() < self.p:
|
||||
img = np.fliplr(img)
|
||||
instances.fliplr(w)
|
||||
# for keypoints
|
||||
if self.flip_idx is not None and instances.keypoints is not None:
|
||||
instances.keypoints = np.ascontiguousarray(instances.keypoints[:, self.flip_idx, :])
|
||||
labels['img'] = np.ascontiguousarray(img)
|
||||
labels['instances'] = instances
|
||||
return labels
|
||||
@ -633,7 +634,7 @@ class Format:
|
||||
labels['cls'] = torch.from_numpy(cls) if nl else torch.zeros(nl)
|
||||
labels['bboxes'] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
|
||||
if self.return_keypoint:
|
||||
labels['keypoints'] = torch.from_numpy(instances.keypoints) if nl else torch.zeros((nl, 17, 2))
|
||||
labels['keypoints'] = torch.from_numpy(instances.keypoints)
|
||||
# then we can use collate_fn
|
||||
if self.batch_idx:
|
||||
labels['batch_idx'] = torch.zeros(nl)
|
||||
@ -672,13 +673,17 @@ def v8_transforms(dataset, imgsz, hyp):
|
||||
perspective=hyp.perspective,
|
||||
pre_transform=LetterBox(new_shape=(imgsz, imgsz)),
|
||||
)])
|
||||
flip_idx = dataset.data.get('flip_idx', None) # for keypoints augmentation
|
||||
if dataset.use_keypoints and flip_idx is None and hyp.fliplr > 0.0:
|
||||
hyp.fliplr = 0.0
|
||||
LOGGER.warning("WARNING ⚠️ No `flip_idx` provided while training keypoints, setting augmentation 'fliplr=0.0'")
|
||||
return Compose([
|
||||
pre_transform,
|
||||
MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
|
||||
Albumentations(p=1.0),
|
||||
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
|
||||
RandomFlip(direction='vertical', p=hyp.flipud),
|
||||
RandomFlip(direction='horizontal', p=hyp.fliplr)]) # transforms
|
||||
RandomFlip(direction='horizontal', p=hyp.fliplr, flip_idx=flip_idx)]) # transforms
|
||||
|
||||
|
||||
# Classification augmentations -----------------------------------------------------------------------------------------
|
||||
|
@ -61,7 +61,7 @@ def seed_worker(worker_id): # noqa
|
||||
random.seed(worker_seed)
|
||||
|
||||
|
||||
def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, rank=-1, mode='train'):
|
||||
def build_dataloader(cfg, batch, img_path, data_info, stride=32, rect=False, rank=-1, mode='train'):
|
||||
assert mode in ['train', 'val']
|
||||
shuffle = mode == 'train'
|
||||
if cfg.rect and shuffle:
|
||||
@ -81,9 +81,9 @@ def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, ra
|
||||
pad=0.0 if mode == 'train' else 0.5,
|
||||
prefix=colorstr(f'{mode}: '),
|
||||
use_segments=cfg.task == 'segment',
|
||||
use_keypoints=cfg.task == 'keypoint',
|
||||
names=names,
|
||||
classes=cfg.classes)
|
||||
use_keypoints=cfg.task == 'pose',
|
||||
classes=cfg.classes,
|
||||
data=data_info)
|
||||
|
||||
batch = min(batch, len(dataset))
|
||||
nd = torch.cuda.device_count() # number of CUDA devices
|
||||
|
@ -57,11 +57,11 @@ class YOLODataset(BaseDataset):
|
||||
single_cls=False,
|
||||
use_segments=False,
|
||||
use_keypoints=False,
|
||||
names=None,
|
||||
data=None,
|
||||
classes=None):
|
||||
self.use_segments = use_segments
|
||||
self.use_keypoints = use_keypoints
|
||||
self.names = names
|
||||
self.data = data
|
||||
assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.'
|
||||
super().__init__(img_path, imgsz, cache, augment, hyp, prefix, rect, batch_size, stride, pad, single_cls,
|
||||
classes)
|
||||
@ -77,10 +77,16 @@ class YOLODataset(BaseDataset):
|
||||
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
||||
desc = f'{self.prefix}Scanning {path.parent / path.stem}...'
|
||||
total = len(self.im_files)
|
||||
nc = len(self.data['names'])
|
||||
nkpt, ndim = self.data.get('kpt_shape', (0, 0))
|
||||
if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)):
|
||||
raise ValueError("'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
|
||||
"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'")
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
results = pool.imap(func=verify_image_label,
|
||||
iterable=zip(self.im_files, self.label_files, repeat(self.prefix),
|
||||
repeat(self.use_keypoints), repeat(len(self.names))))
|
||||
repeat(self.use_keypoints), repeat(len(self.data['names'])), repeat(nkpt),
|
||||
repeat(ndim)))
|
||||
pbar = tqdm(results, desc=desc, total=total, bar_format=TQDM_BAR_FORMAT)
|
||||
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
|
||||
nm += nm_f
|
||||
|
@ -6,10 +6,10 @@ import json
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
import zipfile
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
from tarfile import is_tarfile
|
||||
from zipfile import is_zipfile
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@ -61,7 +61,7 @@ def exif_size(img):
|
||||
|
||||
def verify_image_label(args):
|
||||
# Verify one image-label pair
|
||||
im_file, lb_file, prefix, keypoint, num_cls = args
|
||||
im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args
|
||||
# number (missing, found, empty, corrupt), message, segments, keypoints
|
||||
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, '', [], None
|
||||
try:
|
||||
@ -92,25 +92,19 @@ def verify_image_label(args):
|
||||
nl = len(lb)
|
||||
if nl:
|
||||
if keypoint:
|
||||
assert lb.shape[1] == 56, 'labels require 56 columns each'
|
||||
assert (lb[:, 5::3] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
|
||||
assert (lb[:, 6::3] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
|
||||
kpts = np.zeros((lb.shape[0], 39))
|
||||
for i in range(len(lb)):
|
||||
kpt = np.delete(lb[i, 5:], np.arange(2, lb.shape[1] - 5, 3)) # remove occlusion param from GT
|
||||
kpts[i] = np.hstack((lb[i, :5], kpt))
|
||||
lb = kpts
|
||||
assert lb.shape[1] == 39, 'labels require 39 columns each after removing occlusion parameter'
|
||||
assert lb.shape[1] == (5 + nkpt * ndim), f'labels require {(5 + nkpt * ndim)} columns each'
|
||||
assert (lb[:, 5::ndim] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
|
||||
assert (lb[:, 6::ndim] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
|
||||
else:
|
||||
assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
|
||||
assert (lb[:, 1:] <= 1).all(), \
|
||||
f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}'
|
||||
assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
|
||||
# All labels
|
||||
max_cls = int(lb[:, 0].max()) # max label count
|
||||
assert max_cls <= num_cls, \
|
||||
f'Label class {max_cls} exceeds dataset class count {num_cls}. ' \
|
||||
f'Possible class labels are 0-{num_cls - 1}'
|
||||
assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
|
||||
_, i = np.unique(lb, axis=0, return_index=True)
|
||||
if len(i) < nl: # duplicate row check
|
||||
lb = lb[i] # remove duplicates
|
||||
@ -119,12 +113,18 @@ def verify_image_label(args):
|
||||
msg = f'{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed'
|
||||
else:
|
||||
ne = 1 # label empty
|
||||
lb = np.zeros((0, 39), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32)
|
||||
lb = np.zeros((0, (5 + nkpt * ndim)), dtype=np.float32) if keypoint else np.zeros(
|
||||
(0, 5), dtype=np.float32)
|
||||
else:
|
||||
nm = 1 # label missing
|
||||
lb = np.zeros((0, 39), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32)
|
||||
lb = np.zeros((0, (5 + nkpt * ndim)), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32)
|
||||
if keypoint:
|
||||
keypoints = lb[:, 5:].reshape(-1, 17, 2)
|
||||
keypoints = lb[:, 5:].reshape(-1, nkpt, ndim)
|
||||
if ndim == 2:
|
||||
kpt_mask = np.ones(keypoints.shape[:2], dtype=np.float32)
|
||||
kpt_mask = np.where(keypoints[..., 0] < 0, 0.0, kpt_mask)
|
||||
kpt_mask = np.where(keypoints[..., 1] < 0, 0.0, kpt_mask)
|
||||
keypoints = np.concatenate([keypoints, kpt_mask[..., None]], axis=-1) # (nl, nkpt, 3)
|
||||
lb = lb[:, :5]
|
||||
return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
|
||||
except Exception as e:
|
||||
@ -195,7 +195,7 @@ def check_det_dataset(dataset, autodownload=True):
|
||||
|
||||
# Download (optional)
|
||||
extract_dir = ''
|
||||
if isinstance(data, (str, Path)) and (is_zipfile(data) or is_tarfile(data)):
|
||||
if isinstance(data, (str, Path)) and (zipfile.is_zipfile(data) or is_tarfile(data)):
|
||||
new_dir = safe_download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False)
|
||||
data = next((DATASETS_DIR / new_dir).rglob('*.yaml'))
|
||||
extract_dir, autodownload = data.parent, False
|
||||
@ -356,23 +356,8 @@ class HUBDatasetStats():
|
||||
assert dir.is_dir(), f'Error unzipping {path}, {dir} not found. path/to/abc.zip MUST unzip to path/to/abc/'
|
||||
return True, str(dir), self._find_yaml(dir) # zipped, data_dir, yaml_path
|
||||
|
||||
def _hub_ops(self, f, max_dim=1920):
|
||||
# HUB ops for 1 image 'f': resize and save at reduced quality in /dataset-hub for web/app viewing
|
||||
f_new = self.im_dir / Path(f).name # dataset-hub image filename
|
||||
try: # use PIL
|
||||
im = Image.open(f)
|
||||
r = max_dim / max(im.height, im.width) # ratio
|
||||
if r < 1.0: # image too large
|
||||
im = im.resize((int(im.width * r), int(im.height * r)))
|
||||
im.save(f_new, 'JPEG', quality=50, optimize=True) # save
|
||||
except Exception as e: # use OpenCV
|
||||
LOGGER.info(f'WARNING ⚠️ HUB ops PIL failure {f}: {e}')
|
||||
im = cv2.imread(f)
|
||||
im_height, im_width = im.shape[:2]
|
||||
r = max_dim / max(im_height, im_width) # ratio
|
||||
if r < 1.0: # image too large
|
||||
im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
|
||||
cv2.imwrite(str(f_new), im)
|
||||
def _hub_ops(self, f):
|
||||
compress_one_image(f, self.im_dir / Path(f).name) # save to dataset-hub
|
||||
|
||||
def get_json(self, save=False, verbose=False):
|
||||
# Return dataset JSON for Ultralytics HUB
|
||||
@ -426,3 +411,93 @@ class HUBDatasetStats():
|
||||
pass
|
||||
LOGGER.info(f'Done. All images saved to {self.im_dir}')
|
||||
return self.im_dir
|
||||
|
||||
|
||||
def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
|
||||
"""
|
||||
Compresses a single image file to reduced size while preserving its aspect ratio and quality using either the
|
||||
Python Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it will
|
||||
not be resized.
|
||||
|
||||
Args:
|
||||
f (str): The path to the input image file.
|
||||
f_new (str, optional): The path to the output image file. If not specified, the input file will be overwritten.
|
||||
max_dim (int, optional): The maximum dimension (width or height) of the output image. Default is 1920 pixels.
|
||||
quality (int, optional): The image compression quality as a percentage. Default is 50%.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Usage:
|
||||
from pathlib import Path
|
||||
from ultralytics.yolo.data.utils import compress_one_image
|
||||
for f in Path('/Users/glennjocher/Downloads/dataset').rglob('*.jpg'):
|
||||
compress_one_image(f)
|
||||
"""
|
||||
try: # use PIL
|
||||
im = Image.open(f)
|
||||
r = max_dim / max(im.height, im.width) # ratio
|
||||
if r < 1.0: # image too large
|
||||
im = im.resize((int(im.width * r), int(im.height * r)))
|
||||
im.save(f_new or f, 'JPEG', quality=quality, optimize=True) # save
|
||||
except Exception as e: # use OpenCV
|
||||
LOGGER.info(f'WARNING ⚠️ HUB ops PIL failure {f}: {e}')
|
||||
im = cv2.imread(f)
|
||||
im_height, im_width = im.shape[:2]
|
||||
r = max_dim / max(im_height, im_width) # ratio
|
||||
if r < 1.0: # image too large
|
||||
im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
|
||||
cv2.imwrite(str(f_new or f), im)
|
||||
|
||||
|
||||
def delete_dsstore(path):
|
||||
"""
|
||||
Deletes all ".DS_store" files under a specified directory.
|
||||
|
||||
Args:
|
||||
path (str, optional): The directory path where the ".DS_store" files should be deleted.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Usage:
|
||||
from ultralytics.yolo.data.utils import delete_dsstore
|
||||
delete_dsstore('/Users/glennjocher/Downloads/dataset')
|
||||
|
||||
Note:
|
||||
".DS_store" files are created by the Apple operating system and contain metadata about folders and files. They
|
||||
are hidden system files and can cause issues when transferring files between different operating systems.
|
||||
"""
|
||||
# Delete Apple .DS_store files
|
||||
files = list(Path(path).rglob('.DS_store'))
|
||||
LOGGER.info(f'Deleting *.DS_store files: {files}')
|
||||
for f in files:
|
||||
f.unlink()
|
||||
|
||||
|
||||
def zip_directory(dir, use_zipfile_library=True):
|
||||
"""Zips a directory and saves the archive to the specified output path.
|
||||
|
||||
Args:
|
||||
dir (str): The path to the directory to be zipped.
|
||||
use_zipfile_library (bool): Whether to use zipfile library or shutil for zipping.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Usage:
|
||||
from ultralytics.yolo.data.utils import zip_directory
|
||||
zip_directory('/Users/glennjocher/Downloads/playground')
|
||||
|
||||
zip -r coco8-pose.zip coco8-pose
|
||||
"""
|
||||
delete_dsstore(dir)
|
||||
if use_zipfile_library:
|
||||
dir = Path(dir)
|
||||
with zipfile.ZipFile(dir.with_suffix('.zip'), 'w', zipfile.ZIP_DEFLATED) as zip_file:
|
||||
for file_path in dir.glob('**/*'):
|
||||
if file_path.is_file():
|
||||
zip_file.write(file_path, file_path.relative_to(dir))
|
||||
else:
|
||||
import shutil
|
||||
shutil.make_archive(dir, 'zip', dir)
|
||||
|
Reference in New Issue
Block a user