ultralytics 8.0.117 NAS export, classify and tasks banner URL fixes (#3145)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-06-12 20:14:46 +02:00
committed by GitHub
parent b59342b81c
commit c340f84ce9
8 changed files with 22 additions and 12 deletions

View File

@ -1,6 +1,12 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
# NAS model interface
YOLO-NAS model interface.
Usage - Predict:
from ultralytics import NAS
model = NAS('yolo_nas_s')
results = model.predict('ultralytics/assets/bus.jpg')
"""
from pathlib import Path
@ -33,11 +39,13 @@ class NAS:
self.model.args = DEFAULT_CFG_DICT # attach args to model
# Standardize model
self.model.fuse = lambda verbose: self.model
self.model.fuse = lambda verbose=True: self.model
self.model.stride = torch.tensor([32])
self.model.names = dict(enumerate(self.model._class_names))
self.model.is_fused = lambda: False # for info()
self.model.yaml = {} # for info()
self.model.pt_path = model # for export()
self.model.task = 'detect' # for export()
self.info()
@smart_inference_mode()