New guess_model_task() function (#614)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-01-25 02:24:28 +01:00
committed by GitHub
parent 520825c4b2
commit 59d4335664
6 changed files with 63 additions and 29 deletions

View File

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
"""
Auto-batch utils
AutoBatch utils
"""
from copy import deepcopy

View File

@ -308,23 +308,6 @@ def strip_optimizer(f='best.pt', s=''):
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
def guess_task_from_model_yaml(model):
try:
cfg = model if isinstance(model, dict) else model.yaml # model cfg dict
m = cfg["head"][-1][-2].lower() # output module name
task = None
if m in ["classify", "classifier", "cls", "fc"]:
task = "classify"
if m in ["detect"]:
task = "detect"
if m in ["segment"]:
task = "segment"
except Exception as e:
raise SyntaxError('Unknown task. Define task explicitly, i.e. task=detect when running your command. '
'Valid tasks are detect, segment, classify.') from e
return task
def profile(input, ops, n=10, device=None):
""" YOLOv8 speed/memory/FLOPs profiler
Usage: