From 24363236f22e164a853423671530d0c7c998531b Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 11 Apr 2023 15:30:01 +0200 Subject: [PATCH] `ultralytics 8.0.74` Pose labels, fp64 labels, Ensemble fixes (#1956) Co-authored-by: jjlira <63210717+jjlira@users.noreply.github.com> Co-authored-by: Jose Lira Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: wilsonlmh Co-authored-by: HaeJin Lee --- examples/hub.ipynb | 8 ++++---- ultralytics/__init__.py | 2 +- ultralytics/hub/session.py | 26 ++++++++++++++----------- ultralytics/nn/modules.py | 2 +- ultralytics/yolo/data/augment.py | 20 +++++++++---------- ultralytics/yolo/data/base.py | 3 ++- ultralytics/yolo/engine/model.py | 4 +++- ultralytics/yolo/utils/callbacks/hub.py | 2 +- ultralytics/yolo/utils/ops.py | 6 ++++-- 9 files changed, 41 insertions(+), 32 deletions(-) diff --git a/examples/hub.ipynb b/examples/hub.ipynb index 6839aec..0e3a9e4 100644 --- a/examples/hub.ipynb +++ b/examples/hub.ipynb @@ -26,14 +26,14 @@ "\n", "\n", "
\n", - " \n", - " \"CI\n", - " \n", + " \n", + " \"CI\n", + " \n", " \"Open\n", "\n", "Welcome to the [Ultralytics](https://ultralytics.com/) HUB notebook! \n", "\n", - "This notebook allows you to train [YOLOv5](https://github.com/ultralytics/yolov5) and [YOLOv8](https://github.com/ultralytics/ultralytics) 🚀 models using [HUB](https://hub.ultralytics.com/). Please browse the YOLOv8 Docs for details, raise an issue on GitHub for support, and join our Discord community for questions and discussions!\n", + "This notebook allows you to train [YOLOv5](https://github.com/ultralytics/yolov5) and [YOLOv8](https://github.com/ultralytics/ultralytics) 🚀 models using [HUB](https://hub.ultralytics.com/). Please browse the YOLOv8 Docs for details, raise an issue on GitHub for support, and join our Discord community for questions and discussions!\n", "
" ] }, diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index dae00e0..61c768f 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, GPL-3.0 license -__version__ = '8.0.73' +__version__ = '8.0.74' from ultralytics.hub import start from ultralytics.yolo.engine.model import YOLO diff --git a/ultralytics/hub/session.py b/ultralytics/hub/session.py index a0cf72c..a2aa6f5 100644 --- a/ultralytics/hub/session.py +++ b/ultralytics/hub/session.py @@ -112,17 +112,21 @@ class HUBTrainingSession: raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix self.model_id = data['id'] - self.train_args = { - 'batch': data['batch' if ('batch' in data) else 'batch_size'], # TODO: deprecate 'batch_size' in 3Q23 - 'epochs': data['epochs'], - 'imgsz': data['imgsz'], - 'patience': data['patience'], - 'device': data['device'], - 'cache': data['cache'], - 'data': data['data']} - - self.model_file = data.get('cfg', data['weights']) - self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u + if data['status'] == 'new': # new model to start training + self.train_args = { + # TODO: deprecate 'batch_size' key for 'batch' in 3Q23 + 'batch': data['batch' if ('batch' in data) else 'batch_size'], + 'epochs': data['epochs'], + 'imgsz': data['imgsz'], + 'patience': data['patience'], + 'device': data['device'], + 'cache': data['cache'], + 'data': data['data']} + self.model_file = data.get('cfg', data['weights']) + self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u + elif data['status'] == 'training': # existing model to resume training + self.train_args = {'data': data['data'], 'resume': True} + self.model_file = data['resume'] return data except requests.exceptions.ConnectionError as e: diff --git a/ultralytics/nn/modules.py b/ultralytics/nn/modules.py index ee21d79..c53037d 100644 --- a/ultralytics/nn/modules.py +++ b/ultralytics/nn/modules.py @@ -374,7 +374,7 @@ class Ensemble(nn.ModuleList): y = [module(x, augment, profile, visualize)[0] for module in self] # y = torch.stack(y).max(0)[0] # max ensemble # y = torch.stack(y).mean(0) # mean ensemble - y = torch.cat(y, 1) # nms ensemble + y = torch.cat(y, 2) # nms ensemble, y shape(B, HW, C) return y, None # inference, train output diff --git a/ultralytics/yolo/data/augment.py b/ultralytics/yolo/data/augment.py index a234da4..0394f33 100644 --- a/ultralytics/yolo/data/augment.py +++ b/ultralytics/yolo/data/augment.py @@ -127,7 +127,7 @@ class Mosaic(BaseMixTransform): s = self.imgsz yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y for i in range(4): - labels_patch = (labels if i == 0 else labels['mix_labels'][i - 1]).copy() + labels_patch = labels if i == 0 else labels['mix_labels'][i - 1] # Load image img = labels_patch['img'] h, w = labels_patch.pop('resized_shape') @@ -223,18 +223,18 @@ class RandomPerspective: def affine_transform(self, img, border): # Center - C = np.eye(3) + C = np.eye(3, dtype=np.float32) C[0, 2] = -img.shape[1] / 2 # x translation (pixels) C[1, 2] = -img.shape[0] / 2 # y translation (pixels) # Perspective - P = np.eye(3) + P = np.eye(3, dtype=np.float32) P[2, 0] = random.uniform(-self.perspective, self.perspective) # x perspective (about y) P[2, 1] = random.uniform(-self.perspective, self.perspective) # y perspective (about x) # Rotation and Scale - R = np.eye(3) + R = np.eye(3, dtype=np.float32) a = random.uniform(-self.degrees, self.degrees) # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations s = random.uniform(1 - self.scale, 1 + self.scale) @@ -242,12 +242,12 @@ class RandomPerspective: R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s) # Shear - S = np.eye(3) + S = np.eye(3, dtype=np.float32) S[0, 1] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) # x shear (deg) S[1, 0] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) # y shear (deg) # Translation - T = np.eye(3) + T = np.eye(3, dtype=np.float32) T[0, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[0] # x translation (pixels) T[1, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[1] # y translation (pixels) @@ -274,7 +274,7 @@ class RandomPerspective: if n == 0: return bboxes - xy = np.ones((n * 4, 3)) + xy = np.ones((n * 4, 3), dtype=bboxes.dtype) xy[:, :2] = bboxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1 xy = xy @ M.T # transform xy = (xy[:, :2] / xy[:, 2:3] if self.perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine @@ -282,7 +282,7 @@ class RandomPerspective: # create new boxes x = xy[:, [0, 2, 4, 6]] y = xy[:, [1, 3, 5, 7]] - return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T + return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1)), dtype=bboxes.dtype).reshape(4, n).T def apply_segments(self, segments, M): """apply affine to segments and generate new bboxes from segments. @@ -298,7 +298,7 @@ class RandomPerspective: if n == 0: return [], segments - xy = np.ones((n * num, 3)) + xy = np.ones((n * num, 3), dtype=segments.dtype) segments = segments.reshape(-1, 2) xy[:, :2] = segments xy = xy @ M.T # transform @@ -319,7 +319,7 @@ class RandomPerspective: n, nkpt = keypoints.shape[:2] if n == 0: return keypoints - xy = np.ones((n * nkpt, 3)) + xy = np.ones((n * nkpt, 3), dtype=keypoints.dtype) visible = keypoints[..., 2].reshape(n * nkpt, 1) xy[:, :2] = keypoints[..., :2].reshape(n * nkpt, 2) xy = xy @ M.T # transform diff --git a/ultralytics/yolo/data/base.py b/ultralytics/yolo/data/base.py index 28db054..ae64290 100644 --- a/ultralytics/yolo/data/base.py +++ b/ultralytics/yolo/data/base.py @@ -3,6 +3,7 @@ import glob import math import os +from copy import deepcopy from multiprocessing.pool import ThreadPool from pathlib import Path from typing import Optional @@ -177,7 +178,7 @@ class BaseDataset(Dataset): return self.transforms(self.get_label_info(index)) def get_label_info(self, index): - label = self.labels[index].copy() + label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948 label.pop('shape', None) # shape is for rect, remove it label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index) label['ratio_pad'] = (label['resized_shape'][0] / label['ori_shape'][0], diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index d81dc29..26e19fd 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -166,7 +166,9 @@ class YOLO: """ Raises TypeError is model is not a PyTorch model """ - if not isinstance(self.model, nn.Module): + pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == '.pt' + pt_module = isinstance(self.model, nn.Module) + if not (pt_module or pt_str): raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. " f'PyTorch models can be used to train, val, predict and export, i.e. ' f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only " diff --git a/ultralytics/yolo/utils/callbacks/hub.py b/ultralytics/yolo/utils/callbacks/hub.py index 7d127cd..b485895 100644 --- a/ultralytics/yolo/utils/callbacks/hub.py +++ b/ultralytics/yolo/utils/callbacks/hub.py @@ -40,7 +40,7 @@ def on_model_save(trainer): # Upload checkpoints with rate limiting is_best = trainer.best_fitness == trainer.fitness if time() - session.timers['ckpt'] > session.rate_limits['ckpt']: - LOGGER.info(f'{PREFIX}Uploading checkpoint {session.model_id}') + LOGGER.info(f'{PREFIX}Uploading checkpoint https://hub.ultralytics.com/models/{session.model_id}') session.upload_model(trainer.epoch, trainer.last, is_best) session.timers['ckpt'] = time() # reset timer diff --git a/ultralytics/yolo/utils/ops.py b/ultralytics/yolo/utils/ops.py index 26b6fe8..a59197f 100644 --- a/ultralytics/yolo/utils/ops.py +++ b/ultralytics/yolo/utils/ops.py @@ -81,7 +81,8 @@ def segment2box(segment, width=640, height=640): x, y = segment.T # segment xy inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height) x, y, = x[inside], y[inside] - return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros(4) # xyxy + return np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype) if any(x) else np.zeros( + 4, dtype=segment.dtype) # xyxy def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None): @@ -529,7 +530,8 @@ def resample_segments(segments, n=1000): s = np.concatenate((s, s[0:1, :]), axis=0) x = np.linspace(0, len(s) - 1, n) xp = np.arange(len(s)) - segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy + segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], + dtype=np.float32).reshape(2, -1).T # segment xy return segments