ultralytics 8.0.117
NAS export, classify and tasks banner URL fixes (#3145)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -1,6 +1,12 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
# NAS model interface
|
||||
YOLO-NAS model interface.
|
||||
|
||||
Usage - Predict:
|
||||
from ultralytics import NAS
|
||||
|
||||
model = NAS('yolo_nas_s')
|
||||
results = model.predict('ultralytics/assets/bus.jpg')
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
@ -33,11 +39,13 @@ class NAS:
|
||||
self.model.args = DEFAULT_CFG_DICT # attach args to model
|
||||
|
||||
# Standardize model
|
||||
self.model.fuse = lambda verbose: self.model
|
||||
self.model.fuse = lambda verbose=True: self.model
|
||||
self.model.stride = torch.tensor([32])
|
||||
self.model.names = dict(enumerate(self.model._class_names))
|
||||
self.model.is_fused = lambda: False # for info()
|
||||
self.model.yaml = {} # for info()
|
||||
self.model.pt_path = model # for export()
|
||||
self.model.task = 'detect' # for export()
|
||||
self.info()
|
||||
|
||||
@smart_inference_mode()
|
||||
|
Reference in New Issue
Block a user