ultralytics 8.0.45 segment CUDA and DDP callback fixes (#1137)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-02-24 12:43:53 +01:00
committed by GitHub
parent bfc078b32f
commit 3765f4f6d9
8 changed files with 49 additions and 51 deletions

View File

@ -292,7 +292,10 @@ class Exporter:
@try_export
def _export_onnx(self, prefix=colorstr('ONNX:')):
# YOLOv8 ONNX export
check_requirements('onnx>=1.12.0')
requirements = ['onnx>=1.12.0']
if self.args.simplify:
requirements += ['onnxsim', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime']
check_requirements(requirements)
import onnx # noqa
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
@ -326,7 +329,6 @@ class Exporter:
# Simplify
if self.args.simplify:
try:
check_requirements(('onnxsim', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'))
import onnxsim
LOGGER.info(f'{prefix} simplifying with onnxsim {onnxsim.__version__}...')
@ -508,9 +510,8 @@ class Exporter:
try:
import tensorflow as tf # noqa
except ImportError:
check_requirements(
f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if torch.cuda.is_available() else '-cpu'}"
)
cuda = torch.cuda.is_available()
check_requirements(f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if cuda else '-cpu'}")
import tensorflow as tf # noqa
check_requirements(('onnx', 'onnx2tf', 'sng4onnx', 'onnxsim', 'onnx_graphsurgeon', 'tflite_support',
'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'),

View File

@ -64,7 +64,7 @@ class YOLO:
Performs prediction using the YOLO model.
Returns:
list[ultralytics.yolo.engine.results.Results]: The prediction results.
list(ultralytics.yolo.engine.results.Results): The prediction results.
"""
def __init__(self, model='yolov8n.pt') -> None:

View File

@ -111,14 +111,14 @@ class Results:
Args:
show_conf (bool): Whether to show the detection confidence score.
line_width (float, optional): The line width of the bounding boxes. If None, it is automatically scaled to the image size.
font_size (float, optional): The font size of the text. If None, it is automatically scaled to the image size.
line_width (float, optional): The line width of the bounding boxes. If None, it is scaled to the image size.
font_size (float, optional): The font size of the text. If None, it is scaled to the image size.
font (str): The font to use for the text.
pil (bool): Whether to return the image as a PIL Image.
example (str): An example string to display in the plot. Useful for indicating the expected format of the output.
example (str): An example string to display. Useful for indicating the expected format of the output.
Returns:
None or PIL Image: If `pil` is True, the image will be returned as a PIL Image. Otherwise, nothing is returned.
(None) or (PIL.Image): If `pil` is True, a PIL Image is returned. Otherwise, nothing is returned.
"""
img = deepcopy(self.orig_img)
annotator = Annotator(img, line_width, font_size, font, pil, example)
@ -284,7 +284,7 @@ class Masks:
orig_shape (tuple): Original image size, in the format (height, width).
Properties:
segments (list): A list of segments which includes x, y, w, h, label, confidence, and mask of each detection masks.
segments (list): A list of segments which includes x, y, w, h, label, confidence, and mask of each detection.
Methods:
cpu(): Returns a copy of the masks tensor on CPU memory.

View File

@ -181,7 +181,7 @@ class BaseTrainer:
LOGGER.info(f'Running DDP command {cmd}')
subprocess.run(cmd, check=True)
except Exception as e:
LOGGER.warning(e)
raise e
finally:
ddp_cleanup(self, str(file))
else:

View File

@ -63,7 +63,6 @@ class BaseValidator:
dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
save_dir (Path): Directory to save results.
pbar (tqdm.tqdm): Progress bar for displaying progress.
logger (logging.Logger): Logger to log messages.
args (SimpleNamespace): Configuration for the validator.
"""
self.dataloader = dataloader