ultralytics 8.0.117 NAS export, classify and tasks banner URL fixes (#3145)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-06-12 20:14:46 +02:00
committed by GitHub
parent b59342b81c
commit c340f84ce9
8 changed files with 22 additions and 12 deletions

View File

@ -172,7 +172,8 @@ class Exporter:
# Input
im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device)
file = Path(getattr(model, 'pt_path', None) or getattr(model, 'yaml_file', None) or model.yaml['yaml_file'])
file = Path(
getattr(model, 'pt_path', None) or getattr(model, 'yaml_file', None) or model.yaml.get('yaml_file', ''))
if file.suffix == '.yaml':
file = Path(file.name)
@ -207,7 +208,8 @@ class Exporter:
self.im = im
self.model = model
self.file = file
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y)
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else \
tuple(tuple(x.shape if isinstance(x, torch.Tensor) else []) for x in y)
self.pretty_name = Path(self.model.yaml.get('yaml_file', self.file)).stem.replace('yolo', 'YOLO')
trained_on = f'trained on {Path(self.args.data).name}' if self.args.data else '(untrained)'
description = f'Ultralytics {self.pretty_name} model {trained_on}'