Model interface enhancement (#106)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
		
							
								
								
									
										6
									
								
								.github/workflows/ci.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/workflows/ci.yaml
									
									
									
									
										vendored
									
									
								
							| @ -91,14 +91,14 @@ jobs: | ||||
|         shell: bash  # for Windows compatibility | ||||
|         run: | | ||||
|           yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=1 imgsz=64 | ||||
|           yolo task=detect mode=val model=runs/train/exp/weights/last.pt imgsz=64 | ||||
|           yolo task=detect mode=val model=runs/detect/train/weights/last.pt imgsz=64 | ||||
|       - name: Test segmentation  # TODO: segmentation CI | ||||
|         shell: bash  # for Windows compatibility | ||||
|         run: | | ||||
|         # yolo task=segment mode=train model=yolov8n-seg.yaml data=coco128-seg.yaml epochs=1 imgsz=64 | ||||
|         # yolo task=segment mode=val model=runs/train/exp2/weights/last.pt data=coco128-seg.yaml imgsz=64 | ||||
|         # yolo task=segment mode=val model=runs/segment/train/weights/last.pt data=coco128-seg.yaml imgsz=64 | ||||
|       - name: Test classification  # TODO: change to exp3 on Segmentation CI update | ||||
|         shell: bash  # for Windows compatibility | ||||
|         run: | | ||||
|           yolo task=classify mode=train model=resnet18 data=mnist160 epochs=1 imgsz=32 | ||||
|           yolo task=classify mode=val model=runs/train/exp2/weights/last.pt data=mnist160 | ||||
|           yolo task=classify mode=val model=runs/classify/train/weights/last.pt data=mnist160 | ||||
|  | ||||
| @ -1,11 +1,11 @@ | ||||
| import torch | ||||
|  | ||||
| from ultralytics.yolo import YOLO | ||||
| from ultralytics import YOLO | ||||
|  | ||||
|  | ||||
| def test_model_forward(): | ||||
|     model = YOLO() | ||||
|     model.new("yolov8n-seg.yaml") | ||||
|     model.new("yolov8n.yaml") | ||||
|     img = torch.rand(512 * 512 * 3).view(1, 3, 512, 512) | ||||
|     model.forward(img) | ||||
|     model(img) | ||||
| @ -15,7 +15,7 @@ def test_model_info(): | ||||
|     model = YOLO() | ||||
|     model.new("yolov8n.yaml") | ||||
|     model.info() | ||||
|     model.load("balloon-detect.pt") | ||||
|     model.load("best.pt") | ||||
|     model.info(verbose=True) | ||||
|  | ||||
|  | ||||
| @ -23,35 +23,35 @@ def test_model_fuse(): | ||||
|     model = YOLO() | ||||
|     model.new("yolov8n.yaml") | ||||
|     model.fuse() | ||||
|     model.load("balloon-detect.pt") | ||||
|     model.load("best.pt") | ||||
|     model.fuse() | ||||
|  | ||||
|  | ||||
| def test_visualize_preds(): | ||||
|     model = YOLO() | ||||
|     model.load("balloon-segment.pt") | ||||
|     model.load("best.pt") | ||||
|     model.predict(source="ultralytics/assets") | ||||
|  | ||||
|  | ||||
| def test_val(): | ||||
|     model = YOLO() | ||||
|     model.load("balloon-segment.pt") | ||||
|     model.val(data="coco128-seg.yaml", imgsz=32) | ||||
|     model.load("best.pt") | ||||
|     model.val(data="coco128.yaml", imgsz=32) | ||||
|  | ||||
|  | ||||
| def test_model_resume(): | ||||
|     model = YOLO() | ||||
|     model.new("yolov8n-seg.yaml") | ||||
|     model.train(epochs=1, imgsz=32, data="coco128-seg.yaml") | ||||
|     model.new("yolov8n.yaml") | ||||
|     model.train(epochs=1, imgsz=32, data="coco128.yaml") | ||||
|     try: | ||||
|         model.resume(task="segment") | ||||
|         model.resume(task="detect") | ||||
|     except AssertionError: | ||||
|         print("Successfully caught resume assert!") | ||||
|  | ||||
|  | ||||
| def test_model_train_pretrained(): | ||||
|     model = YOLO() | ||||
|     model.load("balloon-detect.pt") | ||||
|     model.load("best.pt") | ||||
|     model.train(data="coco128.yaml", epochs=1, imgsz=32) | ||||
|     model.new("yolov8n.yaml") | ||||
|     model.train(data="coco128.yaml", epochs=1, imgsz=32) | ||||
|  | ||||
| @ -43,6 +43,7 @@ class YOLO: | ||||
|         self.trainer = None | ||||
|         self.task = None | ||||
|         self.ckpt = None | ||||
|         self.overrides = {} | ||||
|  | ||||
|     def new(self, cfg: str): | ||||
|         """ | ||||
| @ -69,6 +70,10 @@ class YOLO: | ||||
|         """ | ||||
|         self.ckpt = torch.load(weights, map_location="cpu") | ||||
|         self.task = self.ckpt["train_args"]["task"] | ||||
|         self.overrides = dict(self.ckpt["train_args"]) | ||||
|         self.overrides["device"] = ''  # reset device | ||||
|         LOGGER.info("Device has been reset to ''") | ||||
|  | ||||
|         self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._guess_ops_from_task( | ||||
|             task=self.task) | ||||
|         self.model = attempt_load_weights(weights) | ||||
| @ -107,6 +112,7 @@ class YOLO: | ||||
|         source (str): Accepts all source types accepted by yolo | ||||
|         **kwargs : Any other args accepted by the predictors. Too see all args check 'configuration' section in the docs | ||||
|         """ | ||||
|         kwargs.update(self.overrides) | ||||
|         predictor = self.PredictorClass(overrides=kwargs) | ||||
|  | ||||
|         # check size type | ||||
| @ -119,7 +125,7 @@ class YOLO: | ||||
|         predictor.setup(model=self.model, source=source) | ||||
|         predictor() | ||||
|  | ||||
|     def val(self, data, **kwargs): | ||||
|     def val(self, data=None, **kwargs): | ||||
|         """ | ||||
|         Validate a model on a given dataset | ||||
|  | ||||
| @ -130,8 +136,9 @@ class YOLO: | ||||
|         if not self.model: | ||||
|             raise Exception("model not initialized!") | ||||
|  | ||||
|         kwargs.update(self.overrides) | ||||
|         args = get_config(config=DEFAULT_CONFIG, overrides=kwargs) | ||||
|         args.data = data | ||||
|         args.data = data or args.data | ||||
|         args.task = self.task | ||||
|  | ||||
|         validator = self.ValidatorClass(args=args) | ||||
|  | ||||
| @ -86,10 +86,15 @@ class BasePredictor: | ||||
|  | ||||
|         # data | ||||
|         if self.data: | ||||
|             if self.data.endswith(".yaml"): | ||||
|                 self.data = check_dataset_yaml(self.data) | ||||
|             else: | ||||
|                 self.data = check_dataset(self.data) | ||||
|             try: | ||||
|                 if self.data.endswith(".yaml"): | ||||
|                     self.data = check_dataset_yaml(self.data) | ||||
|                 else: | ||||
|                     self.data = check_dataset(self.data) | ||||
|             except AssertionError as e: | ||||
|                 LOGGER.info(f"Error ocurred: {e}") | ||||
|             finally: | ||||
|                 LOGGER.info("Predictor will continue without reading the dataset") | ||||
|  | ||||
|         # model | ||||
|         device = select_device(self.args.device) | ||||
|  | ||||
| @ -46,10 +46,15 @@ class BaseTrainer: | ||||
|         self.validator = None | ||||
|         self.model = None | ||||
|         self.callbacks = defaultdict(list) | ||||
|         self.save_dir = increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok) | ||||
|  | ||||
|         # dirs | ||||
|         project = overrides.get("project") or self.args.task | ||||
|         name = overrides.get("name") or self.args.mode | ||||
|         self.save_dir = increment_path(Path("runs") / project / name, exist_ok=self.args.exist_ok) | ||||
|         self.wdir = self.save_dir / 'weights'  # weights dir | ||||
|         self.wdir.mkdir(parents=True, exist_ok=True)  # make dir | ||||
|         self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt'  # checkpoint paths | ||||
|  | ||||
|         self.batch_size = self.args.batch_size | ||||
|         self.epochs = self.args.epochs | ||||
|         self.start_epoch = 0 | ||||
|  | ||||
| @ -6,7 +6,7 @@ from omegaconf import DictConfig, OmegaConf | ||||
| from ultralytics.yolo.utils.configs.hydra_patch import check_config_mismatch | ||||
|  | ||||
|  | ||||
| def get_config(config: Union[str, DictConfig], overrides: Union[str, Dict]): | ||||
| def get_config(config: Union[str, DictConfig], overrides: Union[str, Dict] = {}): | ||||
|     """ | ||||
|     Accepts yaml file name or DictConfig containing experiment configuration. | ||||
|     Returns training args namespace | ||||
|  | ||||
		Reference in New Issue
	
	Block a user