diff --git a/docs/models/yolov8.md b/docs/models/yolov8.md index 240be83..882fc4c 100644 --- a/docs/models/yolov8.md +++ b/docs/models/yolov8.md @@ -21,12 +21,12 @@ YOLOv8 is the latest iteration in the YOLO series of real-time object detectors, ## Supported Tasks -| Model Type | Pre-trained Weights | Task | -|-------------|------------------------------------------------------------------------------------------------------------------|-----------------------| -| YOLOv8 | `yolov8n.pt`, `yolov8s.pt`, `yolov8m.pt`, `yolov8l.pt`, `yolov8x.pt` | Detection | -| YOLOv8-seg | `yolov8n-seg.pt`, `yolov8s-seg.pt`, `yolov8m-seg.pt`, `yolov8l-seg.pt`, `yolov8x-seg.pt` | Instance Segmentation | -| YOLOv8-pose | `yolov8n-pose.pt`, `yolov8s-pose.pt`, `yolov8m-pose.pt`, `yolov8l-pose.pt`, `yolov8x-pose.pt` ,`yolov8x-pose-p6` | Pose/Keypoints | -| YOLOv8-cls | `yolov8n-cls.pt`, `yolov8s-cls.pt`, `yolov8m-cls.pt`, `yolov8l-cls.pt`, `yolov8x-cls.pt` | Classification | +| Model Type | Pre-trained Weights | Task | +|-------------|---------------------------------------------------------------------------------------------------------------------|-----------------------| +| YOLOv8 | `yolov8n.pt`, `yolov8s.pt`, `yolov8m.pt`, `yolov8l.pt`, `yolov8x.pt` | Detection | +| YOLOv8-seg | `yolov8n-seg.pt`, `yolov8s-seg.pt`, `yolov8m-seg.pt`, `yolov8l-seg.pt`, `yolov8x-seg.pt` | Instance Segmentation | +| YOLOv8-pose | `yolov8n-pose.pt`, `yolov8s-pose.pt`, `yolov8m-pose.pt`, `yolov8l-pose.pt`, `yolov8x-pose.pt`, `yolov8x-pose-p6.pt` | Pose/Keypoints | +| YOLOv8-cls | `yolov8n-cls.pt`, `yolov8s-cls.pt`, `yolov8m-cls.pt`, `yolov8l-cls.pt`, `yolov8x-cls.pt` | Classification | ## Supported Modes diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py index 9c97979..5b26643 100644 --- a/ultralytics/engine/exporter.py +++ b/ultralytics/engine/exporter.py @@ -177,7 +177,7 @@ class Exporter: im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device) file = Path( getattr(model, 'pt_path', None) or getattr(model, 'yaml_file', None) or model.yaml.get('yaml_file', '')) - if file.suffix == '.yaml': + if file.suffix in ('.yaml', '.yml'): file = Path(file.name) # Update model diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py index 68891bf..8a8d597 100644 --- a/ultralytics/engine/model.py +++ b/ultralytics/engine/model.py @@ -88,7 +88,7 @@ class Model: suffix = Path(model).suffix if not suffix and Path(model).stem in GITHUB_ASSET_STEMS: model, suffix = Path(model).with_suffix('.pt'), '.pt' # add suffix, i.e. yolov8n -> yolov8n.pt - if suffix == '.yaml': + if suffix in ('.yaml', '.yml'): self._new(model, task) else: self._load(model, task) diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 5c034cf..9532d30 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -119,7 +119,7 @@ class BaseTrainer: try: if self.args.task == 'classify': self.data = check_cls_dataset(self.args.data) - elif self.args.data.endswith('.yaml') or self.args.task in ('detect', 'segment'): + elif self.args.data.split('.')[-1] in ('yaml', 'yml') or self.args.task in ('detect', 'segment'): self.data = check_det_dataset(self.args.data) if 'yaml_file' in self.data: self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage diff --git a/ultralytics/engine/validator.py b/ultralytics/engine/validator.py index e1382cd..2b08912 100644 --- a/ultralytics/engine/validator.py +++ b/ultralytics/engine/validator.py @@ -126,7 +126,7 @@ class BaseValidator: self.args.batch = 1 # export.py models default to batch-size 1 LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models') - if isinstance(self.args.data, str) and self.args.data.endswith('.yaml'): + if isinstance(self.args.data, str) and self.args.data.split('.')[-1] in ('yaml', 'yml'): self.data = check_det_dataset(self.args.data) elif self.args.task == 'classify': self.data = check_cls_dataset(self.args.data, split=self.args.split) diff --git a/ultralytics/models/fastsam/model.py b/ultralytics/models/fastsam/model.py index 6cfedc4..a4bfada 100644 --- a/ultralytics/models/fastsam/model.py +++ b/ultralytics/models/fastsam/model.py @@ -23,7 +23,7 @@ class FastSAM(Model): """Call the __init__ method of the parent class (YOLO) with the updated default model""" if model == 'FastSAM.pt': model = 'FastSAM-x.pt' - assert Path(model).suffix != '.yaml', 'FastSAM models only support pre-trained models.' + assert Path(model).suffix not in ('.yaml', '.yml'), 'FastSAM models only support pre-trained models.' super().__init__(model=model, task='segment') @property diff --git a/ultralytics/models/nas/model.py b/ultralytics/models/nas/model.py index 518a7c8..1f7cd35 100644 --- a/ultralytics/models/nas/model.py +++ b/ultralytics/models/nas/model.py @@ -23,7 +23,7 @@ from .val import NASValidator class NAS(Model): def __init__(self, model='yolo_nas_s.pt') -> None: - assert Path(model).suffix != '.yaml', 'YOLO-NAS models only support pre-trained models.' + assert Path(model).suffix not in ('.yaml', '.yml'), 'YOLO-NAS models only support pre-trained models.' super().__init__(model, task='detect') @smart_inference_mode() diff --git a/ultralytics/models/rtdetr/model.py b/ultralytics/models/rtdetr/model.py index 5612a04..aa99f9d 100644 --- a/ultralytics/models/rtdetr/model.py +++ b/ultralytics/models/rtdetr/model.py @@ -16,8 +16,8 @@ class RTDETR(Model): """ def __init__(self, model='rtdetr-l.pt') -> None: - if model and not model.endswith('.pt') and not model.endswith('.yaml'): - raise NotImplementedError('RT-DETR only supports creating from pt file or yaml file.') + if model and not model.split('.')[-1] in ('pt', 'yaml', 'yml'): + raise NotImplementedError('RT-DETR only supports creating from *.pt file or *.yaml file.') super().__init__(model=model, task='detect') @property diff --git a/ultralytics/models/yolo/classify/train.py b/ultralytics/models/yolo/classify/train.py index 0494708..eda92ad 100644 --- a/ultralytics/models/yolo/classify/train.py +++ b/ultralytics/models/yolo/classify/train.py @@ -57,7 +57,7 @@ class ClassificationTrainer(BaseTrainer): self.model, _ = attempt_load_one_weight(model, device='cpu') for p in self.model.parameters(): p.requires_grad = True # for training - elif model.endswith('.yaml'): + elif model.split('.')[-1] in ('yaml', 'yml'): self.model = self.get_model(cfg=model) elif model in torchvision.models.__dict__: self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if self.args.pretrained else None) diff --git a/ultralytics/utils/benchmarks.py b/ultralytics/utils/benchmarks.py index 7a67e4f..7dfadce 100644 --- a/ultralytics/utils/benchmarks.py +++ b/ultralytics/utils/benchmarks.py @@ -192,7 +192,7 @@ class ProfileModels: output = [] for file in files: engine_file = file.with_suffix('.engine') - if file.suffix in ('.pt', '.yaml'): + if file.suffix in ('.pt', '.yaml', '.yml'): model = YOLO(str(file)) model.fuse() # to report correct params and GFLOPs in model.info() model_info = model.info() @@ -229,7 +229,7 @@ class ProfileModels: if path.is_dir(): extensions = ['*.pt', '*.onnx', '*.yaml'] files.extend([file for ext in extensions for file in glob.glob(str(path / ext))]) - elif path.suffix in {'.pt', '.yaml'}: # add non-existing + elif path.suffix in ('.pt', '.yaml', '.yml'): # add non-existing files.append(str(path)) else: files.extend(glob.glob(str(path)))