diff --git a/ultralytics/yolo/engine/exporter.py b/ultralytics/yolo/engine/exporter.py index dc457a0..ababe0c 100644 --- a/ultralytics/yolo/engine/exporter.py +++ b/ultralytics/yolo/engine/exporter.py @@ -313,14 +313,11 @@ class Exporter: # Simplify if self.args.simplify: try: - cuda = torch.cuda.is_available() - check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1')) - import onnxsim # noqa + check_requirements('onnxsim') + import onnxsim LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...') - model_onnx, check = onnxsim.simplify(model_onnx) - assert check, 'assert check failed' - onnx.save(model_onnx, f) + subprocess.run(f'onnxsim {f} {f}', shell=True) except Exception as e: LOGGER.info(f'{prefix} simplifier failure: {e}') return f, model_onnx @@ -460,6 +457,40 @@ class Exporter: iou_thres=0.45, conf_thres=0.25, prefix=colorstr('TensorFlow SavedModel:')): + + # YOLOv5 TensorFlow SavedModel export + try: + import tensorflow as tf # noqa + except ImportError: + check_requirements(f"tensorflow{'' if torch.cuda.is_available() else '-macos' if MACOS else '-cpu'}") + import tensorflow as tf # noqa + check_requirements(("onnx", "onnx2tf", "sng4onnx", "onnxsim", "onnx_graphsurgeon"), + cmds="--extra-index-url https://pypi.ngc.nvidia.com ") + + LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') + f = str(self.file).replace(self.file.suffix, '_saved_model') + + # Export to ONNX + self._export_onnx() + onnx = self.file.with_suffix('.onnx') + + # Export to TF SavedModel + subprocess.run(f'onnx2tf -i {onnx} --output_signaturedefs -o {f}', shell=True) + + # Load saved_model + keras_model = tf.saved_model.load(f, tags=None, options=None) + + return f, keras_model + + @try_export + def _export_saved_model_OLD(self, + nms=False, + agnostic_nms=False, + topk_per_class=100, + topk_all=100, + iou_thres=0.45, + conf_thres=0.25, + prefix=colorstr('TensorFlow SavedModel:')): # YOLOv5 TensorFlow SavedModel export try: import tensorflow as tf # noqa diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index 6f148ac..7c0bd30 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -52,8 +52,8 @@ class YOLO: # Load or create new YOLO model {'.pt': self._load, '.yaml': self._new}[Path(model).suffix](model) - def __call__(self, source): - return self.predict(source) + def __call__(self, source, **kwargs): + return self.predict(source, **kwargs) def _new(self, cfg: str, verbose=True): """ @@ -218,3 +218,4 @@ class YOLO: args.pop("name", None) args.pop("batch", None) args.pop("epochs", None) + args.pop("cache", None)