ultralytics 8.0.83 Neptune AI logging addition (#2130)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Snyk bot <snyk-bot@snyk.io>
Co-authored-by: Toutatis64 <Toutatis64@users.noreply.github.com>
Co-authored-by: M. Tolga Cangöz <46008593+standardAI@users.noreply.github.com>
Co-authored-by: Talia Bender <85292283+taliabender@users.noreply.github.com>
Co-authored-by: Ophélie Le Mentec <17216799+ouphi@users.noreply.github.com>
Co-authored-by: Kadir Şahin <68073829+ssahinnkadir@users.noreply.github.com>
Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com>
This commit is contained in:
Glenn Jocher
2023-04-19 21:25:47 +02:00
committed by GitHub
parent 55a03ad85f
commit 6c082ebd6f
13 changed files with 178 additions and 43 deletions

View File

@ -3,6 +3,7 @@
import glob
import math
import os
import random
from copy import deepcopy
from multiprocessing.pool import ThreadPool
from pathlib import Path
@ -10,10 +11,11 @@ from typing import Optional
import cv2
import numpy as np
import psutil
from torch.utils.data import Dataset
from tqdm import tqdm
from ..utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT
from ..utils import LOCAL_RANK, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT
from .utils import HELP_URL, IMG_FORMATS
@ -63,14 +65,10 @@ class BaseDataset(Dataset):
self.augment = augment
self.single_cls = single_cls
self.prefix = prefix
self.im_files = self.get_img_files(self.img_path)
self.labels = self.get_labels()
self.update_labels(include_class=classes) # single_cls and include_class
self.ni = len(self.labels)
# Rect stuff
self.ni = len(self.labels) # number of images
self.rect = rect
self.batch_size = batch_size
self.stride = stride
@ -80,6 +78,8 @@ class BaseDataset(Dataset):
self.set_rectangle()
# Cache stuff
if cache == 'ram' and not self.check_cache_ram():
cache = False
self.ims = [None] * self.ni
self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
if cache:
@ -148,7 +148,7 @@ class BaseDataset(Dataset):
def cache_images(self, cache):
"""Cache images to memory or disk."""
gb = 0 # Gigabytes of cached images
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image
with ThreadPool(NUM_THREADS) as pool:
@ -156,11 +156,11 @@ class BaseDataset(Dataset):
pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
for i, x in pbar:
if cache == 'disk':
gb += self.npy_files[i].stat().st_size
b += self.npy_files[i].stat().st_size
else: # 'ram'
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
gb += self.ims[i].nbytes
pbar.desc = f'{self.prefix}Caching images ({gb / 1E9:.1f}GB {cache})'
b += self.ims[i].nbytes
pbar.desc = f'{self.prefix}Caching images ({b / gb:.1f}GB {cache})'
pbar.close()
def cache_images_to_disk(self, i):
@ -169,6 +169,24 @@ class BaseDataset(Dataset):
if not f.exists():
np.save(f.as_posix(), cv2.imread(self.im_files[i]))
def check_cache_ram(self, safety_margin=0.5):
"""Check image caching requirements vs available memory."""
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
n = min(self.ni, 30) # extrapolate from 30 random images
for _ in range(n):
im = cv2.imread(random.choice(self.im_files)) # sample image
ratio = self.imgsz / max(im.shape[0], im.shape[1]) # max(h, w) # ratio
b += im.nbytes * ratio ** 2
mem_required = b * self.ni / n * (1 + safety_margin) # GB required to cache dataset into RAM
mem = psutil.virtual_memory()
cache = mem_required < mem.available # to cache or not to cache, that is the question
if not cache:
LOGGER.info(f'{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images '
f'with {int(safety_margin * 100)}% safety margin but only '
f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, '
f"{'caching images ✅' if cache else 'not caching images ⚠️'}")
return cache
def set_rectangle(self):
"""Sets the shape of bounding boxes for YOLO detections as rectangles."""
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index