diff --git a/tests/test_cli.py b/tests/test_cli.py index 21db7c6..24e9703 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -29,7 +29,7 @@ def test_special_modes(): @pytest.mark.parametrize('task,model,data', TASK_ARGS) def test_train(task, model, data): - run(f'yolo train {task} model={model}.yaml data={data} imgsz=32 epochs=1') + run(f'yolo train {task} model={model}.yaml data={data} imgsz=32 epochs=1 cache=disk') @pytest.mark.parametrize('task,model,data', TASK_ARGS) diff --git a/tests/test_python.py b/tests/test_python.py index 0e8d9f1..4b64eea 100644 --- a/tests/test_python.py +++ b/tests/test_python.py @@ -108,13 +108,13 @@ def test_amp(): def test_train_scratch(): model = YOLO(CFG) - model.train(data='coco8.yaml', epochs=1, imgsz=32) + model.train(data='coco8.yaml', epochs=1, imgsz=32, cache='disk') # test disk caching model(SOURCE) def test_train_pretrained(): model = YOLO(MODEL) - model.train(data='coco8.yaml', epochs=1, imgsz=32) + model.train(data='coco8.yaml', epochs=1, imgsz=32, cache='ram') # test RAM caching model(SOURCE) diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 8939c0e..5b3beda 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = '8.0.108' +__version__ = '8.0.109' from ultralytics.hub import start from ultralytics.vit.rtdetr import RTDETR diff --git a/ultralytics/yolo/data/base.py b/ultralytics/yolo/data/base.py index d2ea0cc..cbb4843 100644 --- a/ultralytics/yolo/data/base.py +++ b/ultralytics/yolo/data/base.py @@ -54,7 +54,7 @@ class BaseDataset(Dataset): hyp=DEFAULT_CFG, prefix='', rect=False, - batch_size=None, + batch_size=16, stride=32, pad=0.5, single_cls=False, @@ -77,6 +77,10 @@ class BaseDataset(Dataset): assert self.batch_size is not None self.set_rectangle() + # Buffer thread for mosaic images + self.buffer = [] # buffer size = batch size + self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0 + # Cache stuff if cache == 'ram' and not self.check_cache_ram(): cache = False @@ -88,10 +92,6 @@ class BaseDataset(Dataset): # Transforms self.transforms = self.build_transforms(hyp=hyp) - # Buffer thread for mosaic images - self.buffer = [] # buffer size = batch size - self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0 - def get_img_files(self, img_path): """Read image files.""" try: