From eb5adf4e0b7a4cb71af43b994db236a8c90830ed Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Wed, 14 Dec 2022 14:33:31 +0530 Subject: [PATCH] Model enhancement (#75) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ultralytics/tests/test_model.py | 61 ++++++++- ultralytics/yolo/engine/model.py | 126 +++++++++++++++--- .../yolo/utils/modeling/autobackend.py | 20 ++- 3 files changed, 176 insertions(+), 31 deletions(-) diff --git a/ultralytics/tests/test_model.py b/ultralytics/tests/test_model.py index 353fab1..bd1b6ee 100644 --- a/ultralytics/tests/test_model.py +++ b/ultralytics/tests/test_model.py @@ -1,13 +1,62 @@ +import torch + from ultralytics.yolo import YOLO -def test_model(): +def test_model_forward(): + model = YOLO() + model.new("yolov5n-seg.yaml") + img = torch.rand(512 * 512 * 3).view(1, 3, 512, 512) + model.forward(img) + model(img) + + +def test_model_info(): + model = YOLO() + model.new("yolov5n.yaml") + model.info() + model.load("balloon-detect.pt") + model.info(verbose=True) + + +def test_model_fuse(): + model = YOLO() + model.new("yolov5n.yaml") + model.fuse() + model.load("balloon-detect.pt") + model.fuse() + + +def test_visualize_preds(): model = YOLO() - model.new("assets/dummy_model.yaml") - model.model = "squeezenet1_0" # temp solution before get_model is implemented - # model.load("yolov5n.pt") - model.train(data="imagenette160", epochs=1, lr0=0.01) + model.load("balloon-segment.pt") + model.predict(source="ultralytics/assets") + + +def test_val(): + model = YOLO() + model.load("balloon-segment.pt") + model.val(data="coco128-seg.yaml", img_size=32) + + +def test_model_resume(): + model = YOLO() + model.new("yolov5n-seg.yaml") + model.train(epochs=1, img_size=32, data="coco128-seg.yaml") + try: + model.resume(task="segment") + except AssertionError: + print("Successfully caught resume assert!") + + +def test(): + test_model_forward() + test_model_info() + test_model_fuse() + test_visualize_preds() + test_val() + test_model_resume() if __name__ == "__main__": - test_model() + test() diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index 733a438..b499846 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -1,18 +1,28 @@ import torch import yaml +from omegaconf import OmegaConf from ultralytics import yolo +from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG from ultralytics.yolo.utils import LOGGER from ultralytics.yolo.utils.checks import check_yaml +from ultralytics.yolo.utils.configs import get_config from ultralytics.yolo.utils.files import yaml_load from ultralytics.yolo.utils.modeling import attempt_load_weights from ultralytics.yolo.utils.modeling.tasks import ClassificationModel, DetectionModel, SegmentationModel +from ultralytics.yolo.utils.torch_utils import smart_inference_mode -# map head: [model, trainer] +# map head: [model, trainer, validator, predictor] MODEL_MAP = { - "classify": [ClassificationModel, 'yolo.TYPE.classify.ClassificationTrainer'], - "detect": [DetectionModel, 'yolo.TYPE.detect.DetectionTrainer'], - "segment": [SegmentationModel, 'yolo.TYPE.segment.SegmentationTrainer']} + "classify": [ + ClassificationModel, 'yolo.TYPE.classify.ClassificationTrainer', 'yolo.TYPE.classify.ClassificationValidator', + 'yolo.TYPE.classify.ClassificationPredictor'], + "detect": [ + DetectionModel, 'yolo.TYPE.detect.DetectionTrainer', 'yolo.TYPE.detect.DetectionValidator', + 'yolo.TYPE.detect.DetectionPredictor'], + "segment": [ + SegmentationModel, 'yolo.TYPE.segment.SegmentationTrainer', 'yolo.TYPE.segment.SegmentationValidator', + 'yolo.TYPE.segment.SegmentationPredictor']} class YOLO: @@ -28,6 +38,8 @@ class YOLO: self.type = type self.ModelClass = None self.TrainerClass = None + self.ValidatorClass = None + self.PredictorClass = None self.model = None self.trainer = None self.task = None @@ -43,7 +55,9 @@ class YOLO: cfg = check_yaml(cfg) # check YAML with open(cfg, encoding='ascii', errors='ignore') as f: cfg = yaml.safe_load(f) # model dict - self.ModelClass, self.TrainerClass, self.task = self._guess_model_trainer_and_task(cfg["head"][-1][-2]) + self.task = self._guess_task_from_head(cfg["head"][-1][-2]) + self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._guess_ops_from_task( + self.task) self.model = self.ModelClass(cfg) # initialize def load(self, weights: str): @@ -56,8 +70,8 @@ class YOLO: """ self.ckpt = torch.load(weights, map_location="cpu") self.task = self.ckpt["train_args"]["task"] - _, trainer_class_literal = MODEL_MAP[self.task] - self.TrainerClass = eval(trainer_class_literal.replace("TYPE", f"v{self.type}")) + self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._guess_ops_from_task( + task=self.task) self.model = attempt_load_weights(weights) def reset(self): @@ -70,6 +84,60 @@ class YOLO: for p in self.model.parameters(): p.requires_grad = True + def info(self, verbose=False): + """ + Logs model info + + Args: + verbose (bool): Controls verbosity. + """ + if not self.model: + LOGGER.info("model not initialized!") + self.model.info(verbose=verbose) + + def fuse(self): + if not self.model: + LOGGER.info("model not initialized!") + self.model.fuse() + + def predict(self, source, **kwargs): + """ + Visualize prection. + + Args: + source (str): Accepts all source types accepted by yolo + **kwargs : Any other args accepted by the predictors. Too see all args check 'configuration' section in the docs + """ + predictor = self.PredictorClass(overrides=kwargs) + + # check size type + sz = predictor.args.img_size + if type(sz) != int: # recieved listConfig + predictor.args.img_size = [sz[0], sz[0]] if len(sz) == 1 else [sz[0], sz[1]] # expand + else: + predictor.args.img_size = [sz, sz] + + predictor.setup(model=self.model, source=source) + predictor() + + def val(self, data, **kwargs): + """ + Validate a model on a given dataset + + Args: + data (str): The dataset to validate on. Accepts all formats accepted by yolo + kwargs: Any other args accepted by the validators. Too see all args check 'configuration' section in the docs + """ + if not self.model: + raise Exception("model not initialized!") + + args = get_config(config=DEFAULT_CONFIG, overrides=kwargs) + args.data = data + args.task = self.task + + validator = self.ValidatorClass(args=args) + validator(model=self.model) + def train(self, **kwargs): """ Trains the model on given dataset. @@ -95,22 +163,28 @@ class YOLO: self.trainer.model = self.trainer.load_model(weights=self.ckpt) if self.ckpt else self.model self.trainer.train() - def resume(self, task, model=None): + def resume(self, task=None, model=None): """ - Resume a training task. - + Resume a training task. Requires either `task` or `model`. `model` takes the higher precederence. Args: task (str): The task type you want to resume. Automatically finds the last run to resume if `model` is not specified. - model (str): [Optional] The model checkpoint to resume from. If not found, the last run of the given task type is resumed. + model (str): The model checkpoint to resume from. If not found, the last run of the given task type is resumed. + If `model` is speficied """ - if task.lower() not in MODEL_MAP: - raise Exception(f"unrecognised task - {task}. Supported tasks are {MODEL_MAP.keys()}") - _, trainer_class_literal = MODEL_MAP[task.lower()] - self.TrainerClass = eval(trainer_class_literal.replace("TYPE", f"v{self.type}")) + if task: + if task.lower() not in MODEL_MAP: + raise Exception(f"unrecognised task - {task}. Supported tasks are {MODEL_MAP.keys()}") + else: + ckpt = torch.load(model, map_location="cpu") + task = ckpt["train_args"]["task"] + del ckpt + self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._guess_ops_from_task( + task=task.lower()) self.trainer = self.TrainerClass(overrides={"task": task.lower(), "resume": model if model else True}) self.trainer.train() - def _guess_model_trainer_and_task(self, head): + @staticmethod + def _guess_task_from_head(head): task = None if head.lower() in ["classify", "classifier", "cls", "fc"]: task = "classify" @@ -118,13 +192,27 @@ class YOLO: task = "detect" if head.lower() in ["segment"]: task = "segment" - model_class, trainer_class = MODEL_MAP[task] + + if not task: + raise Exception( + "task or model not recognized! Please refer the docs at : ") # TODO: add gitHub and docs links + + return task + + def _guess_ops_from_task(self, task): + model_class, train_lit, val_lit, pred_lit = MODEL_MAP[task] # warning: eval is unsafe. Use with caution - trainer_class = eval(trainer_class.replace("TYPE", f"{self.type}")) + trainer_class = eval(train_lit.replace("TYPE", f"{self.type}")) + validator_class = eval(val_lit.replace("TYPE", f"{self.type}")) + predictor_class = eval(pred_lit.replace("TYPE", f"{self.type}")) - return model_class, trainer_class, task + return model_class, trainer_class, validator_class, predictor_class + @smart_inference_mode() def __call__(self, imgs): if not self.model: LOGGER.info("model not initialized!") return self.model(imgs) + + def forward(self, imgs): + return self.__call__(imgs) diff --git a/ultralytics/yolo/utils/modeling/autobackend.py b/ultralytics/yolo/utils/modeling/autobackend.py index 23da107..d5366a5 100644 --- a/ultralytics/yolo/utils/modeling/autobackend.py +++ b/ultralytics/yolo/utils/modeling/autobackend.py @@ -37,15 +37,23 @@ class AutoBackend(nn.Module): super().__init__() w = str(weights[0] if isinstance(weights, list) else weights) + nn_module = isinstance(weights, torch.nn.Module) pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, triton = self._model_type(w) - fp16 &= pt or jit or onnx or engine # FP16 + fp16 &= pt or jit or onnx or engine or nn_module # FP16 nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH) stride = 32 # default stride cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA - if not (pt or triton): + if not (pt or triton or nn_module): w = attempt_download(w) # download if not local - if pt: # PyTorch + # NOTE: special case: in-memory pytorch model + if nn_module: + model = weights.to(device) + model = model.fuse() if fuse else model + names = model.module.names if hasattr(model, 'module') else model.names # get class names + model.half() if fp16 else model.float() + self.model = model # explicitly assign for to(), cpu(), cuda(), half() + elif pt: # PyTorch model = attempt_load_weights(weights if isinstance(weights, list) else w, device=device, inplace=True, @@ -215,7 +223,7 @@ class AutoBackend(nn.Module): if self.nhwc: im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3) - if self.pt: # PyTorch + if self.pt or self.nn_module: # PyTorch y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im) elif self.jit: # TorchScript y = self.model(im) @@ -294,7 +302,7 @@ class AutoBackend(nn.Module): def warmup(self, imgsz=(1, 3, 640, 640)): # Warmup model by running inference once - warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton + warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module if any(warmup_types) and (self.device.type != 'cpu' or self.triton): im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input for _ in range(2 if self.jit else 1): # @@ -306,7 +314,7 @@ class AutoBackend(nn.Module): # types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle] from ultralytics.yolo.engine.exporter import export_formats sf = list(export_formats().Suffix) # export suffixes - if not is_url(p, check=False): + if not is_url(p, check=False) and not isinstance(p, str): check_suffix(p, sf) # checks url = urlparse(p) # if url may be Triton inference server types = [s in Path(p).name for s in sf]