ultralytics 8.0.44
export and task fixes (#1088)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Mehran Ghandehari <mehran.maps@gmail.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
@ -181,7 +181,6 @@ class AutoBackend(nn.Module):
|
||||
import tensorflow as tf
|
||||
keras = False # assume TF1 saved_model
|
||||
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
|
||||
w = Path(w) / 'metadata.yaml'
|
||||
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
|
||||
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
|
||||
import tensorflow as tf
|
||||
@ -258,8 +257,9 @@ class AutoBackend(nn.Module):
|
||||
f'\n\n{EXPORT_FORMATS_TABLE}')
|
||||
|
||||
# Load external metadata YAML
|
||||
w = Path(w)
|
||||
if xml or saved_model or paddle:
|
||||
metadata = Path(w).parent / 'metadata.yaml'
|
||||
metadata = (w if saved_model else w.parents[1] if paddle else w.parent) / 'metadata.yaml'
|
||||
if metadata.exists():
|
||||
metadata = yaml_load(metadata)
|
||||
stride, names = int(metadata['stride']), metadata['names'] # load metadata
|
||||
|
@ -287,6 +287,7 @@ class ClassificationModel(BaseModel):
|
||||
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
||||
self.yaml['nc'] = nc # override yaml value
|
||||
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
|
||||
self.stride = torch.Tensor([1]) # no stride constraints
|
||||
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
|
||||
self.info()
|
||||
|
||||
@ -520,14 +521,15 @@ def guess_model_task(model):
|
||||
|
||||
# Guess from model filename
|
||||
if isinstance(model, (str, Path)):
|
||||
model = Path(model).stem
|
||||
if '-seg' in model:
|
||||
model = Path(model)
|
||||
if '-seg' in model.stem or 'segment' in model.parts:
|
||||
return 'segment'
|
||||
elif '-cls' in model:
|
||||
elif '-cls' in model.stem or 'classify' in model.parts:
|
||||
return 'classify'
|
||||
else:
|
||||
elif 'detect' in model.parts:
|
||||
return 'detect'
|
||||
|
||||
# Unable to determine task from model
|
||||
raise SyntaxError('YOLO is unable to automatically guess model task. Explicitly define task for your model, '
|
||||
"i.e. 'task=detect', 'task=segment' or 'task=classify'.")
|
||||
LOGGER.warning("WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. "
|
||||
"Explicitly define task for your model, i.e. 'task=detect', 'task=segment' or 'task=classify'.")
|
||||
return 'detect' # assume detect
|
||||
|
Reference in New Issue
Block a user