ultralytics 8.0.35
TensorRT, ONNX and OpenVINO predict and val (#929)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Pedley <ericpedley@gmail.com>
This commit is contained in:
@ -13,7 +13,7 @@ import torch.nn as nn
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER, ROOT, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_version
|
||||
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_version, check_yaml
|
||||
from ultralytics.yolo.utils.downloads import attempt_download_asset, is_url
|
||||
from ultralytics.yolo.utils.ops import xywh2xyxy
|
||||
|
||||
@ -38,7 +38,7 @@ class AutoBackend(nn.Module):
|
||||
weights (str): The path to the weights file. Default: 'yolov8n.pt'
|
||||
device (torch.device): The device to run the model on.
|
||||
dnn (bool): Use OpenCV's DNN module for inference if True, defaults to False.
|
||||
data (dict): Additional data, optional
|
||||
data (str), (Path): Additional data.yaml file for class names, optional
|
||||
fp16 (bool): If True, use half precision. Default: False
|
||||
fuse (bool): Whether to fuse the model or not. Default: True
|
||||
|
||||
@ -237,7 +237,7 @@ class AutoBackend(nn.Module):
|
||||
|
||||
# class names
|
||||
if 'names' not in locals(): # names missing
|
||||
names = yaml_load(data)['names'] if data else {i: f'class{i}' for i in range(999)} # assign default
|
||||
names = yaml_load(check_yaml(data))['names'] if data else {i: f'class{i}' for i in range(999)} # assign
|
||||
names = check_class_names(names)
|
||||
|
||||
self.__dict__.update(locals()) # assign all variables to self
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import contextlib
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import thop
|
||||
import torch
|
||||
@ -490,6 +491,14 @@ def guess_model_task(model):
|
||||
with contextlib.suppress(Exception):
|
||||
cfg = eval(x)
|
||||
break
|
||||
elif isinstance(model, (str, Path)):
|
||||
model = str(model)
|
||||
if '-seg' in model:
|
||||
return "segment"
|
||||
elif '-cls' in model:
|
||||
return "classify"
|
||||
else:
|
||||
return "detect"
|
||||
|
||||
# Guess from YAML dictionary
|
||||
if cfg:
|
||||
|
Reference in New Issue
Block a user