CoreML NMS and half fixes (#143)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -201,7 +201,7 @@ def check_dataset_yaml(data, autodownload=True):
|
||||
extract_dir, autodownload = data.parent, False
|
||||
# Read yaml (optional)
|
||||
if isinstance(data, (str, Path)):
|
||||
data = yaml_load(data) # dictionary
|
||||
data = yaml_load(data, append_filename=True) # dictionary
|
||||
|
||||
# Checks
|
||||
for k in 'train', 'val', 'names':
|
||||
|
@ -67,7 +67,7 @@ import torch
|
||||
|
||||
import ultralytics
|
||||
from ultralytics.nn.modules import Detect, Segment
|
||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_weights
|
||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel
|
||||
from ultralytics.yolo.configs import get_config
|
||||
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
|
||||
from ultralytics.yolo.data.utils import check_dataset
|
||||
@ -154,7 +154,7 @@ class Exporter:
|
||||
# Load PyTorch model
|
||||
self.device = select_device(self.args.device or 'cpu')
|
||||
if self.args.half:
|
||||
if self.device.type == 'cpu' or not coreml:
|
||||
if self.device.type == 'cpu' and not coreml:
|
||||
LOGGER.info('half=True only compatible with GPU or CoreML export, i.e. use device=0 or format=coreml')
|
||||
self.args.half = False
|
||||
assert not self.args.dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic'
|
||||
@ -769,17 +769,22 @@ class Exporter:
|
||||
def export(cfg):
|
||||
cfg.model = cfg.model or "yolov8n.yaml"
|
||||
cfg.format = cfg.format or "torchscript"
|
||||
exporter = Exporter(cfg)
|
||||
|
||||
model = None
|
||||
if isinstance(cfg.model, (str, Path)):
|
||||
if Path(cfg.model).suffix == '.yaml':
|
||||
model = DetectionModel(cfg.model)
|
||||
elif Path(cfg.model).suffix == '.pt':
|
||||
model = attempt_load_weights(cfg.model, fuse=True)
|
||||
else:
|
||||
TypeError(f'Unsupported model type {cfg.model}')
|
||||
exporter(model=model)
|
||||
# exporter = Exporter(cfg)
|
||||
#
|
||||
# model = None
|
||||
# if isinstance(cfg.model, (str, Path)):
|
||||
# if Path(cfg.model).suffix == '.yaml':
|
||||
# model = DetectionModel(cfg.model)
|
||||
# elif Path(cfg.model).suffix == '.pt':
|
||||
# model = attempt_load_weights(cfg.model, fuse=True)
|
||||
# else:
|
||||
# TypeError(f'Unsupported model type {cfg.model}')
|
||||
# exporter(model=model)
|
||||
|
||||
from ultralytics import YOLO
|
||||
model = YOLO(cfg.model)
|
||||
model.export(**cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -64,7 +64,7 @@ class YOLO:
|
||||
verbose (bool): display model info on load
|
||||
"""
|
||||
cfg = check_yaml(cfg) # check YAML
|
||||
cfg_dict = yaml_load(cfg) # model dict
|
||||
cfg_dict = yaml_load(cfg, append_filename=True) # model dict
|
||||
self.task = guess_task_from_head(cfg_dict["head"][-1][-2])
|
||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
|
||||
self._guess_ops_from_task(self.task)
|
||||
@ -183,7 +183,7 @@ class YOLO:
|
||||
overrides.update(kwargs)
|
||||
if kwargs.get("cfg"):
|
||||
LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
|
||||
overrides = yaml_load(check_yaml(kwargs["cfg"]))
|
||||
overrides = yaml_load(check_yaml(kwargs["cfg"]), append_filename=True)
|
||||
overrides["task"] = self.task
|
||||
overrides["mode"] = "train"
|
||||
if not overrides.get("data"):
|
||||
|
@ -157,7 +157,8 @@ class BaseValidator:
|
||||
self.run_callbacks('on_val_end')
|
||||
if self.training:
|
||||
model.float()
|
||||
return {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
|
||||
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
|
||||
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
|
||||
else:
|
||||
self.logger.info('Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image' %
|
||||
self.speed)
|
||||
|
@ -3,6 +3,7 @@ import inspect
|
||||
import logging.config
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
@ -171,6 +172,18 @@ def is_dir_writeable(dir_path: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def get_git_root_dir():
|
||||
"""
|
||||
Determines whether the current file is part of a git repository and if so, returns the repository root directory.
|
||||
If the current file is not part of a git repository, returns None.
|
||||
"""
|
||||
try:
|
||||
output = subprocess.run(["git", "rev-parse", "--git-dir"], capture_output=True, check=True)
|
||||
return Path(output.stdout.strip().decode('utf-8')).parent # parent/.git
|
||||
except subprocess.CalledProcessError:
|
||||
return None
|
||||
|
||||
|
||||
def get_default_args(func):
|
||||
# Get func() default arguments
|
||||
signature = inspect.signature(func)
|
||||
@ -311,13 +324,13 @@ def yaml_save(file='data.yaml', data=None):
|
||||
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
|
||||
|
||||
|
||||
def yaml_load(file='data.yaml', append_filename=True):
|
||||
def yaml_load(file='data.yaml', append_filename=False):
|
||||
"""
|
||||
Load YAML data from a file.
|
||||
|
||||
Args:
|
||||
file (str, optional): File name. Default is 'data.yaml'.
|
||||
append_filename (bool): Add the YAML filename to the YAML dictionary. Default is True.
|
||||
append_filename (bool): Add the YAML filename to the YAML dictionary. Default is False.
|
||||
|
||||
Returns:
|
||||
dict: YAML data and file name.
|
||||
@ -339,14 +352,13 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml'):
|
||||
"""
|
||||
from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first
|
||||
|
||||
git_install = not is_pip_package()
|
||||
root = get_git_root_dir() or Path('') # not is_pip_package()
|
||||
defaults = {
|
||||
'datasets_dir': str(ROOT / 'datasets') if git_install else 'datasets', # default datasets directory.
|
||||
'weights_dir': str(ROOT / 'weights') if git_install else 'weights', # default weights directory.
|
||||
'runs_dir': str(ROOT / 'runs') if git_install else 'runs', # default runs directory.
|
||||
'datasets_dir': str(root / 'datasets'), # default datasets directory.
|
||||
'weights_dir': str(root / 'weights'), # default weights directory.
|
||||
'runs_dir': str(root / 'runs'), # default runs directory.
|
||||
'sync': True, # sync analytics to help with YOLO development
|
||||
'uuid': uuid.getnode(), # device UUID to align analytics
|
||||
'yaml_file': str(file)} # setting YAML file path
|
||||
'uuid': uuid.getnode()} # device UUID to align analytics
|
||||
|
||||
with torch_distributed_zero_first(RANK):
|
||||
if not file.exists():
|
||||
|
@ -18,13 +18,7 @@ def on_pretrain_routine_end(trainer):
|
||||
def on_fit_epoch_end(trainer):
|
||||
session = getattr(trainer, 'hub_session', None)
|
||||
if session:
|
||||
# Upload metrics after val end
|
||||
metrics = trainer.metrics
|
||||
for k, v in metrics.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
metrics[k] = v.item()
|
||||
|
||||
session.metrics_queue[trainer.epoch] = json.dumps(metrics) # json string
|
||||
session.metrics_queue[trainer.epoch] = json.dumps(trainer.metrics) # json string
|
||||
if time() - session.t['metrics'] > session.rate_limits['metrics']:
|
||||
session.upload_metrics()
|
||||
session.t['metrics'] = time() # reset timer
|
||||
|
@ -153,7 +153,7 @@ def check_python(minimum: str = '3.7.0') -> bool:
|
||||
|
||||
|
||||
@TryExcept()
|
||||
def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=''):
|
||||
def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=(), install=True, cmds=''):
|
||||
# Check installed dependencies meet YOLOv5 requirements (pass *.txt file or list of packages or single package str)
|
||||
prefix = colorstr('red', 'bold', 'requirements:')
|
||||
check_python() # check python version
|
||||
|
@ -19,7 +19,7 @@ class DetectionValidator(BaseValidator):
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
|
||||
super().__init__(dataloader, save_dir, pbar, logger, args)
|
||||
self.data_dict = yaml_load(check_file(self.args.data)) if self.args.data else None
|
||||
self.data_dict = yaml_load(check_file(self.args.data), append_filename=True) if self.args.data else None
|
||||
self.is_coco = False
|
||||
self.class_map = None
|
||||
self.targets = None
|
||||
|
Reference in New Issue
Block a user