ultralytics 8.0.117
NAS export, classify and tasks banner URL fixes (#3145)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -172,7 +172,8 @@ class Exporter:
|
||||
|
||||
# Input
|
||||
im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device)
|
||||
file = Path(getattr(model, 'pt_path', None) or getattr(model, 'yaml_file', None) or model.yaml['yaml_file'])
|
||||
file = Path(
|
||||
getattr(model, 'pt_path', None) or getattr(model, 'yaml_file', None) or model.yaml.get('yaml_file', ''))
|
||||
if file.suffix == '.yaml':
|
||||
file = Path(file.name)
|
||||
|
||||
@ -207,7 +208,8 @@ class Exporter:
|
||||
self.im = im
|
||||
self.model = model
|
||||
self.file = file
|
||||
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y)
|
||||
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else \
|
||||
tuple(tuple(x.shape if isinstance(x, torch.Tensor) else []) for x in y)
|
||||
self.pretty_name = Path(self.model.yaml.get('yaml_file', self.file)).stem.replace('yolo', 'YOLO')
|
||||
trained_on = f'trained on {Path(self.args.data).name}' if self.args.data else '(untrained)'
|
||||
description = f'Ultralytics {self.pretty_name} model {trained_on}'
|
||||
|
@ -1,6 +1,12 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
# NAS model interface
|
||||
YOLO-NAS model interface.
|
||||
|
||||
Usage - Predict:
|
||||
from ultralytics import NAS
|
||||
|
||||
model = NAS('yolo_nas_s')
|
||||
results = model.predict('ultralytics/assets/bus.jpg')
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
@ -33,11 +39,13 @@ class NAS:
|
||||
self.model.args = DEFAULT_CFG_DICT # attach args to model
|
||||
|
||||
# Standardize model
|
||||
self.model.fuse = lambda verbose: self.model
|
||||
self.model.fuse = lambda verbose=True: self.model
|
||||
self.model.stride = torch.tensor([32])
|
||||
self.model.names = dict(enumerate(self.model._class_names))
|
||||
self.model.is_fused = lambda: False # for info()
|
||||
self.model.yaml = {} # for info()
|
||||
self.model.pt_path = model # for export()
|
||||
self.model.task = 'detect' # for export()
|
||||
self.info()
|
||||
|
||||
@smart_inference_mode()
|
||||
|
Reference in New Issue
Block a user