ultralytics 8.0.45
segment CUDA and DDP callback fixes (#1137)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
@ -292,7 +292,10 @@ class Exporter:
|
||||
@try_export
|
||||
def _export_onnx(self, prefix=colorstr('ONNX:')):
|
||||
# YOLOv8 ONNX export
|
||||
check_requirements('onnx>=1.12.0')
|
||||
requirements = ['onnx>=1.12.0']
|
||||
if self.args.simplify:
|
||||
requirements += ['onnxsim', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime']
|
||||
check_requirements(requirements)
|
||||
import onnx # noqa
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
|
||||
@ -326,7 +329,6 @@ class Exporter:
|
||||
# Simplify
|
||||
if self.args.simplify:
|
||||
try:
|
||||
check_requirements(('onnxsim', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'))
|
||||
import onnxsim
|
||||
|
||||
LOGGER.info(f'{prefix} simplifying with onnxsim {onnxsim.__version__}...')
|
||||
@ -508,9 +510,8 @@ class Exporter:
|
||||
try:
|
||||
import tensorflow as tf # noqa
|
||||
except ImportError:
|
||||
check_requirements(
|
||||
f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if torch.cuda.is_available() else '-cpu'}"
|
||||
)
|
||||
cuda = torch.cuda.is_available()
|
||||
check_requirements(f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if cuda else '-cpu'}")
|
||||
import tensorflow as tf # noqa
|
||||
check_requirements(('onnx', 'onnx2tf', 'sng4onnx', 'onnxsim', 'onnx_graphsurgeon', 'tflite_support',
|
||||
'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'),
|
||||
|
@ -64,7 +64,7 @@ class YOLO:
|
||||
Performs prediction using the YOLO model.
|
||||
|
||||
Returns:
|
||||
list[ultralytics.yolo.engine.results.Results]: The prediction results.
|
||||
list(ultralytics.yolo.engine.results.Results): The prediction results.
|
||||
"""
|
||||
|
||||
def __init__(self, model='yolov8n.pt') -> None:
|
||||
|
@ -111,14 +111,14 @@ class Results:
|
||||
|
||||
Args:
|
||||
show_conf (bool): Whether to show the detection confidence score.
|
||||
line_width (float, optional): The line width of the bounding boxes. If None, it is automatically scaled to the image size.
|
||||
font_size (float, optional): The font size of the text. If None, it is automatically scaled to the image size.
|
||||
line_width (float, optional): The line width of the bounding boxes. If None, it is scaled to the image size.
|
||||
font_size (float, optional): The font size of the text. If None, it is scaled to the image size.
|
||||
font (str): The font to use for the text.
|
||||
pil (bool): Whether to return the image as a PIL Image.
|
||||
example (str): An example string to display in the plot. Useful for indicating the expected format of the output.
|
||||
example (str): An example string to display. Useful for indicating the expected format of the output.
|
||||
|
||||
Returns:
|
||||
None or PIL Image: If `pil` is True, the image will be returned as a PIL Image. Otherwise, nothing is returned.
|
||||
(None) or (PIL.Image): If `pil` is True, a PIL Image is returned. Otherwise, nothing is returned.
|
||||
"""
|
||||
img = deepcopy(self.orig_img)
|
||||
annotator = Annotator(img, line_width, font_size, font, pil, example)
|
||||
@ -284,7 +284,7 @@ class Masks:
|
||||
orig_shape (tuple): Original image size, in the format (height, width).
|
||||
|
||||
Properties:
|
||||
segments (list): A list of segments which includes x, y, w, h, label, confidence, and mask of each detection masks.
|
||||
segments (list): A list of segments which includes x, y, w, h, label, confidence, and mask of each detection.
|
||||
|
||||
Methods:
|
||||
cpu(): Returns a copy of the masks tensor on CPU memory.
|
||||
|
@ -181,7 +181,7 @@ class BaseTrainer:
|
||||
LOGGER.info(f'Running DDP command {cmd}')
|
||||
subprocess.run(cmd, check=True)
|
||||
except Exception as e:
|
||||
LOGGER.warning(e)
|
||||
raise e
|
||||
finally:
|
||||
ddp_cleanup(self, str(file))
|
||||
else:
|
||||
|
@ -63,7 +63,6 @@ class BaseValidator:
|
||||
dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
|
||||
save_dir (Path): Directory to save results.
|
||||
pbar (tqdm.tqdm): Progress bar for displaying progress.
|
||||
logger (logging.Logger): Logger to log messages.
|
||||
args (SimpleNamespace): Configuration for the validator.
|
||||
"""
|
||||
self.dataloader = dataloader
|
||||
|
Reference in New Issue
Block a user