Threadpool fixes and CLI improvements (#550)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
Glenn Jocher
2023-01-22 17:08:08 +01:00
committed by GitHub
parent d9a0fba251
commit 21b701c4ea
22 changed files with 338 additions and 251 deletions

View File

@ -93,7 +93,7 @@ class BaseDataset(Dataset):
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
assert im_files, f"{self.prefix}No images found"
except Exception as e:
raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}: {e}\n{HELP_URL}") from e
raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e
return im_files
def update_labels(self, include_class: Optional[list]):
@ -134,16 +134,17 @@ class BaseDataset(Dataset):
gb = 0 # Gigabytes of cached images
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
fcn = self.cache_images_to_disk if cache == "disk" else self.load_image
results = ThreadPool(NUM_THREADS).imap(fcn, range(self.ni))
pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
for i, x in pbar:
if cache == "disk":
gb += self.npy_files[i].stat().st_size
else: # 'ram'
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
gb += self.ims[i].nbytes
pbar.desc = f"{self.prefix}Caching images ({gb / 1E9:.1f}GB {cache})"
pbar.close()
with ThreadPool(NUM_THREADS) as pool:
results = pool.imap(fcn, range(self.ni))
pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
for i, x in pbar:
if cache == "disk":
gb += self.npy_files[i].stat().st_size
else: # 'ram'
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
gb += self.ims[i].nbytes
pbar.desc = f"{self.prefix}Caching images ({gb / 1E9:.1f}GB {cache})"
pbar.close()
def cache_images_to_disk(self, i):
# Saves an image as an *.npy file for faster loading

View File

@ -13,7 +13,7 @@ import random
import shutil
import time
from itertools import repeat
from multiprocessing.pool import Pool, ThreadPool
from multiprocessing.pool import ThreadPool
from pathlib import Path
from threading import Thread
from urllib.parse import urlparse
@ -580,7 +580,7 @@ class LoadImagesAndLabels(Dataset):
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
self.im_hw0, self.im_hw = [None] * n, [None] * n
fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
with (Pool if n > 10000 else ThreadPool)(NUM_THREADS) as pool:
with ThreadPool(NUM_THREADS) as pool:
results = pool.imap(fcn, range(n))
pbar = tqdm(enumerate(results), total=n, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
for i, x in pbar:
@ -1150,7 +1150,7 @@ class HUBDatasetStats():
dataset = LoadImagesAndLabels(self.data[split]) # load dataset
desc = f'{split} images'
total = dataset.n
with (Pool if total > 10000 else ThreadPool)(NUM_THREADS) as pool:
with ThreadPool(NUM_THREADS) as pool:
for _ in tqdm(pool.imap(self._hub_ops, dataset.im_files), total=total, desc=desc):
pass
print(f'Done. All images saved to {self.im_dir}')

View File

@ -185,9 +185,9 @@ def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
return masks, index
def check_dataset_yaml(data, autodownload=True):
def check_dataset_yaml(dataset, autodownload=True):
# Download, check and/or unzip dataset if not found locally
data = check_file(data)
data = check_file(dataset)
# Download (optional)
extract_dir = ''
@ -227,9 +227,11 @@ def check_dataset_yaml(data, autodownload=True):
if val:
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
if not all(x.exists() for x in val):
LOGGER.info('\nDataset not found ⚠️, missing paths %s' % [str(x) for x in val if not x.exists()])
if not s or not autodownload:
raise FileNotFoundError('Dataset not found ❌')
msg = f"\nDataset '{dataset}' not found ⚠️, missing paths %s" % [str(x) for x in val if not x.exists()]
if s and autodownload:
LOGGER.warning(msg)
else:
raise FileNotFoundError(s)
t = time.time()
if s.startswith('http') and s.endswith('.zip'): # URL
f = Path(s).name # filename