From 28e48be5b68f11a7a71029b7b7bccf6a3ac26a66 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 24 Mar 2023 00:22:20 +0100 Subject: [PATCH] `ultralytics 8.0.56` PyTorch 2.0 support and minor fixes (#1538) Co-authored-by: N-Friederich <127681326+N-Friederich@users.noreply.github.com> Co-authored-by: Uhrendoktor <36703334+Uhrendoktor@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Aman Agarwal Co-authored-by: ExtReMLapin <3909752+ExtReMLapin@users.noreply.github.com> Co-authored-by: Nadav Eidelstein <30617226+nodeav@users.noreply.github.com> --- .github/workflows/ci.yaml | 10 ++-- README.md | 48 ++++++++--------- README.zh-CN.md | 48 ++++++++--------- docs/app.md | 9 ++-- docs/hub.md | 9 ++-- ultralytics/__init__.py | 2 +- ultralytics/models/v8/yolov8-p2.yaml | 54 +++++++++++++++++++ ultralytics/tracker/trackers/basetrack.py | 4 ++ ultralytics/tracker/trackers/byte_tracker.py | 4 ++ .../yolo/data/dataloaders/stream_loaders.py | 2 +- ultralytics/yolo/data/dataset.py | 3 ++ ultralytics/yolo/engine/results.py | 17 +++--- ultralytics/yolo/utils/__init__.py | 2 +- ultralytics/yolo/utils/checks.py | 5 +- ultralytics/yolo/utils/torch_utils.py | 4 +- ultralytics/yolo/v8/detect/train.py | 1 + ultralytics/yolo/v8/detect/val.py | 13 ++++- ultralytics/yolo/v8/segment/train.py | 2 +- 18 files changed, 149 insertions(+), 88 deletions(-) create mode 100644 ultralytics/models/v8/yolov8-p2.yaml diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 0b1be69..e236ff5 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -7,7 +7,7 @@ on: push: branches: [main] pull_request: - branches: [main] + branches: [main, updates] schedule: - cron: '0 0 * * *' # runs at 00:00 UTC every day @@ -105,7 +105,9 @@ jobs: from ultralytics.yolo.utils.benchmarks import benchmark benchmark(model='${{ matrix.model }}-cls.pt', imgsz=160, half=False, hard_fail=0.61) - name: Benchmark Summary - run: cat benchmarks.log + run: | + cat benchmarks.log + echo "$(cat benchmarks.log)" >> $GITHUB_STEP_SUMMARY Tests: timeout-minutes: 60 @@ -133,9 +135,9 @@ jobs: run: | python -m pip install --upgrade pip wheel if [ "${{ matrix.torch }}" == "1.8.0" ]; then - pip install -e '.[export]' torch==1.8.0 torchvision==0.9.0 pytest --extra-index-url https://download.pytorch.org/whl/cpu + pip install -e . torch==1.8.0 torchvision==0.9.0 pytest --extra-index-url https://download.pytorch.org/whl/cpu else - pip install -e '.[export]' pytest --extra-index-url https://download.pytorch.org/whl/cpu + pip install -e . pytest --extra-index-url https://download.pytorch.org/whl/cpu fi - name: Check environment run: | diff --git a/README.md b/README.md index 86a4232..7713612 100644 --- a/README.md +++ b/README.md @@ -29,27 +29,24 @@ To request an Enterprise License please complete the form at [Ultralytics Licens
- - - - - - - - - - - - - - - - - - - - -
+ + + + + + + + + + + + + + + + + + ##
Documentation
@@ -262,20 +259,17 @@ the [Ultralytics Community Forum](https://community.ultralytics.com/). - + - - - - - + + diff --git a/README.zh-CN.md b/README.zh-CN.md index 9c6d554..2765db5 100644 --- a/README.zh-CN.md +++ b/README.zh-CN.md @@ -25,27 +25,24 @@ SOTA 模型。它在以前成功的 YOLO 版本基础上,引入了新的功能 如果要申请企业许可证,请填写 [Ultralytics 许可](https://ultralytics.com/license)。
- - - - - - - - - - - - - - - - - - - - -
+ + + + + + + + + + + + + + + + + + ##
文档
@@ -239,20 +236,17 @@ YOLOv8 在两种不同的 License 下可用: - + - - - - - + + diff --git a/docs/app.md b/docs/app.md index 9f1b504..8aaf686 100644 --- a/docs/app.md +++ b/docs/app.md @@ -7,20 +7,17 @@ - + - - - - - + + diff --git a/docs/hub.md b/docs/hub.md index 0d0cbd0..199fa63 100644 --- a/docs/hub.md +++ b/docs/hub.md @@ -7,20 +7,17 @@ - + - - - - - + + diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 45ce04c..8565dd0 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, GPL-3.0 license -__version__ = '8.0.55' +__version__ = '8.0.56' from ultralytics.yolo.engine.model import YOLO from ultralytics.yolo.utils.checks import check_yolo as checks diff --git a/ultralytics/models/v8/yolov8-p2.yaml b/ultralytics/models/v8/yolov8-p2.yaml new file mode 100644 index 0000000..f91a98c --- /dev/null +++ b/ultralytics/models/v8/yolov8-p2.yaml @@ -0,0 +1,54 @@ +# Ultralytics YOLO 🚀, GPL-3.0 license +# YOLOv8 object detection model with P2-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 768] + l: [1.00, 1.00, 512] + x: [1.00, 1.25, 512] + +# YOLOv8.0 backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0-p2 head +head: + - [-1, 1, nn.Upsample, [None, 2, 'nearest']] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, 'nearest']] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 15 (P3/8-small) + + - [-1, 1, nn.Upsample, [None, 2, 'nearest']] + - [[-1, 2], 1, Concat, [1]] # cat backbone P2 + - [-1, 3, C2f, [128]] # 18 (P2/4-xsmall) + + - [-1, 1, Conv, [128, 3, 2]] + - [[-1, 15], 1, Concat, [1]] # cat head P3 + - [-1, 3, C2f, [256]] # 21 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 24 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2f, [1024]] # 27 (P5/32-large) + + - [[18, 21, 24, 27], 1, Detect, [nc]] # Detect(P2, P3, P4, P5) diff --git a/ultralytics/tracker/trackers/basetrack.py b/ultralytics/tracker/trackers/basetrack.py index a173f03..71c8541 100644 --- a/ultralytics/tracker/trackers/basetrack.py +++ b/ultralytics/tracker/trackers/basetrack.py @@ -53,3 +53,7 @@ class BaseTrack: def mark_removed(self): self.state = TrackState.Removed + + @staticmethod + def reset_id(): + BaseTrack._count = 0 diff --git a/ultralytics/tracker/trackers/byte_tracker.py b/ultralytics/tracker/trackers/byte_tracker.py index 504e9e2..a6103e2 100644 --- a/ultralytics/tracker/trackers/byte_tracker.py +++ b/ultralytics/tracker/trackers/byte_tracker.py @@ -168,6 +168,7 @@ class BYTETracker: self.args = args self.max_time_lost = int(frame_rate / 30.0 * args.track_buffer) self.kalman_filter = self.get_kalmanfilter() + self.reset_id() def update(self, results, img=None): self.frame_id += 1 @@ -299,6 +300,9 @@ class BYTETracker: def multi_predict(self, tracks): STrack.multi_predict(tracks) + def reset_id(self): + STrack.reset_id() + @staticmethod def joint_stracks(tlista, tlistb): exists = {} diff --git a/ultralytics/yolo/data/dataloaders/stream_loaders.py b/ultralytics/yolo/data/dataloaders/stream_loaders.py index 05b3ac3..0dec447 100644 --- a/ultralytics/yolo/data/dataloaders/stream_loaders.py +++ b/ultralytics/yolo/data/dataloaders/stream_loaders.py @@ -284,6 +284,7 @@ class LoadPilAndNumpy: def __init__(self, im0, imgsz=640, stride=32, auto=True, transforms=None): if not isinstance(im0, list): im0 = [im0] + self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)] self.im0 = [self._single_check(im) for im in im0] self.imgsz = imgsz self.stride = stride @@ -291,7 +292,6 @@ class LoadPilAndNumpy: self.transforms = transforms self.mode = 'image' # generate fake paths - self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(self.im0)] self.bs = len(self.im0) @staticmethod diff --git a/ultralytics/yolo/data/dataset.py b/ultralytics/yolo/data/dataset.py index 2412882..048e16e 100644 --- a/ultralytics/yolo/data/dataset.py +++ b/ultralytics/yolo/data/dataset.py @@ -125,7 +125,10 @@ class YOLODataset(BaseDataset): self.label_files = img2label_paths(self.im_files) cache_path = Path(self.label_files[0]).parent.with_suffix('.cache') try: + import gc + gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585 cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dict + gc.enable() assert cache['version'] == self.cache_version # matches current version assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash except (FileNotFoundError, AssertionError, AttributeError): diff --git a/ultralytics/yolo/engine/results.py b/ultralytics/yolo/engine/results.py index 02a6839..3b70877 100644 --- a/ultralytics/yolo/engine/results.py +++ b/ultralytics/yolo/engine/results.py @@ -129,13 +129,14 @@ class Results: annotator = Annotator(deepcopy(self.orig_img), line_width, font_size, font, pil, example) boxes = self.boxes masks = self.masks - logits = self.probs + probs = self.probs names = self.names + hide_labels, hide_conf = False, not show_conf if boxes is not None: for d in reversed(boxes): - cls, conf = d.cls.squeeze(), d.conf.squeeze() - c = int(cls) - label = (f'{names[c]}' if names else f'{c}') + (f'{conf:.2f}' if show_conf else '') + c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item()) + name = ('' if id is None else f'id:{id} ') + names[c] + label = None if hide_labels else (name if hide_conf else f'{name} {conf:.2f}') annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) if masks is not None: @@ -143,10 +144,10 @@ class Results: im = F.resize(im.contiguous(), masks.data.shape[1:]) / 255 annotator.masks(masks.data, colors=[colors(x, True) for x in boxes.cls], im_gpu=im) - if logits is not None: - n5 = min(len(self.names), 5) - top5i = logits.argsort(0, descending=True)[:n5].tolist() # top 5 indices - text = f"{', '.join(f'{names[j] if names else j} {logits[j]:.2f}' for j in top5i)}, " + if probs is not None: + n5 = min(len(names), 5) + top5i = probs.argsort(0, descending=True)[:n5].tolist() # top 5 indices + text = f"{', '.join(f'{names[j] if names else j} {probs[j]:.2f}' for j in top5i)}, " annotator.text((32, 32), text, txt_color=(255, 255, 255)) # TODO: allow setting colors return np.asarray(annotator.im) if annotator.pil else annotator.im diff --git a/ultralytics/yolo/utils/__init__.py b/ultralytics/yolo/utils/__init__.py index a821aa6..5b7cb6d 100644 --- a/ultralytics/yolo/utils/__init__.py +++ b/ultralytics/yolo/utils/__init__.py @@ -445,7 +445,7 @@ def get_user_config_dir(sub_dir='Ultralytics'): return path -USER_CONFIG_DIR = get_user_config_dir() # Ultralytics settings dir +USER_CONFIG_DIR = os.getenv('YOLO_CONFIG_DIR', get_user_config_dir()) # Ultralytics settings dir def emojis(string=''): diff --git a/ultralytics/yolo/utils/checks.py b/ultralytics/yolo/utils/checks.py index d136246..7597087 100644 --- a/ultralytics/yolo/utils/checks.py +++ b/ultralytics/yolo/utils/checks.py @@ -21,7 +21,7 @@ import torch from matplotlib import font_manager from ultralytics.yolo.utils import (AUTOINSTALL, LOGGER, ONLINE, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, downloads, - emojis, is_colab, is_docker, is_jupyter, is_online, is_pip_package) + emojis, is_colab, is_docker, is_kaggle, is_online, is_pip_package) def is_ascii(s) -> bool: @@ -292,8 +292,7 @@ def check_yaml(file, suffix=('.yaml', '.yml'), hard=True): def check_imshow(warn=False): # Check if environment supports image displays try: - assert not is_jupyter() - assert not is_docker() + assert not any((is_colab(), is_kaggle(), is_docker())) cv2.imshow('test', np.zeros((1, 1, 3))) cv2.waitKey(1) cv2.destroyAllWindows() diff --git a/ultralytics/yolo/utils/torch_utils.py b/ultralytics/yolo/utils/torch_utils.py index 6ab54d0..5b510c4 100644 --- a/ultralytics/yolo/utils/torch_utils.py +++ b/ultralytics/yolo/utils/torch_utils.py @@ -27,6 +27,7 @@ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) TORCH_1_9 = check_version(torch.__version__, '1.9.0') TORCH_1_11 = check_version(torch.__version__, '1.11.0') TORCH_1_12 = check_version(torch.__version__, '1.12.0') +TORCH_2_X = check_version(torch.__version__, minimum='2.0') @contextmanager @@ -95,7 +96,8 @@ def select_device(device='', batch=0, newline=False, verbose=True): p = torch.cuda.get_device_properties(i) s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB arg = 'cuda:0' - elif mps and getattr(torch, 'has_mps', False) and torch.backends.mps.is_available(): # prefer MPS if available + elif mps and getattr(torch, 'has_mps', False) and torch.backends.mps.is_available() and TORCH_2_X: + # prefer MPS if available s += 'MPS\n' arg = 'mps' else: # revert to CPU diff --git a/ultralytics/yolo/v8/detect/train.py b/ultralytics/yolo/v8/detect/train.py index d2584db..6484cd7 100644 --- a/ultralytics/yolo/v8/detect/train.py +++ b/ultralytics/yolo/v8/detect/train.py @@ -134,6 +134,7 @@ class Loss: else: i = targets[:, 0] # image index _, counts = i.unique(return_counts=True) + counts = counts.to(dtype=torch.int32) out = torch.zeros(batch_size, counts.max(), 5, device=self.device) for j in range(batch_size): matches = i == j diff --git a/ultralytics/yolo/v8/detect/val.py b/ultralytics/yolo/v8/detect/val.py index b4bc45e..5d09942 100644 --- a/ultralytics/yolo/v8/detect/val.py +++ b/ultralytics/yolo/v8/detect/val.py @@ -108,8 +108,9 @@ class DetectionValidator(BaseValidator): # Save if self.args.save_json: self.pred_to_json(predn, batch['im_file'][si]) - # if self.args.save_txt: - # save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') + if self.args.save_txt: + file = self.save_dir / 'labels' / f'{Path(batch["im_file"][si]).stem}.txt' + self.save_one_txt(predn, self.args.save_conf, shape, file) def finalize_metrics(self, *args, **kwargs): self.metrics.speed = self.speed @@ -197,6 +198,14 @@ class DetectionValidator(BaseValidator): fname=self.save_dir / f'val_batch{ni}_pred.jpg', names=self.names) # pred + def save_one_txt(self, predn, save_conf, shape, file): + gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh + for *xyxy, conf, cls in predn.tolist(): + xywh = (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh + line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format + with open(file, 'a') as f: + f.write(('%g ' * len(line)).rstrip() % line + '\n') + def pred_to_json(self, predn, filename): stem = Path(filename).stem image_id = int(stem) if stem.isnumeric() else stem diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py index 3cc6466..5fc62b3 100644 --- a/ultralytics/yolo/v8/segment/train.py +++ b/ultralytics/yolo/v8/segment/train.py @@ -79,7 +79,7 @@ class SegLoss(Loss): # targets try: batch_idx = batch['batch_idx'].view(-1, 1) - targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1) + targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes'].to(dtype)), 1) targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)