Check PyTorch model status for all YOLO
methods (#945)
Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
@ -5,12 +5,12 @@ import requests
|
||||
from ultralytics.hub.auth import Auth
|
||||
from ultralytics.hub.session import HubTrainingSession
|
||||
from ultralytics.hub.utils import split_key
|
||||
from ultralytics.yolo.engine.exporter import export_formats
|
||||
from ultralytics.yolo.engine.exporter import EXPORT_FORMATS_LIST
|
||||
from ultralytics.yolo.engine.model import YOLO
|
||||
from ultralytics.yolo.utils import LOGGER, PREFIX, emojis
|
||||
|
||||
# Define all export formats
|
||||
EXPORT_FORMATS = list(export_formats()['Argument'][1:]) + ["ultralytics_tflite", "ultralytics_coreml"]
|
||||
EXPORT_FORMATS_HUB = EXPORT_FORMATS_LIST + ["ultralytics_tflite", "ultralytics_coreml"]
|
||||
|
||||
|
||||
def start(key=""):
|
||||
@ -69,7 +69,7 @@ def reset_model(key=""):
|
||||
|
||||
def export_model(key="", format="torchscript"):
|
||||
# Export a model to all formats
|
||||
assert format in EXPORT_FORMATS, f"Unsupported export format '{format}' passed, valid formats are {EXPORT_FORMATS}"
|
||||
assert format in EXPORT_FORMATS_HUB, f"Unsupported export format '{format}', valid formats are {EXPORT_FORMATS_HUB}"
|
||||
api_key, model_id = split_key(key)
|
||||
r = requests.post("https://api.ultralytics.com/export",
|
||||
json={
|
||||
@ -82,7 +82,7 @@ def export_model(key="", format="torchscript"):
|
||||
|
||||
def get_export(key="", format="torchscript"):
|
||||
# Get an exported model dictionary with download URL
|
||||
assert format in EXPORT_FORMATS, f"Unsupported export format '{format}' passed, valid formats are {EXPORT_FORMATS}"
|
||||
assert format in EXPORT_FORMATS_HUB, f"Unsupported export format '{format}', valid formats are {EXPORT_FORMATS_HUB}"
|
||||
api_key, model_id = split_key(key)
|
||||
r = requests.post("https://api.ultralytics.com/get-export",
|
||||
json={
|
||||
|
Reference in New Issue
Block a user