ultralytics 8.0.35 TensorRT, ONNX and OpenVINO predict and val (#929)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Pedley <ericpedley@gmail.com>
This commit is contained in:
Glenn Jocher
2023-02-11 21:31:49 +04:00
committed by GitHub
parent d32b339373
commit 977fd8f0b8
15 changed files with 88 additions and 69 deletions

View File

@ -13,7 +13,7 @@ import torch.nn as nn
from PIL import Image
from ultralytics.yolo.utils import LOGGER, ROOT, yaml_load
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_version
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_version, check_yaml
from ultralytics.yolo.utils.downloads import attempt_download_asset, is_url
from ultralytics.yolo.utils.ops import xywh2xyxy
@ -38,7 +38,7 @@ class AutoBackend(nn.Module):
weights (str): The path to the weights file. Default: 'yolov8n.pt'
device (torch.device): The device to run the model on.
dnn (bool): Use OpenCV's DNN module for inference if True, defaults to False.
data (dict): Additional data, optional
data (str), (Path): Additional data.yaml file for class names, optional
fp16 (bool): If True, use half precision. Default: False
fuse (bool): Whether to fuse the model or not. Default: True
@ -237,7 +237,7 @@ class AutoBackend(nn.Module):
# class names
if 'names' not in locals(): # names missing
names = yaml_load(data)['names'] if data else {i: f'class{i}' for i in range(999)} # assign default
names = yaml_load(check_yaml(data))['names'] if data else {i: f'class{i}' for i in range(999)} # assign
names = check_class_names(names)
self.__dict__.update(locals()) # assign all variables to self

View File

@ -2,6 +2,7 @@
import contextlib
from copy import deepcopy
from pathlib import Path
import thop
import torch
@ -490,6 +491,14 @@ def guess_model_task(model):
with contextlib.suppress(Exception):
cfg = eval(x)
break
elif isinstance(model, (str, Path)):
model = str(model)
if '-seg' in model:
return "segment"
elif '-cls' in model:
return "classify"
else:
return "detect"
# Guess from YAML dictionary
if cfg: