From 522f1937edb0728f1fb1ec6398ae56d7f476c3a4 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 30 Jan 2023 22:34:28 +0100 Subject: [PATCH] ImageNet names, classify inference, resume fixes (#712) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> --- .pre-commit-config.yaml | 10 +- examples/tutorial.ipynb | 117 +- tests/test_cli.py | 2 +- ultralytics/__init__.py | 2 +- ultralytics/nn/autobackend.py | 11 +- ultralytics/yolo/data/build.py | 2 +- .../yolo/data/dataloaders/stream_loaders.py | 12 +- ultralytics/yolo/data/datasets/ImageNet.yaml | 1003 +++++++++++++++++ ultralytics/yolo/engine/model.py | 4 +- ultralytics/yolo/engine/predictor.py | 4 +- ultralytics/yolo/engine/results.py | 2 +- ultralytics/yolo/engine/trainer.py | 21 +- ultralytics/yolo/utils/__init__.py | 6 +- ultralytics/yolo/utils/autobatch.py | 33 +- ultralytics/yolo/v8/detect/predict.py | 3 +- ultralytics/yolo/v8/segment/predict.py | 2 +- 16 files changed, 1120 insertions(+), 114 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 54903db..b0724f1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,11 +31,11 @@ repos: name: Upgrade code args: [ --py37-plus ] - - repo: https://github.com/PyCQA/isort - rev: 5.11.4 - hooks: - - id: isort - name: Sort imports + # - repo: https://github.com/PyCQA/isort + # rev: 5.11.4 + # hooks: + # - id: isort + # name: Sort imports - repo: https://github.com/pre-commit/mirrors-yapf rev: v0.32.0 diff --git a/examples/tutorial.ipynb b/examples/tutorial.ipynb index 057e93e..ed2d7d5 100644 --- a/examples/tutorial.ipynb +++ b/examples/tutorial.ipynb @@ -57,10 +57,9 @@ "metadata": { "id": "wbvMlHd_QwMG", "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000 + "base_uri": "https://localhost:8080/" }, - "outputId": "5006941e-44ff-4e27-f53e-31bf87221334" + "outputId": "9bda69d4-e57f-404b-b6fe-117234e24677" }, "source": [ "# Pip install method (recommended)\n", @@ -68,14 +67,14 @@ "import ultralytics\n", "ultralytics.checks()" ], - "execution_count": null, + "execution_count": 1, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ - "Ultralytics YOLOv8.0.5 🚀 Python-3.8.16 torch-1.13.1+cu116 CUDA:0 (Tesla T4, 15110MiB)\n", - "Setup complete ✅ (2 CPUs, 12.7 GB RAM, 23.0/166.8 GB disk)\n" + "Ultralytics YOLOv8.0.24 🚀 Python-3.8.10 torch-1.13.1+cu116 CUDA:0 (Tesla T4, 15110MiB)\n", + "Setup complete ✅ (2 CPUs, 12.7 GB RAM, 30.8/166.8 GB disk)\n" ] } ] @@ -111,28 +110,27 @@ "colab": { "base_uri": "https://localhost:8080/" }, - "outputId": "3136de6b-2995-4731-e84c-962acb233d89" + "outputId": "abe002b5-3df9-4324-9e50-1587394398a2" }, "source": [ "# Run inference on an image with YOLOv8n\n", - "!yolo task=detect mode=predict model=yolov8n.pt conf=0.25 source='https://ultralytics.com/images/zidane.jpg'" + "!yolo predict model=yolov8n.pt source='https://ultralytics.com/images/zidane.jpg'" ], - "execution_count": null, + "execution_count": 2, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ - "Downloading https://ultralytics.com/images/zidane.jpg to zidane.jpg...\n", - "100% 165k/165k [00:00<00:00, 12.0MB/s]\n", - "Ultralytics YOLOv8.0.5 🚀 Python-3.8.16 torch-1.13.1+cu116 CUDA:0 (Tesla T4, 15110MiB)\n", "Downloading https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt to yolov8n.pt...\n", - "100% 6.24M/6.24M [00:00<00:00, 58.7MB/s]\n", + "\r 0% 0.00/6.23M [00:00 None: self.boxes = Boxes(boxes, orig_shape) if boxes is not None else None # native size boxes self.masks = Masks(masks, orig_shape) if masks is not None else None # native size or imgsz masks - self.probs = probs.softmax(0) if probs is not None else None + self.probs = probs if probs is not None else None self.orig_shape = orig_shape self.comp = ["boxes", "masks", "probs"] diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index f8c46fa..0dd49e0 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -21,7 +21,7 @@ from torch.optim import lr_scheduler from tqdm import tqdm from ultralytics import __version__ -from ultralytics.nn.tasks import attempt_load_one_weight +from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis, @@ -515,14 +515,15 @@ class BaseTrainer: def check_resume(self): resume = self.args.resume if resume: - last = Path(check_file(resume) if isinstance(resume, (str, Path)) else get_latest_run()) - args_yaml = last.parent.parent / 'args.yaml' # train options yaml - assert args_yaml.is_file(), \ - FileNotFoundError(f'Resume checkpoint {last} not found. ' - 'Please pass a valid checkpoint to resume from, i.e. yolo resume=path/to/last.pt') - args = get_cfg(args_yaml) # replace - args.model, resume = str(last), True # reinstate - self.args = args + try: + last = Path( + check_file(resume) if isinstance(resume, (str, + Path)) and Path(resume).exists() else get_latest_run()) + self.args = get_cfg(attempt_load_weights(last).args) + self.args.model, resume = str(last), True # reinstate + except Exception as e: + raise FileNotFoundError("Resume checkpoint not found. Please pass a valid checkpoint to resume from, " + "i.e. 'yolo train resume model=path/to/last.pt'") from e self.resume = resume def resume_training(self, ckpt): @@ -541,7 +542,7 @@ class BaseTrainer: f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \ f"Start a new training without --resume, i.e. 'yolo task=... mode=train model={self.args.model}'" LOGGER.info( - f'Resuming training from {self.args.model} from epoch {start_epoch} to {self.epochs} total epochs') + f'Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs') if self.epochs < start_epoch: LOGGER.info( f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs.") diff --git a/ultralytics/yolo/utils/__init__.py b/ultralytics/yolo/utils/__init__.py index 66c85ff..5023a18 100644 --- a/ultralytics/yolo/utils/__init__.py +++ b/ultralytics/yolo/utils/__init__.py @@ -479,7 +479,7 @@ def set_sentry(): if SETTINGS['sync'] and \ not is_pytest_running() and \ not is_github_actions_ci() and \ - (is_pip_package() or + ((is_pip_package() and not is_git_dir()) or (get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git" and get_git_branch() == "main")): import sentry_sdk # noqa @@ -493,6 +493,10 @@ def set_sentry(): before_send=before_send, ignore_errors=[KeyboardInterrupt]) + # Disable all sentry logging + for logger in "sentry_sdk", "sentry_sdk.errors": + logging.getLogger(logger).setLevel(logging.CRITICAL) + def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.1'): """ diff --git a/ultralytics/yolo/utils/autobatch.py b/ultralytics/yolo/utils/autobatch.py index 7aaf146..d528864 100644 --- a/ultralytics/yolo/utils/autobatch.py +++ b/ultralytics/yolo/utils/autobatch.py @@ -52,21 +52,22 @@ def autobatch(model, imgsz=640, fraction=0.7, batch_size=16): try: img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes] results = profile(img, model, n=3, device=device) - except Exception as e: - LOGGER.warning(f'{prefix}{e}') - # Fit a solution - y = [x[2] for x in results if x] # memory [2] - p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit - b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size) - if None in results: # some sizes failed - i = results.index(None) # first fail index - if b >= batch_sizes[i]: # y intercept above failure point - b = batch_sizes[max(i - 1, 0)] # select prior safe point - if b < 1 or b > 1024: # b outside of safe range - b = batch_size - LOGGER.warning(f'{prefix}WARNING ⚠️ CUDA anomaly detected, recommend restart environment and retry command.') + # Fit a solution + y = [x[2] for x in results if x] # memory [2] + p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit + b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size) + if None in results: # some sizes failed + i = results.index(None) # first fail index + if b >= batch_sizes[i]: # y intercept above failure point + b = batch_sizes[max(i - 1, 0)] # select prior safe point + if b < 1 or b > 1024: # b outside of safe range + b = batch_size + LOGGER.info(f'{prefix}WARNING ⚠️ CUDA anomaly detected, using default batch-size {batch_size}.') - fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted - LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅') - return b + fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted + LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅') + return b + except Exception as e: + LOGGER.warning(f'{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.') + return batch_size diff --git a/ultralytics/yolo/v8/detect/predict.py b/ultralytics/yolo/v8/detect/predict.py index 32f2cc4..9f5602a 100644 --- a/ultralytics/yolo/v8/detect/predict.py +++ b/ultralytics/yolo/v8/detect/predict.py @@ -41,7 +41,7 @@ class DetectionPredictor(BasePredictor): if len(im.shape) == 3: im = im[None] # expand for batch dim self.seen += 1 - im0 = im0.copy() + imc = im0.copy() if self.args.save_crop else im0 if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1 log_string += f'{idx}: ' frame = self.dataset.count @@ -73,7 +73,6 @@ class DetectionPredictor(BasePredictor): self.model.names[c] if self.args.hide_conf else f'{self.model.names[c]} {conf:.2f}') self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) if self.args.save_crop: - imc = im0.copy() save_one_box(d.xyxy, imc, file=self.save_dir / 'crops' / self.model.model.names[c] / f'{self.data_path.stem}.jpg', diff --git a/ultralytics/yolo/v8/segment/predict.py b/ultralytics/yolo/v8/segment/predict.py index 518fa9e..b45a098 100644 --- a/ultralytics/yolo/v8/segment/predict.py +++ b/ultralytics/yolo/v8/segment/predict.py @@ -43,6 +43,7 @@ class SegmentationPredictor(DetectionPredictor): if len(im.shape) == 3: im = im[None] # expand for batch dim self.seen += 1 + imc = im0.copy() if self.args.save_crop else im0 if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1 log_string += f'{idx}: ' frame = self.dataset.count @@ -91,7 +92,6 @@ class SegmentationPredictor(DetectionPredictor): self.model.names[c] if self.args.hide_conf else f'{self.model.names[c]} {conf:.2f}') self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) if self.args.boxes else None if self.args.save_crop: - imc = im0.copy() save_one_box(d.xyxy, imc, file=self.save_dir / 'crops' / self.model.model.names[c] / f'{self.data_path.stem}.jpg',