new check_dataset functions (#43)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -17,26 +17,6 @@ from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, torch_distributed_zer
|
||||
# BaseTrainer python usage
|
||||
class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
def get_dataset(self, dataset):
|
||||
# temporary solution. Replace with new ultralytics.yolo.ClassificationDataset module
|
||||
data = Path("datasets") / dataset
|
||||
with torch_distributed_zero_first(LOCAL_RANK), WorkingDirectory(Path.cwd()):
|
||||
data_dir = data if data.is_dir() else (Path.cwd() / data)
|
||||
if not data_dir.is_dir():
|
||||
self.console.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
|
||||
t = time.time()
|
||||
if str(data) == 'imagenet':
|
||||
subprocess.run(f"bash {v8.ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
|
||||
else:
|
||||
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
|
||||
download(url, dir=data_dir.parent)
|
||||
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
|
||||
self.console.info(s)
|
||||
train_set = data_dir / "train"
|
||||
test_set = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val' # data/test or data/val
|
||||
|
||||
return train_set, test_set
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size=None, rank=0):
|
||||
return build_classification_dataloader(path=dataset_path,
|
||||
imgsz=self.args.img_size,
|
||||
|
@ -21,26 +21,6 @@ from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, de_parallel, torch_di
|
||||
# BaseTrainer python usage
|
||||
class SegmentationTrainer(BaseTrainer):
|
||||
|
||||
def get_dataset(self, dataset):
|
||||
# temporary solution. Replace with new ultralytics.yolo.ClassificationDataset module
|
||||
data = Path("datasets") / dataset
|
||||
with torch_distributed_zero_first(LOCAL_RANK), WorkingDirectory(Path.cwd()):
|
||||
data_dir = data if data.is_dir() else (Path.cwd() / data)
|
||||
if not data_dir.is_dir():
|
||||
self.console.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
|
||||
t = time.time()
|
||||
if str(data) == 'imagenet':
|
||||
subprocess.run(f"bash {v8.ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
|
||||
else:
|
||||
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
|
||||
download(url, dir=data_dir.parent)
|
||||
# TODO: add colorstr
|
||||
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {'bold', data_dir}\n"
|
||||
self.console.info(s)
|
||||
train_set = data_dir.parent / "coco128-seg"
|
||||
test_set = train_set
|
||||
return train_set, test_set
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size, rank=0):
|
||||
# TODO: manage splits differently
|
||||
# calculate stride - check if model is initialized
|
||||
@ -253,7 +233,7 @@ class SegmentationTrainer(BaseTrainer):
|
||||
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
|
||||
def train(cfg):
|
||||
cfg.cfg = v8.ROOT / "models/yolov5n-seg.yaml"
|
||||
cfg.data = cfg.data or "coco128-segments" # or yolo.ClassificationDataset("mnist")
|
||||
cfg.data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist")
|
||||
trainer = SegmentationTrainer(cfg)
|
||||
trainer.train()
|
||||
|
||||
|
@ -8,9 +8,9 @@ import torch.nn.functional as F
|
||||
from ultralytics.yolo.engine.validator import BaseValidator
|
||||
from ultralytics.yolo.utils import ops
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
from ultralytics.yolo.utils.files import yaml_load
|
||||
from ultralytics.yolo.utils.metrics import (ConfusionMatrix, Metrics, ap_per_class_box_and_mask, box_iou,
|
||||
fitness_segmentation, mask_iou)
|
||||
from ultralytics.yolo.utils.modeling import yaml_load
|
||||
from ultralytics.yolo.utils.torch_utils import de_parallel
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user