New guess_model_task()
function (#614)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
"""
|
||||
Auto-batch utils
|
||||
AutoBatch utils
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
|
@ -308,23 +308,6 @@ def strip_optimizer(f='best.pt', s=''):
|
||||
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
||||
|
||||
|
||||
def guess_task_from_model_yaml(model):
|
||||
try:
|
||||
cfg = model if isinstance(model, dict) else model.yaml # model cfg dict
|
||||
m = cfg["head"][-1][-2].lower() # output module name
|
||||
task = None
|
||||
if m in ["classify", "classifier", "cls", "fc"]:
|
||||
task = "classify"
|
||||
if m in ["detect"]:
|
||||
task = "detect"
|
||||
if m in ["segment"]:
|
||||
task = "segment"
|
||||
except Exception as e:
|
||||
raise SyntaxError('Unknown task. Define task explicitly, i.e. task=detect when running your command. '
|
||||
'Valid tasks are detect, segment, classify.') from e
|
||||
return task
|
||||
|
||||
|
||||
def profile(input, ops, n=10, device=None):
|
||||
""" YOLOv8 speed/memory/FLOPs profiler
|
||||
Usage:
|
||||
|
Reference in New Issue
Block a user