Threadpool fixes and CLI improvements (#550)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
@ -93,7 +93,7 @@ class BaseDataset(Dataset):
|
||||
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
|
||||
assert im_files, f"{self.prefix}No images found"
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}: {e}\n{HELP_URL}") from e
|
||||
raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e
|
||||
return im_files
|
||||
|
||||
def update_labels(self, include_class: Optional[list]):
|
||||
@ -134,16 +134,17 @@ class BaseDataset(Dataset):
|
||||
gb = 0 # Gigabytes of cached images
|
||||
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
|
||||
fcn = self.cache_images_to_disk if cache == "disk" else self.load_image
|
||||
results = ThreadPool(NUM_THREADS).imap(fcn, range(self.ni))
|
||||
pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
|
||||
for i, x in pbar:
|
||||
if cache == "disk":
|
||||
gb += self.npy_files[i].stat().st_size
|
||||
else: # 'ram'
|
||||
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
|
||||
gb += self.ims[i].nbytes
|
||||
pbar.desc = f"{self.prefix}Caching images ({gb / 1E9:.1f}GB {cache})"
|
||||
pbar.close()
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
results = pool.imap(fcn, range(self.ni))
|
||||
pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
|
||||
for i, x in pbar:
|
||||
if cache == "disk":
|
||||
gb += self.npy_files[i].stat().st_size
|
||||
else: # 'ram'
|
||||
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
|
||||
gb += self.ims[i].nbytes
|
||||
pbar.desc = f"{self.prefix}Caching images ({gb / 1E9:.1f}GB {cache})"
|
||||
pbar.close()
|
||||
|
||||
def cache_images_to_disk(self, i):
|
||||
# Saves an image as an *.npy file for faster loading
|
||||
|
@ -13,7 +13,7 @@ import random
|
||||
import shutil
|
||||
import time
|
||||
from itertools import repeat
|
||||
from multiprocessing.pool import Pool, ThreadPool
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from urllib.parse import urlparse
|
||||
@ -580,7 +580,7 @@ class LoadImagesAndLabels(Dataset):
|
||||
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
|
||||
self.im_hw0, self.im_hw = [None] * n, [None] * n
|
||||
fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
|
||||
with (Pool if n > 10000 else ThreadPool)(NUM_THREADS) as pool:
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
results = pool.imap(fcn, range(n))
|
||||
pbar = tqdm(enumerate(results), total=n, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
|
||||
for i, x in pbar:
|
||||
@ -1150,7 +1150,7 @@ class HUBDatasetStats():
|
||||
dataset = LoadImagesAndLabels(self.data[split]) # load dataset
|
||||
desc = f'{split} images'
|
||||
total = dataset.n
|
||||
with (Pool if total > 10000 else ThreadPool)(NUM_THREADS) as pool:
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
for _ in tqdm(pool.imap(self._hub_ops, dataset.im_files), total=total, desc=desc):
|
||||
pass
|
||||
print(f'Done. All images saved to {self.im_dir}')
|
||||
|
@ -185,9 +185,9 @@ def polygons2masks_overlap(imgsz, segments, downsample_ratio=1):
|
||||
return masks, index
|
||||
|
||||
|
||||
def check_dataset_yaml(data, autodownload=True):
|
||||
def check_dataset_yaml(dataset, autodownload=True):
|
||||
# Download, check and/or unzip dataset if not found locally
|
||||
data = check_file(data)
|
||||
data = check_file(dataset)
|
||||
|
||||
# Download (optional)
|
||||
extract_dir = ''
|
||||
@ -227,9 +227,11 @@ def check_dataset_yaml(data, autodownload=True):
|
||||
if val:
|
||||
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
|
||||
if not all(x.exists() for x in val):
|
||||
LOGGER.info('\nDataset not found ⚠️, missing paths %s' % [str(x) for x in val if not x.exists()])
|
||||
if not s or not autodownload:
|
||||
raise FileNotFoundError('Dataset not found ❌')
|
||||
msg = f"\nDataset '{dataset}' not found ⚠️, missing paths %s" % [str(x) for x in val if not x.exists()]
|
||||
if s and autodownload:
|
||||
LOGGER.warning(msg)
|
||||
else:
|
||||
raise FileNotFoundError(s)
|
||||
t = time.time()
|
||||
if s.startswith('http') and s.endswith('.zip'): # URL
|
||||
f = Path(s).name # filename
|
||||
|
Reference in New Issue
Block a user