`ultralytics 8.0.74` Pose labels, fp64 labels, Ensemble fixes (#1956)

Co-authored-by: jjlira <63210717+jjlira@users.noreply.github.com>
Co-authored-by: Jose Lira <jose.lira@georgebrown.ca>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: wilsonlmh <lu6ni4z-forum@Yahoo.com.hk>
Co-authored-by: HaeJin Lee <seareale@gmail.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent 5629ed0bb7
commit 24363236f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -26,14 +26,14 @@
"<img width=\"1024\", src=\"https://github.com/ultralytics/assets/raw/main/im/ultralytics-hub.png\"></a>\n",
"\n",
"<div align=\"center\">\n",
" <a href=\"https://github.com/ultralytics/hub/actions/workflows/ci.yaml\">\n",
" <img src=\"https://github.com/ultralytics/hub/actions/workflows/ci.yaml/badge.svg\" alt=\"CI CPU\"></a>\n",
" <a href=\"https://colab.research.google.com/github/ultralytics/hub/blob/master/hub.ipynb\">\n",
" <a href=\"https://github.com/ultralytics/ultralytics/actions/workflows/ci.yaml\">\n",
" <img src=\"https://github.com/ultralytics/ultralytics/actions/workflows/ci.yaml/badge.svg\" alt=\"CI CPU\"></a>\n",
" <a href=\"https://colab.research.google.com/github/ultralytics/ultralytics/blob/main/examples/hub.ipynb\">\n",
" <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"></a>\n",
"\n",
"Welcome to the [Ultralytics](https://ultralytics.com/) HUB notebook! \n",
"\n",
"This notebook allows you to train [YOLOv5](https://github.com/ultralytics/yolov5) and [YOLOv8](https://github.com/ultralytics/ultralytics) 🚀 models using [HUB](https://hub.ultralytics.com/). Please browse the YOLOv8 <a href=\"https://docs.ultralytics.com\">Docs</a> for details, raise an issue on <a href=\"https://github.com/ultralytics/hub/issues/new/choose\">GitHub</a> for support, and join our <a href=\"https://discord.gg/n6cFeSPZdD\">Discord</a> community for questions and discussions!\n",
"This notebook allows you to train [YOLOv5](https://github.com/ultralytics/yolov5) and [YOLOv8](https://github.com/ultralytics/ultralytics) 🚀 models using [HUB](https://hub.ultralytics.com/). Please browse the YOLOv8 <a href=\"https://docs.ultralytics.com\">Docs</a> for details, raise an issue on <a href=\"https://github.com/ultralytics/ultralytics/issues/new/choose\">GitHub</a> for support, and join our <a href=\"https://discord.gg/n6cFeSPZdD\">Discord</a> community for questions and discussions!\n",
"</div>"
]
},

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
__version__ = '8.0.73'
__version__ = '8.0.74'
from ultralytics.hub import start
from ultralytics.yolo.engine.model import YOLO

@ -112,17 +112,21 @@ class HUBTrainingSession:
raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix
self.model_id = data['id']
self.train_args = {
'batch': data['batch' if ('batch' in data) else 'batch_size'], # TODO: deprecate 'batch_size' in 3Q23
'epochs': data['epochs'],
'imgsz': data['imgsz'],
'patience': data['patience'],
'device': data['device'],
'cache': data['cache'],
'data': data['data']}
self.model_file = data.get('cfg', data['weights'])
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
if data['status'] == 'new': # new model to start training
self.train_args = {
# TODO: deprecate 'batch_size' key for 'batch' in 3Q23
'batch': data['batch' if ('batch' in data) else 'batch_size'],
'epochs': data['epochs'],
'imgsz': data['imgsz'],
'patience': data['patience'],
'device': data['device'],
'cache': data['cache'],
'data': data['data']}
self.model_file = data.get('cfg', data['weights'])
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
elif data['status'] == 'training': # existing model to resume training
self.train_args = {'data': data['data'], 'resume': True}
self.model_file = data['resume']
return data
except requests.exceptions.ConnectionError as e:

@ -374,7 +374,7 @@ class Ensemble(nn.ModuleList):
y = [module(x, augment, profile, visualize)[0] for module in self]
# y = torch.stack(y).max(0)[0] # max ensemble
# y = torch.stack(y).mean(0) # mean ensemble
y = torch.cat(y, 1) # nms ensemble
y = torch.cat(y, 2) # nms ensemble, y shape(B, HW, C)
return y, None # inference, train output

@ -127,7 +127,7 @@ class Mosaic(BaseMixTransform):
s = self.imgsz
yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y
for i in range(4):
labels_patch = (labels if i == 0 else labels['mix_labels'][i - 1]).copy()
labels_patch = labels if i == 0 else labels['mix_labels'][i - 1]
# Load image
img = labels_patch['img']
h, w = labels_patch.pop('resized_shape')
@ -223,18 +223,18 @@ class RandomPerspective:
def affine_transform(self, img, border):
# Center
C = np.eye(3)
C = np.eye(3, dtype=np.float32)
C[0, 2] = -img.shape[1] / 2 # x translation (pixels)
C[1, 2] = -img.shape[0] / 2 # y translation (pixels)
# Perspective
P = np.eye(3)
P = np.eye(3, dtype=np.float32)
P[2, 0] = random.uniform(-self.perspective, self.perspective) # x perspective (about y)
P[2, 1] = random.uniform(-self.perspective, self.perspective) # y perspective (about x)
# Rotation and Scale
R = np.eye(3)
R = np.eye(3, dtype=np.float32)
a = random.uniform(-self.degrees, self.degrees)
# a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
s = random.uniform(1 - self.scale, 1 + self.scale)
@ -242,12 +242,12 @@ class RandomPerspective:
R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
# Shear
S = np.eye(3)
S = np.eye(3, dtype=np.float32)
S[0, 1] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) # x shear (deg)
S[1, 0] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) # y shear (deg)
# Translation
T = np.eye(3)
T = np.eye(3, dtype=np.float32)
T[0, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[0] # x translation (pixels)
T[1, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[1] # y translation (pixels)
@ -274,7 +274,7 @@ class RandomPerspective:
if n == 0:
return bboxes
xy = np.ones((n * 4, 3))
xy = np.ones((n * 4, 3), dtype=bboxes.dtype)
xy[:, :2] = bboxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
xy = xy @ M.T # transform
xy = (xy[:, :2] / xy[:, 2:3] if self.perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
@ -282,7 +282,7 @@ class RandomPerspective:
# create new boxes
x = xy[:, [0, 2, 4, 6]]
y = xy[:, [1, 3, 5, 7]]
return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1)), dtype=bboxes.dtype).reshape(4, n).T
def apply_segments(self, segments, M):
"""apply affine to segments and generate new bboxes from segments.
@ -298,7 +298,7 @@ class RandomPerspective:
if n == 0:
return [], segments
xy = np.ones((n * num, 3))
xy = np.ones((n * num, 3), dtype=segments.dtype)
segments = segments.reshape(-1, 2)
xy[:, :2] = segments
xy = xy @ M.T # transform
@ -319,7 +319,7 @@ class RandomPerspective:
n, nkpt = keypoints.shape[:2]
if n == 0:
return keypoints
xy = np.ones((n * nkpt, 3))
xy = np.ones((n * nkpt, 3), dtype=keypoints.dtype)
visible = keypoints[..., 2].reshape(n * nkpt, 1)
xy[:, :2] = keypoints[..., :2].reshape(n * nkpt, 2)
xy = xy @ M.T # transform

@ -3,6 +3,7 @@
import glob
import math
import os
from copy import deepcopy
from multiprocessing.pool import ThreadPool
from pathlib import Path
from typing import Optional
@ -177,7 +178,7 @@ class BaseDataset(Dataset):
return self.transforms(self.get_label_info(index))
def get_label_info(self, index):
label = self.labels[index].copy()
label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948
label.pop('shape', None) # shape is for rect, remove it
label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index)
label['ratio_pad'] = (label['resized_shape'][0] / label['ori_shape'][0],

@ -166,7 +166,9 @@ class YOLO:
"""
Raises TypeError is model is not a PyTorch model
"""
if not isinstance(self.model, nn.Module):
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == '.pt'
pt_module = isinstance(self.model, nn.Module)
if not (pt_module or pt_str):
raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. "
f'PyTorch models can be used to train, val, predict and export, i.e. '
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "

@ -40,7 +40,7 @@ def on_model_save(trainer):
# Upload checkpoints with rate limiting
is_best = trainer.best_fitness == trainer.fitness
if time() - session.timers['ckpt'] > session.rate_limits['ckpt']:
LOGGER.info(f'{PREFIX}Uploading checkpoint {session.model_id}')
LOGGER.info(f'{PREFIX}Uploading checkpoint https://hub.ultralytics.com/models/{session.model_id}')
session.upload_model(trainer.epoch, trainer.last, is_best)
session.timers['ckpt'] = time() # reset timer

@ -81,7 +81,8 @@ def segment2box(segment, width=640, height=640):
x, y = segment.T # segment xy
inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
x, y, = x[inside], y[inside]
return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros(4) # xyxy
return np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype) if any(x) else np.zeros(
4, dtype=segment.dtype) # xyxy
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
@ -529,7 +530,8 @@ def resample_segments(segments, n=1000):
s = np.concatenate((s, s[0:1, :]), axis=0)
x = np.linspace(0, len(s) - 1, n)
xp = np.arange(len(s))
segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)],
dtype=np.float32).reshape(2, -1).T # segment xy
return segments

Loading…
Cancel
Save