diff --git a/docker/Dockerfile-arm64 b/docker/Dockerfile-arm64 index 6c18d78..c11dba4 100644 --- a/docker/Dockerfile-arm64 +++ b/docker/Dockerfile-arm64 @@ -3,7 +3,7 @@ # Image is aarch64-compatible for Apple M1 and other ARM architectures i.e. Jetson Nano and Raspberry Pi # Start FROM Ubuntu image https://hub.docker.com/_/ubuntu -FROM arm64v8/ubuntu:rolling +FROM arm64v8/ubuntu:22.10 # Downloads to user config dir ADD https://ultralytics.com/assets/Arial.ttf https://ultralytics.com/assets/Arial.Unicode.ttf /root/.config/Ultralytics/ diff --git a/docker/Dockerfile-cpu b/docker/Dockerfile-cpu index 11b3350..fa6ec12 100644 --- a/docker/Dockerfile-cpu +++ b/docker/Dockerfile-cpu @@ -3,7 +3,7 @@ # Image is CPU-optimized for ONNX, OpenVINO and PyTorch YOLOv8 deployments # Start FROM Ubuntu image https://hub.docker.com/_/ubuntu -FROM ubuntu:rolling +FROM ubuntu:22.10 # Downloads to user config dir ADD https://ultralytics.com/assets/Arial.ttf https://ultralytics.com/assets/Arial.Unicode.ttf /root/.config/Ultralytics/ diff --git a/docs/modes/predict.md b/docs/modes/predict.md index a5189d4..b252d43 100644 --- a/docs/modes/predict.md +++ b/docs/modes/predict.md @@ -56,7 +56,7 @@ whether each source can be used in streaming mode with `stream=True` ✅ and an ## Arguments -`model.predict` accepts multiple arguments that control the predction operation. These arguments can be passed directly to `model.predict`: +`model.predict` accepts multiple arguments that control the prediction operation. These arguments can be passed directly to `model.predict`: !!! example ``` model.predict(source, save=True, imgsz=320, conf=0.5) @@ -273,4 +273,4 @@ Here's a Python script using OpenCV (cv2) and YOLOv8 to run inference on video f # Release the video capture object and close the display window cap.release() cv2.destroyAllWindows() - ``` \ No newline at end of file + ``` diff --git a/docs/reference/hub/utils.md b/docs/reference/hub/utils.md index 5f1b00c..2931d9c 100644 --- a/docs/reference/hub/utils.md +++ b/docs/reference/hub/utils.md @@ -3,11 +3,6 @@ :::ultralytics.hub.utils.Traces

-# check_dataset_disk_space ---- -:::ultralytics.hub.utils.check_dataset_disk_space -

- # request_with_credentials --- :::ultralytics.hub.utils.request_with_credentials diff --git a/docs/reference/yolo/utils/downloads.md b/docs/reference/yolo/utils/downloads.md index 103c261..58b8d9a 100644 --- a/docs/reference/yolo/utils/downloads.md +++ b/docs/reference/yolo/utils/downloads.md @@ -8,6 +8,11 @@ :::ultralytics.yolo.utils.downloads.unzip_file

+# check_disk_space +--- +:::ultralytics.yolo.utils.downloads.check_disk_space +

+ # safe_download --- :::ultralytics.yolo.utils.downloads.safe_download diff --git a/setup.py b/setup.py index e40392a..04ce73b 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ setup( 'Intended Audience :: Developers', 'Intended Audience :: Education', 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: GNU Affero General Public License v3 (AGPLv3)', + 'License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)', 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 409a89a..5a2a3f7 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = '8.0.82' +__version__ = '8.0.83' from ultralytics.hub import start from ultralytics.yolo.engine.model import YOLO diff --git a/ultralytics/yolo/data/base.py b/ultralytics/yolo/data/base.py index 08265a6..3ed64d5 100644 --- a/ultralytics/yolo/data/base.py +++ b/ultralytics/yolo/data/base.py @@ -3,6 +3,7 @@ import glob import math import os +import random from copy import deepcopy from multiprocessing.pool import ThreadPool from pathlib import Path @@ -10,10 +11,11 @@ from typing import Optional import cv2 import numpy as np +import psutil from torch.utils.data import Dataset from tqdm import tqdm -from ..utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT +from ..utils import LOCAL_RANK, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT from .utils import HELP_URL, IMG_FORMATS @@ -63,14 +65,10 @@ class BaseDataset(Dataset): self.augment = augment self.single_cls = single_cls self.prefix = prefix - self.im_files = self.get_img_files(self.img_path) self.labels = self.get_labels() self.update_labels(include_class=classes) # single_cls and include_class - - self.ni = len(self.labels) - - # Rect stuff + self.ni = len(self.labels) # number of images self.rect = rect self.batch_size = batch_size self.stride = stride @@ -80,6 +78,8 @@ class BaseDataset(Dataset): self.set_rectangle() # Cache stuff + if cache == 'ram' and not self.check_cache_ram(): + cache = False self.ims = [None] * self.ni self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files] if cache: @@ -148,7 +148,7 @@ class BaseDataset(Dataset): def cache_images(self, cache): """Cache images to memory or disk.""" - gb = 0 # Gigabytes of cached images + b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image with ThreadPool(NUM_THREADS) as pool: @@ -156,11 +156,11 @@ class BaseDataset(Dataset): pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0) for i, x in pbar: if cache == 'disk': - gb += self.npy_files[i].stat().st_size + b += self.npy_files[i].stat().st_size else: # 'ram' self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i) - gb += self.ims[i].nbytes - pbar.desc = f'{self.prefix}Caching images ({gb / 1E9:.1f}GB {cache})' + b += self.ims[i].nbytes + pbar.desc = f'{self.prefix}Caching images ({b / gb:.1f}GB {cache})' pbar.close() def cache_images_to_disk(self, i): @@ -169,6 +169,24 @@ class BaseDataset(Dataset): if not f.exists(): np.save(f.as_posix(), cv2.imread(self.im_files[i])) + def check_cache_ram(self, safety_margin=0.5): + """Check image caching requirements vs available memory.""" + b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes + n = min(self.ni, 30) # extrapolate from 30 random images + for _ in range(n): + im = cv2.imread(random.choice(self.im_files)) # sample image + ratio = self.imgsz / max(im.shape[0], im.shape[1]) # max(h, w) # ratio + b += im.nbytes * ratio ** 2 + mem_required = b * self.ni / n * (1 + safety_margin) # GB required to cache dataset into RAM + mem = psutil.virtual_memory() + cache = mem_required < mem.available # to cache or not to cache, that is the question + if not cache: + LOGGER.info(f'{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images ' + f'with {int(safety_margin * 100)}% safety margin but only ' + f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, ' + f"{'caching images ✅' if cache else 'not caching images ⚠️'}") + return cache + def set_rectangle(self): """Sets the shape of bounding boxes for YOLO detections as rectangles.""" bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index 6e49716..efa938d 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -469,31 +469,27 @@ class YOLO: @property def names(self): - """ - Returns class names of the loaded model. - """ + """Returns class names of the loaded model.""" return self.model.names if hasattr(self.model, 'names') else None @property def device(self): - """ - Returns device if PyTorch model - """ + """Returns device if PyTorch model.""" return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None @property def transforms(self): - """ - Returns transform of the loaded model. - """ + """Returns transform of the loaded model.""" return self.model.transforms if hasattr(self.model, 'transforms') else None def add_callback(self, event: str, func): - """ - Add callback - """ + """Add a callback.""" self.callbacks[event].append(func) + def clear_callback(self, event: str): + """Clear all event callbacks.""" + self.callbacks[event] = [] + @staticmethod def _reset_ckpt_args(args): """Reset arguments when loading a PyTorch model.""" diff --git a/ultralytics/yolo/utils/__init__.py b/ultralytics/yolo/utils/__init__.py index 94563f8..92f4646 100644 --- a/ultralytics/yolo/utils/__init__.py +++ b/ultralytics/yolo/utils/__init__.py @@ -734,3 +734,26 @@ ENVIRONMENT = 'Colab' if is_colab() else 'Kaggle' if is_kaggle() else 'Jupyter' 'Docker' if is_docker() else platform.system() TESTS_RUNNING = is_pytest_running() or is_github_actions_ci() set_sentry() + +# OpenCV Multilanguage-friendly functions ------------------------------------------------------------------------------------ +imshow_ = cv2.imshow # copy to avoid recursion errors + + +def imread(filename, flags=cv2.IMREAD_COLOR): + return cv2.imdecode(np.fromfile(filename, np.uint8), flags) + + +def imwrite(filename, img): + try: + cv2.imencode(Path(filename).suffix, img)[1].tofile(filename) + return True + except Exception: + return False + + +def imshow(path, im): + imshow_(path.encode('unicode_escape').decode(), im) + + +if Path(inspect.stack()[0].filename).parent.parent.as_posix() in inspect.stack()[-1].filename: + cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow # redefine diff --git a/ultralytics/yolo/utils/callbacks/base.py b/ultralytics/yolo/utils/callbacks/base.py index 4cf06c7..94da2b2 100644 --- a/ultralytics/yolo/utils/callbacks/base.py +++ b/ultralytics/yolo/utils/callbacks/base.py @@ -200,11 +200,12 @@ def add_integration_callbacks(instance): from .comet import callbacks as comet_callbacks from .hub import callbacks as hub_callbacks from .mlflow import callbacks as mf_callbacks + from .neptune import callbacks as neptune_callbacks from .raytune import callbacks as tune_callbacks from .tensorboard import callbacks as tb_callbacks from .wb import callbacks as wb_callbacks - for x in clearml_callbacks, comet_callbacks, hub_callbacks, tb_callbacks, mf_callbacks, tune_callbacks, wb_callbacks: + for x in clearml_callbacks, comet_callbacks, hub_callbacks, tb_callbacks, mf_callbacks, tune_callbacks, wb_callbacks, neptune_callbacks: for k, v in x.items(): if v not in instance.callbacks[k]: # prevent duplicate callbacks addition instance.callbacks[k].append(v) # callback[name].append(func) diff --git a/ultralytics/yolo/utils/callbacks/mlflow.py b/ultralytics/yolo/utils/callbacks/mlflow.py index 36d092d..1c2ed74 100644 --- a/ultralytics/yolo/utils/callbacks/mlflow.py +++ b/ultralytics/yolo/utils/callbacks/mlflow.py @@ -52,19 +52,12 @@ def on_fit_epoch_end(trainer): run.log_metrics(metrics=metrics_dict, step=trainer.epoch) -def on_model_save(trainer): - """Logs model and metrics to mlflow on save.""" - if mlflow: - run.log_artifact(trainer.last) - - def on_train_end(trainer): """Called at end of train loop to log model artifact info.""" if mlflow: root_dir = Path(__file__).resolve().parents[3] + run.log_artifact(trainer.last) run.log_artifact(trainer.best) - model_uri = f'runs:/{run_id}/' - run.register_model(model_uri, experiment_name) run.pyfunc.log_model(artifact_path=experiment_name, code_path=[str(root_dir)], artifacts={'model_path': str(trainer.save_dir)}, @@ -74,5 +67,4 @@ def on_train_end(trainer): callbacks = { 'on_pretrain_routine_end': on_pretrain_routine_end, 'on_fit_epoch_end': on_fit_epoch_end, - 'on_model_save': on_model_save, 'on_train_end': on_train_end} if mlflow else {} diff --git a/ultralytics/yolo/utils/callbacks/neptune.py b/ultralytics/yolo/utils/callbacks/neptune.py new file mode 100644 index 0000000..ca72b6b --- /dev/null +++ b/ultralytics/yolo/utils/callbacks/neptune.py @@ -0,0 +1,105 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +import matplotlib.image as mpimg +import matplotlib.pyplot as plt + +from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING +from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params + +try: + import neptune + from neptune.types import File + + assert not TESTS_RUNNING # do not log pytest + assert hasattr(neptune, '__version__') +except (ImportError, AssertionError): + neptune = None + +run = None # NeptuneAI experiment logger instance + + +def _log_scalars(scalars, step=0): + """Log scalars to the NeptuneAI experiment logger.""" + if run: + for k, v in scalars.items(): + run[k].append(value=v, step=step) + + +def _log_images(imgs_dict, group=''): + """Log scalars to the NeptuneAI experiment logger.""" + if run: + for k, v in imgs_dict.items(): + run[f'{group}/{k}'].upload(File(v)) + + +def _log_plot(title, plot_path): + """Log plots to the NeptuneAI experiment logger.""" + """ + Log image as plot in the plot section of NeptuneAI + + arguments: + title (str) Title of the plot + plot_path (PosixPath or str) Path to the saved image file + """ + img = mpimg.imread(plot_path) + fig = plt.figure() + ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks + ax.imshow(img) + run[f'Plots/{title}'].upload(fig) + + +def on_pretrain_routine_start(trainer): + """Callback function called before the training routine starts.""" + try: + global run + run = neptune.init_run(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, tags=['YOLOv8']) + run['Configuration/Hyperparameters'] = {k: '' if v is None else v for k, v in vars(trainer.args).items()} + except Exception as e: + LOGGER.warning(f'WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}') + + +def on_train_epoch_end(trainer): + """Callback function called at end of each training epoch.""" + _log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1) + _log_scalars(trainer.lr, trainer.epoch + 1) + if trainer.epoch == 1: + _log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, 'Mosaic') + + +def on_fit_epoch_end(trainer): + """Callback function called at end of each fit (train+val) epoch.""" + if run and trainer.epoch == 0: + model_info = { + 'parameters': get_num_params(trainer.model), + 'GFLOPs': round(get_flops(trainer.model), 3), + 'speed(ms)': round(trainer.validator.speed['inference'], 3)} + run['Configuration/Model'] = model_info + _log_scalars(trainer.metrics, trainer.epoch + 1) + + +def on_val_end(validator): + """Callback function called at end of each validation.""" + if run: + # Log val_labels and val_pred + _log_images({f.stem: str(f) for f in validator.save_dir.glob('val*.jpg')}, 'Validation') + + +def on_train_end(trainer): + """Callback function called at end of training.""" + if run: + # Log final results, CM matrix + PR plots + files = ['results.png', 'confusion_matrix.png', *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))] + files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter + for f in files: + _log_plot(title=f.stem, plot_path=f) + # Log the final model + run[f'weights/{trainer.args.name or trainer.args.task}/{str(trainer.best.name)}'].upload(File(str( + trainer.best))) + run.stop() + + +callbacks = { + 'on_pretrain_routine_start': on_pretrain_routine_start, + 'on_train_epoch_end': on_train_epoch_end, + 'on_fit_epoch_end': on_fit_epoch_end, + 'on_val_end': on_val_end, + 'on_train_end': on_train_end} if neptune else {}