Create Exporter() Class (#117)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@ -5,7 +5,7 @@ import torch
|
||||
from ultralytics import yolo # noqa required for python usage
|
||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_weights
|
||||
from ultralytics.yolo.configs import get_config
|
||||
from ultralytics.yolo.engine.exporter import export_model
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, HELP_MSG, LOGGER
|
||||
from ultralytics.yolo.utils.checks import check_yaml
|
||||
from ultralytics.yolo.utils.files import yaml_load
|
||||
@ -164,7 +164,7 @@ class YOLO:
|
||||
validator(model=self.model)
|
||||
|
||||
@smart_inference_mode()
|
||||
def export(self, format='', save_dir='', **kwargs):
|
||||
def export(self, **kwargs):
|
||||
"""
|
||||
Export model.
|
||||
|
||||
@ -177,36 +177,9 @@ class YOLO:
|
||||
overrides.update(kwargs)
|
||||
args = get_config(config=DEFAULT_CONFIG, overrides=overrides)
|
||||
args.task = self.task
|
||||
args.format = format
|
||||
|
||||
file = self.ckpt or Path(Path(self.cfg).name)
|
||||
if save_dir:
|
||||
file = Path(save_dir) / file.name
|
||||
file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
export_model(
|
||||
model=self.model,
|
||||
file=file,
|
||||
data=args.data, # 'dataset.yaml path'
|
||||
imgsz=args.imgsz or (640, 640), # image (height, width)
|
||||
batch_size=1, # batch size
|
||||
device=args.device, # cuda device, i.e. 0 or 0,1,2,3 or cpu
|
||||
format=args.format, # include formats
|
||||
half=args.half or False, # FP16 half-precision export
|
||||
keras=args.keras or False, # use Keras
|
||||
optimize=args.optimize or False, # TorchScript: optimize for mobile
|
||||
int8=args.int8 or False, # CoreML/TF INT8 quantization
|
||||
dynamic=args.dynamic or False, # ONNX/TF/TensorRT: dynamic axes
|
||||
opset=args.opset or 17, # ONNX: opset version
|
||||
verbose=False, # TensorRT: verbose log
|
||||
workspace=args.workspace or 4, # TensorRT: workspace size (GB)
|
||||
nms=False, # TF: add NMS to model
|
||||
agnostic_nms=False, # TF: add agnostic NMS to model
|
||||
topk_per_class=100, # TF.js NMS: topk per class to keep
|
||||
topk_all=100, # TF.js NMS: topk for all classes to keep
|
||||
iou_thres=0.45, # TF.js NMS: IoU threshold
|
||||
conf_thres=0.25, # TF.js NMS: confidence threshold
|
||||
)
|
||||
exporter = Exporter(overrides=overrides)
|
||||
exporter(model=self.model)
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""
|
||||
|
@ -16,14 +16,14 @@ Usage - formats:
|
||||
$ yolo task=... mode=predict --weights yolov8n.pt # PyTorch
|
||||
yolov8n.torchscript # TorchScript
|
||||
yolov8n.onnx # ONNX Runtime or OpenCV DNN with --dnn
|
||||
yolov5s_openvino_model # OpenVINO
|
||||
yolov8n_openvino_model # OpenVINO
|
||||
yolov8n.engine # TensorRT
|
||||
yolov8n.mlmodel # CoreML (macOS-only)
|
||||
yolov5s_saved_model # TensorFlow SavedModel
|
||||
yolov8n_saved_model # TensorFlow SavedModel
|
||||
yolov8n.pb # TensorFlow GraphDef
|
||||
yolov8n.tflite # TensorFlow Lite
|
||||
yolov5s_edgetpu.tflite # TensorFlow Edge TPU
|
||||
yolov5s_paddle_model # PaddlePaddle
|
||||
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
|
||||
yolov8n_paddle_model # PaddlePaddle
|
||||
"""
|
||||
import platform
|
||||
from pathlib import Path
|
||||
|
Reference in New Issue
Block a user