New guess_model_task()
function (#614)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -251,7 +251,7 @@ class ClassificationModel(BaseModel):
|
||||
ch=3,
|
||||
nc=1000,
|
||||
cutoff=10,
|
||||
verbose=True): # yaml, model, number of classes, cutoff index
|
||||
verbose=True): # yaml, model, channels, number of classes, cutoff index, verbose flag
|
||||
super().__init__()
|
||||
self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg, ch, nc, verbose)
|
||||
|
||||
@ -457,3 +457,53 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||
ch = []
|
||||
ch.append(c2)
|
||||
return nn.Sequential(*layers), sorted(save)
|
||||
|
||||
|
||||
def guess_model_task(model):
|
||||
"""
|
||||
Guess the task of a PyTorch model from its architecture or configuration.
|
||||
|
||||
Args:
|
||||
model (nn.Module) or (dict): PyTorch model or model configuration in YAML format.
|
||||
|
||||
Returns:
|
||||
str: Task of the model ('detect', 'segment', 'classify').
|
||||
|
||||
Raises:
|
||||
SyntaxError: If the task of the model could not be determined.
|
||||
"""
|
||||
cfg, task = None, None
|
||||
if isinstance(model, dict):
|
||||
cfg = model
|
||||
elif isinstance(model, nn.Module): # PyTorch model
|
||||
for x in 'model.yaml', 'model.model.yaml', 'model.model.model.yaml':
|
||||
with contextlib.suppress(Exception):
|
||||
cfg = eval(x)
|
||||
break
|
||||
|
||||
# Guess from YAML dictionary
|
||||
if cfg:
|
||||
m = cfg["head"][-1][-2].lower() # output module name
|
||||
if m in ["classify", "classifier", "cls", "fc"]:
|
||||
task = "classify"
|
||||
if m in ["detect"]:
|
||||
task = "detect"
|
||||
if m in ["segment"]:
|
||||
task = "segment"
|
||||
|
||||
# Guess from PyTorch model
|
||||
if task is None and isinstance(model, nn.Module):
|
||||
for m in model.modules():
|
||||
if isinstance(m, Detect):
|
||||
task = "detect"
|
||||
elif isinstance(m, Segment):
|
||||
task = "segment"
|
||||
elif isinstance(m, Classify):
|
||||
task = "classify"
|
||||
|
||||
# Unable to determine task from model
|
||||
if task is None:
|
||||
raise SyntaxError("YOLO is unable to automatically guess model task. Explicitly define task for your model, "
|
||||
"i.e. 'task=detect', 'task=segment' or 'task=classify'.")
|
||||
else:
|
||||
return task
|
||||
|
Reference in New Issue
Block a user