new check_dataset functions (#43)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia
2022-11-12 19:32:41 +05:30
committed by GitHub
parent d143ac666f
commit 1f3aad86c1
13 changed files with 336 additions and 62 deletions

View File

@ -17,26 +17,6 @@ from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, torch_distributed_zer
# BaseTrainer python usage
class ClassificationTrainer(BaseTrainer):
def get_dataset(self, dataset):
# temporary solution. Replace with new ultralytics.yolo.ClassificationDataset module
data = Path("datasets") / dataset
with torch_distributed_zero_first(LOCAL_RANK), WorkingDirectory(Path.cwd()):
data_dir = data if data.is_dir() else (Path.cwd() / data)
if not data_dir.is_dir():
self.console.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
t = time.time()
if str(data) == 'imagenet':
subprocess.run(f"bash {v8.ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
else:
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
download(url, dir=data_dir.parent)
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
self.console.info(s)
train_set = data_dir / "train"
test_set = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val' # data/test or data/val
return train_set, test_set
def get_dataloader(self, dataset_path, batch_size=None, rank=0):
return build_classification_dataloader(path=dataset_path,
imgsz=self.args.img_size,