new check_dataset functions (#43)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
		
							
								
								
									
										2
									
								
								.github/workflows/ci.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/ci.yaml
									
									
									
									
										vendored
									
									
								
							| @ -94,7 +94,7 @@ jobs: | ||||
|       - name: Test segmentation | ||||
|         shell: bash  # for Windows compatibility | ||||
|         run: | | ||||
|           python ultralytics/yolo/v8/segment/train.py cfg=yolov5n-seg.yaml data=coco128-segments epochs=1 img_size=64 | ||||
|           python ultralytics/yolo/v8/segment/train.py cfg=yolov5n-seg.yaml data=coco128-seg.yaml epochs=1 img_size=64 | ||||
|       - name: Test classification | ||||
|         shell: bash  # for Windows compatibility | ||||
|         run: | | ||||
|  | ||||
							
								
								
									
										101
									
								
								ultralytics/yolo/data/datasets/coco128-seg.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										101
									
								
								ultralytics/yolo/data/datasets/coco128-seg.yaml
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,101 @@ | ||||
| # YOLOv5 🚀 by Ultralytics, GPL-3.0 license | ||||
| # COCO128-seg dataset https://www.kaggle.com/ultralytics/coco128 (first 128 images from COCO train2017) by Ultralytics | ||||
| # Example usage: python train.py --data coco128.yaml | ||||
| # parent | ||||
| # ├── yolov5 | ||||
| # └── datasets | ||||
| #     └── coco128-seg  ← downloads here (7 MB) | ||||
|  | ||||
|  | ||||
| # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] | ||||
| path: ../datasets/coco128-seg  # dataset root dir | ||||
| train: images/train2017  # train images (relative to 'path') 128 images | ||||
| val: images/train2017  # val images (relative to 'path') 128 images | ||||
| test:  # test images (optional) | ||||
|  | ||||
| # Classes | ||||
| names: | ||||
|   0: person | ||||
|   1: bicycle | ||||
|   2: car | ||||
|   3: motorcycle | ||||
|   4: airplane | ||||
|   5: bus | ||||
|   6: train | ||||
|   7: truck | ||||
|   8: boat | ||||
|   9: traffic light | ||||
|   10: fire hydrant | ||||
|   11: stop sign | ||||
|   12: parking meter | ||||
|   13: bench | ||||
|   14: bird | ||||
|   15: cat | ||||
|   16: dog | ||||
|   17: horse | ||||
|   18: sheep | ||||
|   19: cow | ||||
|   20: elephant | ||||
|   21: bear | ||||
|   22: zebra | ||||
|   23: giraffe | ||||
|   24: backpack | ||||
|   25: umbrella | ||||
|   26: handbag | ||||
|   27: tie | ||||
|   28: suitcase | ||||
|   29: frisbee | ||||
|   30: skis | ||||
|   31: snowboard | ||||
|   32: sports ball | ||||
|   33: kite | ||||
|   34: baseball bat | ||||
|   35: baseball glove | ||||
|   36: skateboard | ||||
|   37: surfboard | ||||
|   38: tennis racket | ||||
|   39: bottle | ||||
|   40: wine glass | ||||
|   41: cup | ||||
|   42: fork | ||||
|   43: knife | ||||
|   44: spoon | ||||
|   45: bowl | ||||
|   46: banana | ||||
|   47: apple | ||||
|   48: sandwich | ||||
|   49: orange | ||||
|   50: broccoli | ||||
|   51: carrot | ||||
|   52: hot dog | ||||
|   53: pizza | ||||
|   54: donut | ||||
|   55: cake | ||||
|   56: chair | ||||
|   57: couch | ||||
|   58: potted plant | ||||
|   59: bed | ||||
|   60: dining table | ||||
|   61: toilet | ||||
|   62: tv | ||||
|   63: laptop | ||||
|   64: mouse | ||||
|   65: remote | ||||
|   66: keyboard | ||||
|   67: cell phone | ||||
|   68: microwave | ||||
|   69: oven | ||||
|   70: toaster | ||||
|   71: sink | ||||
|   72: refrigerator | ||||
|   73: book | ||||
|   74: clock | ||||
|   75: vase | ||||
|   76: scissors | ||||
|   77: teddy bear | ||||
|   78: hair drier | ||||
|   79: toothbrush | ||||
|  | ||||
|  | ||||
| # Download script/URL (optional) | ||||
| download: https://ultralytics.com/assets/coco128-seg.zip | ||||
							
								
								
									
										101
									
								
								ultralytics/yolo/data/datasets/coco128.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										101
									
								
								ultralytics/yolo/data/datasets/coco128.yaml
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,101 @@ | ||||
| # YOLOv5 🚀 by Ultralytics, GPL-3.0 license | ||||
| # COCO128 dataset https://www.kaggle.com/ultralytics/coco128 (first 128 images from COCO train2017) by Ultralytics | ||||
| # Example usage: python train.py --data coco128.yaml | ||||
| # parent | ||||
| # ├── yolov5 | ||||
| # └── datasets | ||||
| #     └── coco128  ← downloads here (7 MB) | ||||
|  | ||||
|  | ||||
| # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] | ||||
| path: ../datasets/coco128  # dataset root dir | ||||
| train: images/train2017  # train images (relative to 'path') 128 images | ||||
| val: images/train2017  # val images (relative to 'path') 128 images | ||||
| test:  # test images (optional) | ||||
|  | ||||
| # Classes | ||||
| names: | ||||
|   0: person | ||||
|   1: bicycle | ||||
|   2: car | ||||
|   3: motorcycle | ||||
|   4: airplane | ||||
|   5: bus | ||||
|   6: train | ||||
|   7: truck | ||||
|   8: boat | ||||
|   9: traffic light | ||||
|   10: fire hydrant | ||||
|   11: stop sign | ||||
|   12: parking meter | ||||
|   13: bench | ||||
|   14: bird | ||||
|   15: cat | ||||
|   16: dog | ||||
|   17: horse | ||||
|   18: sheep | ||||
|   19: cow | ||||
|   20: elephant | ||||
|   21: bear | ||||
|   22: zebra | ||||
|   23: giraffe | ||||
|   24: backpack | ||||
|   25: umbrella | ||||
|   26: handbag | ||||
|   27: tie | ||||
|   28: suitcase | ||||
|   29: frisbee | ||||
|   30: skis | ||||
|   31: snowboard | ||||
|   32: sports ball | ||||
|   33: kite | ||||
|   34: baseball bat | ||||
|   35: baseball glove | ||||
|   36: skateboard | ||||
|   37: surfboard | ||||
|   38: tennis racket | ||||
|   39: bottle | ||||
|   40: wine glass | ||||
|   41: cup | ||||
|   42: fork | ||||
|   43: knife | ||||
|   44: spoon | ||||
|   45: bowl | ||||
|   46: banana | ||||
|   47: apple | ||||
|   48: sandwich | ||||
|   49: orange | ||||
|   50: broccoli | ||||
|   51: carrot | ||||
|   52: hot dog | ||||
|   53: pizza | ||||
|   54: donut | ||||
|   55: cake | ||||
|   56: chair | ||||
|   57: couch | ||||
|   58: potted plant | ||||
|   59: bed | ||||
|   60: dining table | ||||
|   61: toilet | ||||
|   62: tv | ||||
|   63: laptop | ||||
|   64: mouse | ||||
|   65: remote | ||||
|   66: keyboard | ||||
|   67: cell phone | ||||
|   68: microwave | ||||
|   69: oven | ||||
|   70: toaster | ||||
|   71: sink | ||||
|   72: refrigerator | ||||
|   73: book | ||||
|   74: clock | ||||
|   75: vase | ||||
|   76: scissors | ||||
|   77: teddy bear | ||||
|   78: hair drier | ||||
|   79: toothbrush | ||||
|  | ||||
|  | ||||
| # Download script/URL (optional) | ||||
| download: https://ultralytics.com/assets/coco128.zip | ||||
| @ -1,11 +1,22 @@ | ||||
| import contextlib | ||||
| import hashlib | ||||
| import os | ||||
| import subprocess | ||||
| import time | ||||
| from pathlib import Path | ||||
| from tarfile import is_tarfile | ||||
| from zipfile import is_zipfile | ||||
|  | ||||
| import cv2 | ||||
| import numpy as np | ||||
| import torch | ||||
| from PIL import ExifTags, Image, ImageOps | ||||
|  | ||||
| from ultralytics.yolo.utils import LOGGER, ROOT, colorstr | ||||
| from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii | ||||
| from ultralytics.yolo.utils.downloads import download | ||||
| from ultralytics.yolo.utils.files import unzip_file, yaml_load | ||||
|  | ||||
| from ..utils.ops import segments2boxes | ||||
|  | ||||
| HELP_URL = "See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data" | ||||
| @ -176,3 +187,89 @@ def polygons2masks_overlap(img_size, segments, downsample_ratio=1): | ||||
|         masks = masks + mask | ||||
|         masks = np.clip(masks, a_min=0, a_max=i + 1) | ||||
|     return masks, index | ||||
|  | ||||
|  | ||||
| def check_dataset_yaml(data, autodownload=True): | ||||
|     # Download, check and/or unzip dataset if not found locally | ||||
|     data = check_file(data) | ||||
|     DATASETS_DIR = Path.cwd() / "../datasets" | ||||
|     # Download (optional) | ||||
|     extract_dir = '' | ||||
|     if isinstance(data, (str, Path)) and (is_zipfile(data) or is_tarfile(data)): | ||||
|         download(data, dir=f'{DATASETS_DIR}/{Path(data).stem}', unzip=True, delete=False, curl=False, threads=1) | ||||
|         data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml')) | ||||
|         extract_dir, autodownload = data.parent, False | ||||
|     # Read yaml (optional) | ||||
|     if isinstance(data, (str, Path)): | ||||
|         data = yaml_load(data)  # dictionary | ||||
|  | ||||
|     # Checks | ||||
|     for k in 'train', 'val', 'names': | ||||
|         assert k in data, f"data.yaml '{k}:' field missing ❌" | ||||
|     if isinstance(data['names'], (list, tuple)):  # old array format | ||||
|         data['names'] = dict(enumerate(data['names']))  # convert to dict | ||||
|     data['nc'] = len(data['names']) | ||||
|  | ||||
|     # Resolve paths | ||||
|     path = Path(extract_dir or data.get('path') or '')  # optional 'path' default to '.' | ||||
|     if not path.is_absolute(): | ||||
|         path = (Path.cwd() / path).resolve() | ||||
|         data['path'] = path  # download scripts | ||||
|     for k in 'train', 'val', 'test': | ||||
|         if data.get(k):  # prepend path | ||||
|             if isinstance(data[k], str): | ||||
|                 x = (path / data[k]).resolve() | ||||
|                 if not x.exists() and data[k].startswith('../'): | ||||
|                     x = (path / data[k][3:]).resolve() | ||||
|                 data[k] = str(x) | ||||
|             else: | ||||
|                 data[k] = [str((path / x).resolve()) for x in data[k]] | ||||
|  | ||||
|     # Parse yaml | ||||
|     train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download')) | ||||
|     if val: | ||||
|         val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])]  # val path | ||||
|         if not all(x.exists() for x in val): | ||||
|             LOGGER.info('\nDataset not found ⚠️, missing paths %s' % [str(x) for x in val if not x.exists()]) | ||||
|             if not s or not autodownload: | ||||
|                 raise Exception('Dataset not found ❌') | ||||
|             t = time.time() | ||||
|             if s.startswith('http') and s.endswith('.zip'):  # URL | ||||
|                 f = Path(s).name  # filename | ||||
|                 LOGGER.info(f'Downloading {s} to {f}...') | ||||
|                 torch.hub.download_url_to_file(s, f) | ||||
|                 Path(DATASETS_DIR).mkdir(parents=True, exist_ok=True)  # create root | ||||
|                 unzip_file(f, path=DATASETS_DIR)  # unzip | ||||
|                 Path(f).unlink()  # remove zip | ||||
|                 r = None  # success | ||||
|             elif s.startswith('bash '):  # bash script | ||||
|                 LOGGER.info(f'Running {s} ...') | ||||
|                 r = os.system(s) | ||||
|             else:  # python script | ||||
|                 r = exec(s, {'yaml': data})  # return None | ||||
|             dt = f'({round(time.time() - t, 1)}s)' | ||||
|             s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌" | ||||
|             LOGGER.info(f"Dataset download {s}") | ||||
|     check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True)  # download fonts | ||||
|     return data  # dictionary | ||||
|  | ||||
|  | ||||
| def check_dataset(dataset: str): | ||||
|     data = Path.cwd() / "datasets" / dataset | ||||
|     data_dir = data if data.is_dir() else (Path.cwd() / data) | ||||
|     if not data_dir.is_dir(): | ||||
|         LOGGER.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...') | ||||
|         t = time.time() | ||||
|         if str(data) == 'imagenet': | ||||
|             subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True) | ||||
|         else: | ||||
|             url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip' | ||||
|             download(url, dir=data_dir.parent) | ||||
|         s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n" | ||||
|         LOGGER.info(s) | ||||
|     train_set = data_dir / "train" | ||||
|     test_set = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val'  # data/test or data/val | ||||
|     nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()])  # number of classes | ||||
|     names = [name for name in os.listdir(data_dir / 'train') if os.path.isdir(data_dir / 'train' / name)] | ||||
|     data = {"train": train_set, "val": test_set, "nc": nc, "names": names} | ||||
|     return data | ||||
|  | ||||
| @ -24,7 +24,9 @@ from tqdm import tqdm | ||||
|  | ||||
| import ultralytics.yolo.utils as utils | ||||
| import ultralytics.yolo.utils.loggers as loggers | ||||
| from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml | ||||
| from ultralytics.yolo.utils import LOGGER, ROOT | ||||
| from ultralytics.yolo.utils.checks import check_file, check_yaml | ||||
| from ultralytics.yolo.utils.files import increment_path, save_yaml | ||||
| from ultralytics.yolo.utils.modeling import get_model | ||||
|  | ||||
| @ -55,9 +57,14 @@ class BaseTrainer: | ||||
|         self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu') | ||||
|  | ||||
|         # Model and Dataloaders. | ||||
|         self.trainset, self.testset = self.get_dataset(self.args.data) | ||||
|         self.data = self.args.data | ||||
|         if self.data.endswith(".yaml"): | ||||
|             self.data = check_dataset_yaml(self.data) | ||||
|         else: | ||||
|             self.data = check_dataset(self.data) | ||||
|         self.trainset, self.testset = self.get_dataset(self.data) | ||||
|         if self.args.cfg is not None: | ||||
|             self.model = self.load_cfg(self.args.cfg) | ||||
|             self.model = self.load_cfg(check_file(self.args.cfg)) | ||||
|         if self.args.model is not None: | ||||
|             self.model = self.get_model(self.args.model, self.args.pretrained).to(self.device) | ||||
|  | ||||
| @ -250,10 +257,9 @@ class BaseTrainer: | ||||
|  | ||||
|     def get_dataset(self, data): | ||||
|         """ | ||||
|         Download the dataset if needed and verify it. | ||||
|         Returns train and val split datasets | ||||
|         Get train, val path from data dict if it exists. Returns None if data format is not recognized | ||||
|         """ | ||||
|         pass | ||||
|         return data["train"], data["val"] | ||||
|  | ||||
|     def get_model(self, model, pretrained): | ||||
|         """ | ||||
|  | ||||
| @ -7,7 +7,7 @@ from pathlib import Path | ||||
|  | ||||
| # Constants | ||||
| FILE = Path(__file__).resolve() | ||||
| ROOT = FILE.parents[2]  # YOLOv5 root directory | ||||
| ROOT = FILE.parents[2]  # YOLO | ||||
| RANK = int(os.getenv('RANK', -1)) | ||||
| DATASETS_DIR = ROOT.parent / 'datasets'  # YOLOv5 datasets directory | ||||
| NUM_THREADS = min(8, max(1, os.cpu_count() - 1))  # number of YOLOv5 multiprocessing threads | ||||
|  | ||||
| @ -116,13 +116,10 @@ def check_file(file, suffix=''): | ||||
|             torch.hub.download_url_to_file(url, file) | ||||
|             assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}'  # check | ||||
|         return file | ||||
|     elif file.startswith('clearml://'):  # ClearML Dataset ID | ||||
|         assert 'clearml' in sys.modules, "Can not use ClearML dataset. Run 'pip install clearml' to install" | ||||
|         return file | ||||
|     else:  # search | ||||
|         files = [] | ||||
|         for d in 'data', 'models', 'utils':  # search directories | ||||
|             files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True))  # find file | ||||
|         for d in 'data', 'v8', 'utils':  # search directories | ||||
|             files.extend(glob.glob(str(ROOT / "yolo" / d / '**' / file), recursive=True))  # find file | ||||
|         assert len(files), f'File not found: {file}'  # assert file was found | ||||
|         assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}"  # assert unique | ||||
|         return files[0]  # return file | ||||
|  | ||||
| @ -1,6 +1,7 @@ | ||||
| import contextlib | ||||
| import os | ||||
| from pathlib import Path | ||||
| from zipfile import ZipFile | ||||
|  | ||||
| import yaml | ||||
|  | ||||
| @ -44,3 +45,19 @@ def save_yaml(file='data.yaml', data=None): | ||||
|     # Single-line safe yaml saving | ||||
|     with open(file, 'w') as f: | ||||
|         yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False) | ||||
|  | ||||
|  | ||||
| def yaml_load(file='data.yaml'): | ||||
|     # Single-line safe yaml loading | ||||
|     with open(file, errors='ignore') as f: | ||||
|         return yaml.safe_load(f) | ||||
|  | ||||
|  | ||||
| def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX')): | ||||
|     # Unzip a *.zip file to path/, excluding files containing strings in exclude list | ||||
|     if path is None: | ||||
|         path = Path(file).parent  # default path | ||||
|     with ZipFile(file) as zipObj: | ||||
|         for f in zipObj.namelist():  # list all archived filenames in the zip | ||||
|             if all(x not in f for x in exclude): | ||||
|                 zipObj.extract(f, path=path) | ||||
|  | ||||
| @ -118,9 +118,3 @@ def get_model(model='s.pt', pretrained=True): | ||||
|         return torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None) | ||||
|     else:  # Ultralytics assets | ||||
|         return torch.load(attempt_download(f"{model}.pt"), map_location='cpu') | ||||
|  | ||||
|  | ||||
| def yaml_load(file='data.yaml'): | ||||
|     # Single-line safe yaml loading | ||||
|     with open(file, errors='ignore') as f: | ||||
|         return yaml.safe_load(f) | ||||
|  | ||||
| @ -32,7 +32,8 @@ class AutoBackend(nn.Module): | ||||
|         #   TensorFlow Lite:                *.tflite | ||||
|         #   TensorFlow Edge TPU:            *_edgetpu.tflite | ||||
|         #   PaddlePaddle:                   *_paddle_model | ||||
|         from ultralytics.yolo.utils.modeling import attempt_load_weights, yaml_load | ||||
|         from ultralytics.yolo.utils.files import yaml_load | ||||
|         from ultralytics.yolo.utils.modeling import attempt_load_weights | ||||
|  | ||||
|         super().__init__() | ||||
|         w = str(weights[0] if isinstance(weights, list) else weights) | ||||
| @ -315,7 +316,7 @@ class AutoBackend(nn.Module): | ||||
|  | ||||
|     @staticmethod | ||||
|     def _load_metadata(f=Path('path/to/meta.yaml')): | ||||
|         from ultralytics.yolo.utils.modeling import yaml_load | ||||
|         from ultralytics.yolo.utils.files import yaml_load | ||||
|  | ||||
|         # Load metadata from meta.yaml if it exists | ||||
|         if f.exists(): | ||||
|  | ||||
| @ -17,26 +17,6 @@ from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, torch_distributed_zer | ||||
| # BaseTrainer python usage | ||||
| class ClassificationTrainer(BaseTrainer): | ||||
|  | ||||
|     def get_dataset(self, dataset): | ||||
|         # temporary solution. Replace with new ultralytics.yolo.ClassificationDataset module | ||||
|         data = Path("datasets") / dataset | ||||
|         with torch_distributed_zero_first(LOCAL_RANK), WorkingDirectory(Path.cwd()): | ||||
|             data_dir = data if data.is_dir() else (Path.cwd() / data) | ||||
|             if not data_dir.is_dir(): | ||||
|                 self.console.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...') | ||||
|                 t = time.time() | ||||
|                 if str(data) == 'imagenet': | ||||
|                     subprocess.run(f"bash {v8.ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True) | ||||
|                 else: | ||||
|                     url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip' | ||||
|                     download(url, dir=data_dir.parent) | ||||
|                 s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n" | ||||
|                 self.console.info(s) | ||||
|         train_set = data_dir / "train" | ||||
|         test_set = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val'  # data/test or data/val | ||||
|  | ||||
|         return train_set, test_set | ||||
|  | ||||
|     def get_dataloader(self, dataset_path, batch_size=None, rank=0): | ||||
|         return build_classification_dataloader(path=dataset_path, | ||||
|                                                imgsz=self.args.img_size, | ||||
|  | ||||
| @ -21,26 +21,6 @@ from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, de_parallel, torch_di | ||||
| # BaseTrainer python usage | ||||
| class SegmentationTrainer(BaseTrainer): | ||||
|  | ||||
|     def get_dataset(self, dataset): | ||||
|         # temporary solution. Replace with new ultralytics.yolo.ClassificationDataset module | ||||
|         data = Path("datasets") / dataset | ||||
|         with torch_distributed_zero_first(LOCAL_RANK), WorkingDirectory(Path.cwd()): | ||||
|             data_dir = data if data.is_dir() else (Path.cwd() / data) | ||||
|             if not data_dir.is_dir(): | ||||
|                 self.console.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...') | ||||
|                 t = time.time() | ||||
|                 if str(data) == 'imagenet': | ||||
|                     subprocess.run(f"bash {v8.ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True) | ||||
|                 else: | ||||
|                     url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip' | ||||
|                     download(url, dir=data_dir.parent) | ||||
|                 # TODO: add colorstr | ||||
|                 s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {'bold', data_dir}\n" | ||||
|                 self.console.info(s) | ||||
|         train_set = data_dir.parent / "coco128-seg" | ||||
|         test_set = train_set | ||||
|         return train_set, test_set | ||||
|  | ||||
|     def get_dataloader(self, dataset_path, batch_size, rank=0): | ||||
|         # TODO: manage splits differently | ||||
|         # calculate stride - check if model is initialized | ||||
| @ -253,7 +233,7 @@ class SegmentationTrainer(BaseTrainer): | ||||
| @hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name) | ||||
| def train(cfg): | ||||
|     cfg.cfg = v8.ROOT / "models/yolov5n-seg.yaml" | ||||
|     cfg.data = cfg.data or "coco128-segments"  # or yolo.ClassificationDataset("mnist") | ||||
|     cfg.data = cfg.data or "coco128-seg.yaml"  # or yolo.ClassificationDataset("mnist") | ||||
|     trainer = SegmentationTrainer(cfg) | ||||
|     trainer.train() | ||||
|  | ||||
|  | ||||
| @ -8,9 +8,9 @@ import torch.nn.functional as F | ||||
| from ultralytics.yolo.engine.validator import BaseValidator | ||||
| from ultralytics.yolo.utils import ops | ||||
| from ultralytics.yolo.utils.checks import check_requirements | ||||
| from ultralytics.yolo.utils.files import yaml_load | ||||
| from ultralytics.yolo.utils.metrics import (ConfusionMatrix, Metrics, ap_per_class_box_and_mask, box_iou, | ||||
|                                             fitness_segmentation, mask_iou) | ||||
| from ultralytics.yolo.utils.modeling import yaml_load | ||||
| from ultralytics.yolo.utils.torch_utils import de_parallel | ||||
|  | ||||
|  | ||||
|  | ||||
		Reference in New Issue
	
	Block a user