# Ultralytics YOLO 🚀, AGPL-3.0 license """ YOLO-NAS model interface. Usage - Predict: from ultralytics import NAS model = NAS('yolo_nas_s') results = model.predict('ultralytics/assets/bus.jpg') """ from pathlib import Path import torch from ultralytics.engine.model import Model from ultralytics.utils.torch_utils import model_info, smart_inference_mode from .predict import NASPredictor from .val import NASValidator class NAS(Model): def __init__(self, model='yolo_nas_s.pt') -> None: assert Path(model).suffix != '.yaml', 'YOLO-NAS models only support pre-trained models.' super().__init__(model, task='detect') @smart_inference_mode() def _load(self, weights: str, task: str): # Load or create new NAS model import super_gradients suffix = Path(weights).suffix if suffix == '.pt': self.model = torch.load(weights) elif suffix == '': self.model = super_gradients.training.models.get(weights, pretrained_weights='coco') # Standardize model self.model.fuse = lambda verbose=True: self.model self.model.stride = torch.tensor([32]) self.model.names = dict(enumerate(self.model._class_names)) self.model.is_fused = lambda: False # for info() self.model.yaml = {} # for info() self.model.pt_path = weights # for export() self.model.task = 'detect' # for export() def info(self, detailed=False, verbose=True): """ Logs model info. Args: detailed (bool): Show detailed information about model. verbose (bool): Controls verbosity. """ return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640) @property def task_map(self): return {'detect': {'predictor': NASPredictor, 'validator': NASValidator}}