Threadpool fixes and CLI improvements (#550)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
Glenn Jocher
2023-01-22 17:08:08 +01:00
committed by GitHub
parent d9a0fba251
commit 21b701c4ea
22 changed files with 338 additions and 251 deletions

View File

@ -13,7 +13,7 @@ import random
import shutil
import time
from itertools import repeat
from multiprocessing.pool import Pool, ThreadPool
from multiprocessing.pool import ThreadPool
from pathlib import Path
from threading import Thread
from urllib.parse import urlparse
@ -580,7 +580,7 @@ class LoadImagesAndLabels(Dataset):
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
self.im_hw0, self.im_hw = [None] * n, [None] * n
fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
with (Pool if n > 10000 else ThreadPool)(NUM_THREADS) as pool:
with ThreadPool(NUM_THREADS) as pool:
results = pool.imap(fcn, range(n))
pbar = tqdm(enumerate(results), total=n, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
for i, x in pbar:
@ -1150,7 +1150,7 @@ class HUBDatasetStats():
dataset = LoadImagesAndLabels(self.data[split]) # load dataset
desc = f'{split} images'
total = dataset.n
with (Pool if total > 10000 else ThreadPool)(NUM_THREADS) as pool:
with ThreadPool(NUM_THREADS) as pool:
for _ in tqdm(pool.imap(self._hub_ops, dataset.im_files), total=total, desc=desc):
pass
print(f'Done. All images saved to {self.im_dir}')