Support both *.yml and *.yaml files (#4086)

Co-authored-by: ChristopherRogers1991 <ChristopherRogers1991@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sungjoo(Dennis) Hwang <48212469+Denny-Hwang@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-08-01 15:41:09 +02:00
committed by GitHub
parent f2ed85790f
commit b507e3a032
10 changed files with 17 additions and 17 deletions

View File

@ -177,7 +177,7 @@ class Exporter:
im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device)
file = Path(
getattr(model, 'pt_path', None) or getattr(model, 'yaml_file', None) or model.yaml.get('yaml_file', ''))
if file.suffix == '.yaml':
if file.suffix in ('.yaml', '.yml'):
file = Path(file.name)
# Update model

View File

@ -88,7 +88,7 @@ class Model:
suffix = Path(model).suffix
if not suffix and Path(model).stem in GITHUB_ASSET_STEMS:
model, suffix = Path(model).with_suffix('.pt'), '.pt' # add suffix, i.e. yolov8n -> yolov8n.pt
if suffix == '.yaml':
if suffix in ('.yaml', '.yml'):
self._new(model, task)
else:
self._load(model, task)

View File

@ -119,7 +119,7 @@ class BaseTrainer:
try:
if self.args.task == 'classify':
self.data = check_cls_dataset(self.args.data)
elif self.args.data.endswith('.yaml') or self.args.task in ('detect', 'segment'):
elif self.args.data.split('.')[-1] in ('yaml', 'yml') or self.args.task in ('detect', 'segment'):
self.data = check_det_dataset(self.args.data)
if 'yaml_file' in self.data:
self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage

View File

@ -126,7 +126,7 @@ class BaseValidator:
self.args.batch = 1 # export.py models default to batch-size 1
LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
if isinstance(self.args.data, str) and self.args.data.endswith('.yaml'):
if isinstance(self.args.data, str) and self.args.data.split('.')[-1] in ('yaml', 'yml'):
self.data = check_det_dataset(self.args.data)
elif self.args.task == 'classify':
self.data = check_cls_dataset(self.args.data, split=self.args.split)