From 5629ed0bb7029f3e3fc65f98875bb05a7f614705 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 11 Apr 2023 01:00:09 +0200 Subject: [PATCH] `ultralytics 8.0.73` minor fixes (#1929) Co-authored-by: Yonghye Kwon Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: joseliraGB <122470533+joseliraGB@users.noreply.github.com> --- .github/workflows/ci.yaml | 4 +- examples/hub.ipynb | 103 ++++++++++++++++++++++++ ultralytics/__init__.py | 2 +- ultralytics/hub/session.py | 4 +- ultralytics/nn/tasks.py | 2 +- ultralytics/yolo/data/base.py | 6 +- ultralytics/yolo/data/utils.py | 4 +- ultralytics/yolo/engine/model.py | 1 - ultralytics/yolo/engine/predictor.py | 31 ++++++- ultralytics/yolo/engine/results.py | 79 +++++++++++++++++- ultralytics/yolo/engine/trainer.py | 2 +- ultralytics/yolo/utils/__init__.py | 5 +- ultralytics/yolo/v8/classify/predict.py | 41 ---------- ultralytics/yolo/v8/detect/predict.py | 43 ---------- ultralytics/yolo/v8/pose/predict.py | 45 ----------- ultralytics/yolo/v8/segment/predict.py | 50 ------------ 16 files changed, 224 insertions(+), 198 deletions(-) create mode 100644 examples/hub.ipynb diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index a6d60d0..7a274cf 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -76,9 +76,9 @@ jobs: run: | python -m pip install --upgrade pip wheel if [ "${{ matrix.os }}" == "macos-latest" ]; then - pip install -e . coremltools openvino-dev tensorflow-macos --extra-index-url https://download.pytorch.org/whl/cpu + pip install -e . coremltools openvino-dev tensorflow-macos tensorflowjs --extra-index-url https://download.pytorch.org/whl/cpu else - pip install -e . coremltools openvino-dev tensorflow-cpu --extra-index-url https://download.pytorch.org/whl/cpu + pip install -e . coremltools openvino-dev tensorflow-cpu tensorflowjs --extra-index-url https://download.pytorch.org/whl/cpu fi yolo export format=tflite - name: Check environment diff --git a/examples/hub.ipynb b/examples/hub.ipynb new file mode 100644 index 0000000..6839aec --- /dev/null +++ b/examples/hub.ipynb @@ -0,0 +1,103 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Ultralytics HUB", + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "FIzICjaph_Wy" + }, + "source": [ + "\n", + "\n", + "\n", + "
\n", + " \n", + " \"CI\n", + " \n", + " \"Open\n", + "\n", + "Welcome to the [Ultralytics](https://ultralytics.com/) HUB notebook! \n", + "\n", + "This notebook allows you to train [YOLOv5](https://github.com/ultralytics/yolov5) and [YOLOv8](https://github.com/ultralytics/ultralytics) 🚀 models using [HUB](https://hub.ultralytics.com/). Please browse the YOLOv8 Docs for details, raise an issue on GitHub for support, and join our Discord community for questions and discussions!\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eRQ2ow94MiOv" + }, + "source": [ + "# Setup\n", + "\n", + "Pip install `ultralytics` and [dependencies](https://github.com/ultralytics/ultralytics/blob/main/requirements.txt) and check software and hardware." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "FyDnXd-n4c7Y", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "22dcbc27-9c6f-44fb-9745-620431f93793" + }, + "source": [ + "%pip install ultralytics # install\n", + "from ultralytics import YOLO, checks, hub\n", + "checks() # checks" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Ultralytics YOLOv8.0.64 🚀 Python-3.9.16 torch-2.0.0+cu118 CUDA:0 (Tesla T4, 15102MiB)\n", + "Setup complete ✅ (2 CPUs, 12.7 GB RAM, 28.3/166.8 GB disk)\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cQ9BwaAqxAm4" + }, + "source": [ + "# Start\n", + "\n", + "Login with your [API key](https://hub.ultralytics.com/settings?tab=api+keys), select your YOLO 🚀 model and start training!" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "XSlZaJ9Iw_iZ" + }, + "source": [ + "hub.login('API_KEY') # use your API key\n", + "\n", + "model = YOLO('https://hub.ultralytics.com/MODEL_ID') # use your model URL\n", + "model.train() # train model" + ], + "execution_count": null, + "outputs": [] + } + ] +} diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 8f38845..dae00e0 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, GPL-3.0 license -__version__ = '8.0.72' +__version__ = '8.0.73' from ultralytics.hub import start from ultralytics.yolo.engine.model import YOLO diff --git a/ultralytics/hub/session.py b/ultralytics/hub/session.py index 0f93f50..a0cf72c 100644 --- a/ultralytics/hub/session.py +++ b/ultralytics/hub/session.py @@ -112,10 +112,8 @@ class HUBTrainingSession: raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix self.model_id = data['id'] - # TODO: restore when server keys when dataset URL and GPU train is working - self.train_args = { - 'batch': data['batch_size'], + 'batch': data['batch' if ('batch' in data) else 'batch_size'], # TODO: deprecate 'batch_size' in 3Q23 'epochs': data['epochs'], 'imgsz': data['imgsz'], 'patience': data['patience'], diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index c8f4627..1109346 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -539,7 +539,7 @@ def guess_model_task(model): model (nn.Module) or (dict): PyTorch model or model configuration in YAML format. Returns: - str: Task of the model ('detect', 'segment', 'classify'). + str: Task of the model ('detect', 'segment', 'classify', 'pose'). Raises: SyntaxError: If the task of the model could not be determined. diff --git a/ultralytics/yolo/data/base.py b/ultralytics/yolo/data/base.py index f9fd90f..28db054 100644 --- a/ultralytics/yolo/data/base.py +++ b/ultralytics/yolo/data/base.py @@ -180,10 +180,8 @@ class BaseDataset(Dataset): label = self.labels[index].copy() label.pop('shape', None) # shape is for rect, remove it label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index) - label['ratio_pad'] = ( - label['resized_shape'][0] / label['ori_shape'][0], - label['resized_shape'][1] / label['ori_shape'][1], - ) # for evaluation + label['ratio_pad'] = (label['resized_shape'][0] / label['ori_shape'][0], + label['resized_shape'][1] / label['ori_shape'][1]) # for evaluation if self.rect: label['rect_shape'] = self.batch_shapes[self.batch[index]] label = self.update_labels_info(label) diff --git a/ultralytics/yolo/data/utils.py b/ultralytics/yolo/data/utils.py index ffa4f3d..396315b 100644 --- a/ultralytics/yolo/data/utils.py +++ b/ultralytics/yolo/data/utils.py @@ -17,7 +17,8 @@ from PIL import ExifTags, Image, ImageOps from tqdm import tqdm from ultralytics.nn.autobackend import check_class_names -from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, clean_url, colorstr, emojis, yaml_load +from ultralytics.yolo.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, SETTINGS_YAML, clean_url, colorstr, emojis, + yaml_load) from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii from ultralytics.yolo.utils.downloads import download, safe_download, unzip_file from ultralytics.yolo.utils.ops import segments2boxes @@ -246,6 +247,7 @@ def check_det_dataset(dataset, autodownload=True): if s and autodownload: LOGGER.warning(m) else: + m += f"\nNote dataset download directory is '{DATASETS_DIR}'. You can update this in '{SETTINGS_YAML}'" raise FileNotFoundError(m) t = time.time() if s.startswith('http') and s.endswith('.zip'): # URL diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index 3469c2b..d81dc29 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -356,7 +356,6 @@ class YOLO: raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'") if overrides.get('resume'): overrides['resume'] = self.ckpt_path - self.task = overrides.get('task') or self.task self.trainer = TASK_MAP[self.task][1](overrides=overrides, _callbacks=self.callbacks) if not overrides.get('resume'): # manually set model only if not resuming diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py index 2a6e68c..e266b0b 100644 --- a/ultralytics/yolo/engine/predictor.py +++ b/ultralytics/yolo/engine/predictor.py @@ -109,8 +109,35 @@ class BasePredictor: def preprocess(self, img): pass - def write_results(self, results, batch, print_string): - raise NotImplementedError('print_results function needs to be implemented') + def write_results(self, idx, results, batch): + p, im, _ = batch + log_string = '' + if len(im.shape) == 3: + im = im[None] # expand for batch dim + self.seen += 1 + if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1 + log_string += f'{idx}: ' + frame = self.dataset.count + else: + frame = getattr(self.dataset, 'frame', 0) + self.data_path = p + self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}') + log_string += '%gx%g ' % im.shape[2:] # print string + result = results[idx] + log_string += result.verbose() + + if self.args.save or self.args.show: # Add bbox to image + plot_args = dict(line_width=self.args.line_thickness, boxes=self.args.boxes) + if not self.args.retina_masks: + plot_args['im_gpu'] = im[idx] + self.plotted_img = result.plot(**plot_args) + # write + if self.args.save_txt: + result.save_txt(f'{self.txt_path}.txt', save_conf=self.args.save_conf) + if self.args.save_crop: + result.save_crop(save_dir=self.save_dir / 'crops', file_name=self.data_path.stem) + + return log_string def postprocess(self, preds, img, orig_img): return preds diff --git a/ultralytics/yolo/engine/results.py b/ultralytics/yolo/engine/results.py index b50ddc0..b6e780a 100644 --- a/ultralytics/yolo/engine/results.py +++ b/ultralytics/yolo/engine/results.py @@ -7,13 +7,14 @@ Usage: See https://docs.ultralytics.com/modes/predict/ from copy import deepcopy from functools import lru_cache +from pathlib import Path import numpy as np import torch from ultralytics.yolo.data.augment import LetterBox from ultralytics.yolo.utils import LOGGER, SimpleClass, deprecation_warn, ops -from ultralytics.yolo.utils.plotting import Annotator, colors +from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box class BaseTensor(SimpleClass): @@ -233,6 +234,80 @@ class Results(SimpleClass): return annotator.result() + def verbose(self): + """ + Return log string for each tasks. + """ + log_string = '' + probs = self.probs + boxes = self.boxes + if len(self) == 0: + return log_string if probs is not None else log_string + '(no detections), ' + if probs is not None: + n5 = min(len(self.names), 5) + top5i = probs.argsort(0, descending=True)[:n5].tolist() # top 5 indices + log_string += f"{', '.join(f'{self.names[j]} {probs[j]:.2f}' for j in top5i)}, " + if boxes: + for c in boxes.cls.unique(): + n = (boxes.cls == c).sum() # detections per class + log_string += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " + return log_string + + def save_txt(self, txt_file, save_conf=False): + """Save predictions into txt file. + + Args: + txt_file (str): txt file path. + save_conf (bool): save confidence score or not. + """ + boxes = self.boxes + masks = self.masks + probs = self.probs + kpts = self.keypoints + texts = [] + if probs is not None: + # classify + n5 = min(len(self.names), 5) + top5i = probs.argsort(0, descending=True)[:n5].tolist() # top 5 indices + [texts.append(f'{probs[j]:.2f} {self.names[j]}') for j in top5i] + elif boxes: + # detect/segment/pose + for j, d in enumerate(boxes): + c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item()) + line = (c, *d.xywhn.view(-1)) + if masks: + seg = masks[j].xyn[0].copy().reshape(-1) # reversed mask.xyn, (n,2) to (n*2) + line = (c, *seg) + if kpts is not None: + kpt = (kpts[j][:, :2] / d.orig_shape[[1, 0]]).reshape(-1).tolist() + line += (*kpt, ) + line += (conf, ) * save_conf + (() if id is None else (id, )) + texts.append(('%g ' * len(line)).rstrip() % line) + + with open(txt_file, 'a') as f: + for text in texts: + f.write(text + '\n') + + def save_crop(self, save_dir, file_name=Path('im.jpg')): + """Save cropped predictions to `save_dir/cls/file_name.jpg`. + + Args: + save_dir (str | pathlib.Path): Save path. + file_name (str | pathlib.Path): File name. + """ + if self.probs is not None: + LOGGER.warning('Warning: Classify task do not support `save_crop`.') + return + if isinstance(save_dir, str): + save_dir = Path(save_dir) + if isinstance(file_name, str): + file_name = Path(file_name) + for d in self.boxes: + save_one_box(d.xyxy, + self.orig_img.copy(), + file=save_dir / self.names[int(d.cls)] / f'{file_name.stem}.jpg', + BGR=True) + class Boxes(BaseTensor): """ @@ -339,6 +414,8 @@ class Masks(BaseTensor): """ def __init__(self, masks, orig_shape) -> None: + if masks.ndim == 2: + masks = masks[None, :] self.masks = masks # N, h, w self.orig_shape = orig_shape diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index c400651..0100d9c 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -552,7 +552,7 @@ class BaseTrainer: if self.resume: assert start_epoch > 0, \ f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \ - f"Start a new training without --resume, i.e. 'yolo task=... mode=train model={self.args.model}'" + f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'" LOGGER.info( f'Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs') if self.epochs < start_epoch: diff --git a/ultralytics/yolo/utils/__init__.py b/ultralytics/yolo/utils/__init__.py index aec7e97..e9daf7c 100644 --- a/ultralytics/yolo/utils/__init__.py +++ b/ultralytics/yolo/utils/__init__.py @@ -490,6 +490,7 @@ def get_user_config_dir(sub_dir='Ultralytics'): USER_CONFIG_DIR = Path(os.getenv('YOLO_CONFIG_DIR', get_user_config_dir())) # Ultralytics settings dir +SETTINGS_YAML = USER_CONFIG_DIR / 'settings.yaml' def emojis(string=''): @@ -591,7 +592,7 @@ def set_sentry(): logging.getLogger(logger).setLevel(logging.CRITICAL) -def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.3'): +def get_settings(file=SETTINGS_YAML, version='0.0.3'): """ Loads a global Ultralytics settings YAML file or creates one with default values if it does not exist. @@ -640,7 +641,7 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.3'): return settings -def set_settings(kwargs, file=USER_CONFIG_DIR / 'settings.yaml'): +def set_settings(kwargs, file=SETTINGS_YAML): """ Function that runs on a first-time ultralytics package installation to set up global settings and create necessary directories. diff --git a/ultralytics/yolo/v8/classify/predict.py b/ultralytics/yolo/v8/classify/predict.py index e869dfd..b2bab40 100644 --- a/ultralytics/yolo/v8/classify/predict.py +++ b/ultralytics/yolo/v8/classify/predict.py @@ -5,14 +5,10 @@ import torch from ultralytics.yolo.engine.predictor import BasePredictor from ultralytics.yolo.engine.results import Results from ultralytics.yolo.utils import DEFAULT_CFG, ROOT -from ultralytics.yolo.utils.plotting import Annotator class ClassificationPredictor(BasePredictor): - def get_annotator(self, img): - return Annotator(img, example=str(self.model.names), pil=True) - def preprocess(self, img): img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device) return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 @@ -27,43 +23,6 @@ class ClassificationPredictor(BasePredictor): return results - def write_results(self, idx, results, batch): - p, im, im0 = batch - log_string = '' - if len(im.shape) == 3: - im = im[None] # expand for batch dim - self.seen += 1 - im0 = im0.copy() - if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1 - log_string += f'{idx}: ' - frame = self.dataset.count - else: - frame = getattr(self.dataset, 'frame', 0) - - self.data_path = p - # save_path = str(self.save_dir / p.name) # im.jpg - self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}') - log_string += '%gx%g ' % im.shape[2:] # print string - - result = results[idx] - if len(result) == 0: - return log_string - prob = result.probs - # Print results - n5 = min(len(self.model.names), 5) - top5i = prob.argsort(0, descending=True)[:n5].tolist() # top 5 indices - log_string += f"{', '.join(f'{self.model.names[j]} {prob[j]:.2f}' for j in top5i)}, " - - # write - if self.args.save or self.args.show: # Add bbox to image - self.plotted_img = result.plot() - if self.args.save_txt: # Write to file - text = '\n'.join(f'{prob[j]:.2f} {self.model.names[j]}' for j in top5i) - with open(f'{self.txt_path}.txt', 'a') as f: - f.write(text + '\n') - - return log_string - def predict(cfg=DEFAULT_CFG, use_python=False): model = cfg.model or 'yolov8n-cls.pt' # or "resnet18" diff --git a/ultralytics/yolo/v8/detect/predict.py b/ultralytics/yolo/v8/detect/predict.py index 4a9621a..afaf760 100644 --- a/ultralytics/yolo/v8/detect/predict.py +++ b/ultralytics/yolo/v8/detect/predict.py @@ -5,7 +5,6 @@ import torch from ultralytics.yolo.engine.predictor import BasePredictor from ultralytics.yolo.engine.results import Results from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops -from ultralytics.yolo.utils.plotting import save_one_box class DetectionPredictor(BasePredictor): @@ -34,48 +33,6 @@ class DetectionPredictor(BasePredictor): results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred)) return results - def write_results(self, idx, results, batch): - p, im, im0 = batch - log_string = '' - if len(im.shape) == 3: - im = im[None] # expand for batch dim - self.seen += 1 - imc = im0.copy() if self.args.save_crop else im0 - if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1 - log_string += f'{idx}: ' - frame = self.dataset.count - else: - frame = getattr(self.dataset, 'frame', 0) - self.data_path = p - self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}') - log_string += '%gx%g ' % im.shape[2:] # print string - - result = results[idx] # TODO: make boxes inherit from tensors - if len(result) == 0: - return f'{log_string}(no detections), ' - det = result.boxes - for c in det.cls.unique(): - n = (det.cls == c).sum() # detections per class - log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, " - - if self.args.save or self.args.show: # Add bbox to image - self.plotted_img = result.plot(line_width=self.args.line_thickness) - - # write - for d in reversed(det): - c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item()) - if self.args.save_txt: # Write to file - line = (c, *d.xywhn.view(-1)) + (conf, ) * self.args.save_conf + (() if id is None else (id, )) - with open(f'{self.txt_path}.txt', 'a') as f: - f.write(('%g ' * len(line)).rstrip() % line + '\n') - if self.args.save_crop: - save_one_box(d.xyxy, - imc, - file=self.save_dir / 'crops' / self.model.names[c] / f'{self.data_path.stem}.jpg', - BGR=True) - - return log_string - def predict(cfg=DEFAULT_CFG, use_python=False): model = cfg.model or 'yolov8n.pt' diff --git a/ultralytics/yolo/v8/pose/predict.py b/ultralytics/yolo/v8/pose/predict.py index 06daa96..bdb5cd2 100644 --- a/ultralytics/yolo/v8/pose/predict.py +++ b/ultralytics/yolo/v8/pose/predict.py @@ -2,7 +2,6 @@ from ultralytics.yolo.engine.results import Results from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops -from ultralytics.yolo.utils.plotting import save_one_box from ultralytics.yolo.v8.detect.predict import DetectionPredictor @@ -34,50 +33,6 @@ class PosePredictor(DetectionPredictor): keypoints=pred_kpts)) return results - def write_results(self, idx, results, batch): - p, im, im0 = batch - log_string = '' - if len(im.shape) == 3: - im = im[None] # expand for batch dim - self.seen += 1 - imc = im0.copy() if self.args.save_crop else im0 - if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1 - log_string += f'{idx}: ' - frame = self.dataset.count - else: - frame = getattr(self.dataset, 'frame', 0) - self.data_path = p - self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}') - log_string += '%gx%g ' % im.shape[2:] # print string - - result = results[idx] # TODO: make boxes inherit from tensors - if len(result) == 0: - return f'{log_string}(no detections), ' - det = result.boxes - for c in det.cls.unique(): - n = (det.cls == c).sum() # detections per class - log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, " - - if self.args.save or self.args.show: # Add bbox to image - self.plotted_img = result.plot(line_width=self.args.line_thickness, boxes=self.args.boxes) - - # write - for j, d in enumerate(reversed(det)): - c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item()) - if self.args.save_txt: # Write to file - kpt = (result[j].keypoints[:, :2] / d.orig_shape[[1, 0]]).reshape(-1).tolist() - box = d.xywhn.view(-1).tolist() - line = (c, *box, *kpt) + (conf, ) * self.args.save_conf + (() if id is None else (id, )) - with open(f'{self.txt_path}.txt', 'a') as f: - f.write(('%g ' * len(line)).rstrip() % line + '\n') - if self.args.save_crop: - save_one_box(d.xyxy, - imc, - file=self.save_dir / 'crops' / self.model.model.names[c] / f'{self.data_path.stem}.jpg', - BGR=True) - - return log_string - def predict(cfg=DEFAULT_CFG, use_python=False): model = cfg.model or 'yolov8n-pose.pt' diff --git a/ultralytics/yolo/v8/segment/predict.py b/ultralytics/yolo/v8/segment/predict.py index 969b978..5812487 100644 --- a/ultralytics/yolo/v8/segment/predict.py +++ b/ultralytics/yolo/v8/segment/predict.py @@ -4,7 +4,6 @@ import torch from ultralytics.yolo.engine.results import Results from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops -from ultralytics.yolo.utils.plotting import save_one_box from ultralytics.yolo.v8.detect.predict import DetectionPredictor @@ -40,55 +39,6 @@ class SegmentationPredictor(DetectionPredictor): Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)) return results - def write_results(self, idx, results, batch): - p, im, im0 = batch - log_string = '' - if len(im.shape) == 3: - im = im[None] # expand for batch dim - self.seen += 1 - imc = im0.copy() if self.args.save_crop else im0 - if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1 - log_string += f'{idx}: ' - frame = self.dataset.count - else: - frame = getattr(self.dataset, 'frame', 0) - - self.data_path = p - self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}') - log_string += '%gx%g ' % im.shape[2:] # print string - - result = results[idx] - if len(result) == 0: - return f'{log_string}(no detections), ' - det, mask = result.boxes, result.masks # getting tensors TODO: mask mask,box inherit for tensor - - # Print results - for c in det.cls.unique(): - n = (det.cls == c).sum() # detections per class - log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, " - - # Mask plotting - if self.args.save or self.args.show: - im_gpu = torch.as_tensor(im0, dtype=torch.float16, device=mask.masks.device).permute( - 2, 0, 1).flip(0).contiguous() / 255 if self.args.retina_masks else im[idx] - self.plotted_img = result.plot(line_width=self.args.line_thickness, im_gpu=im_gpu, boxes=self.args.boxes) - - # Write results - for j, d in enumerate(reversed(det)): - c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item()) - if self.args.save_txt: # Write to file - seg = mask.xyn[len(det) - j - 1].copy().reshape(-1) # reversed mask.xyn, (n,2) to (n*2) - line = (c, *seg) + (conf, ) * self.args.save_conf + (() if id is None else (id, )) - with open(f'{self.txt_path}.txt', 'a') as f: - f.write(('%g ' * len(line)).rstrip() % line + '\n') - if self.args.save_crop: - save_one_box(d.xyxy, - imc, - file=self.save_dir / 'crops' / self.model.names[c] / f'{self.data_path.stem}.jpg', - BGR=True) - - return log_string - def predict(cfg=DEFAULT_CFG, use_python=False): model = cfg.model or 'yolov8n-seg.pt'