`ultralytics 8.0.45` segment CUDA and DDP callback fixes (#1137)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent bfc078b32f
commit 3765f4f6d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
__version__ = '8.0.44'
__version__ = '8.0.45'
from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils.checks import check_yolo as checks

@ -48,19 +48,21 @@ CLI_HELP_MSG = \
"""
# Define keys for arg type checks
CFG_FLOAT_KEYS = {'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear', 'fl_gamma'}
CFG_FRACTION_KEYS = {
'dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr', 'label_smoothing',
'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud', 'fliplr', 'mosaic', 'mixup', 'copy_paste',
'conf', 'iou'} # fractional floats limited to 0.0 - 1.0
CFG_INT_KEYS = {
'epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
'line_thickness', 'workspace', 'nbs', 'save_period'}
CFG_BOOL_KEYS = {
'save', 'exist_ok', 'pretrained', 'verbose', 'deterministic', 'single_cls', 'image_weights', 'rect', 'cos_lr',
'overlap_mask', 'val', 'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt', 'save_conf',
'save_crop', 'hide_labels', 'hide_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks', 'boxes', 'keras',
'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'v5loader'}
CFG_FLOAT_KEYS = 'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear', 'fl_gamma'
CFG_FRACTION_KEYS = ('dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr',
'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud',
'fliplr', 'mosaic', 'mixup', 'copy_paste', 'conf', 'iou') # fractional floats limited to 0.0 - 1.0
CFG_INT_KEYS = ('epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
'line_thickness', 'workspace', 'nbs', 'save_period')
CFG_BOOL_KEYS = ('save', 'exist_ok', 'pretrained', 'verbose', 'deterministic', 'single_cls', 'image_weights', 'rect',
'cos_lr', 'overlap_mask', 'val', 'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show',
'save_txt', 'save_conf', 'save_crop', 'hide_labels', 'hide_conf', 'visualize', 'augment',
'agnostic_nms', 'retina_masks', 'boxes', 'keras', 'optimize', 'int8', 'dynamic', 'simplify', 'nms',
'v5loader')
# Define valid tasks and modes
TASKS = 'detect', 'segment', 'classify'
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
def cfg2dict(cfg):
@ -196,9 +198,6 @@ def entrypoint(debug=''):
LOGGER.info(CLI_HELP_MSG)
return
# Define tasks and modes
tasks = 'detect', 'segment', 'classify'
modes = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
special = {
'help': lambda: LOGGER.info(CLI_HELP_MSG),
'checks': checks.check_yolo,
@ -206,7 +205,7 @@ def entrypoint(debug=''):
'settings': lambda: yaml_print(USER_CONFIG_DIR / 'settings.yaml'),
'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
'copy-cfg': copy_default_cfg}
full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in tasks}, **{k: None for k in modes}, **special}
full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special}
# Define common mis-uses of special commands, i.e. -h, -help, --help
special.update({k[0]: v for k, v in special.items()}) # singular
@ -240,9 +239,9 @@ def entrypoint(debug=''):
except (NameError, SyntaxError, ValueError, AssertionError) as e:
check_cfg_mismatch(full_args_dict, {a: ''}, e)
elif a in tasks:
elif a in TASKS:
overrides['task'] = a
elif a in modes:
elif a in MODES:
overrides['mode'] = a
elif a in special:
special[a]()
@ -262,10 +261,10 @@ def entrypoint(debug=''):
mode = overrides.get('mode', None)
if mode is None:
mode = DEFAULT_CFG.mode or 'predict'
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {modes}. Using default 'mode={mode}'.")
elif mode not in modes:
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {MODES}. Using default 'mode={mode}'.")
elif mode not in MODES:
if mode not in ('checks', checks):
raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {modes}.\n{CLI_HELP_MSG}")
raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}")
LOGGER.warning("WARNING ⚠️ 'yolo mode=checks' is deprecated. Use 'yolo checks' instead.")
checks.check_yolo()
return
@ -280,11 +279,11 @@ def entrypoint(debug=''):
model = YOLO(model)
# Task
# if task and task != model.task:
# LOGGER.warning(f"WARNING ⚠️ 'task={task}' conflicts with {model.task} model {overrides['model']}. "
# f"Inheriting 'task={model.task}' from {overrides['model']} and ignoring 'task={task}'.")
overrides['task'] = overrides.get('task', model.task)
model.task = overrides['task']
task = overrides.get('task', None)
if task is not None and task not in TASKS:
raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}")
else:
model.task = task
# Mode
if mode in {'predict', 'track'} and 'source' not in overrides:

@ -292,7 +292,10 @@ class Exporter:
@try_export
def _export_onnx(self, prefix=colorstr('ONNX:')):
# YOLOv8 ONNX export
check_requirements('onnx>=1.12.0')
requirements = ['onnx>=1.12.0']
if self.args.simplify:
requirements += ['onnxsim', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime']
check_requirements(requirements)
import onnx # noqa
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
@ -326,7 +329,6 @@ class Exporter:
# Simplify
if self.args.simplify:
try:
check_requirements(('onnxsim', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'))
import onnxsim
LOGGER.info(f'{prefix} simplifying with onnxsim {onnxsim.__version__}...')
@ -508,9 +510,8 @@ class Exporter:
try:
import tensorflow as tf # noqa
except ImportError:
check_requirements(
f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if torch.cuda.is_available() else '-cpu'}"
)
cuda = torch.cuda.is_available()
check_requirements(f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if cuda else '-cpu'}")
import tensorflow as tf # noqa
check_requirements(('onnx', 'onnx2tf', 'sng4onnx', 'onnxsim', 'onnx_graphsurgeon', 'tflite_support',
'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'),

@ -64,7 +64,7 @@ class YOLO:
Performs prediction using the YOLO model.
Returns:
list[ultralytics.yolo.engine.results.Results]: The prediction results.
list(ultralytics.yolo.engine.results.Results): The prediction results.
"""
def __init__(self, model='yolov8n.pt') -> None:

@ -111,14 +111,14 @@ class Results:
Args:
show_conf (bool): Whether to show the detection confidence score.
line_width (float, optional): The line width of the bounding boxes. If None, it is automatically scaled to the image size.
font_size (float, optional): The font size of the text. If None, it is automatically scaled to the image size.
line_width (float, optional): The line width of the bounding boxes. If None, it is scaled to the image size.
font_size (float, optional): The font size of the text. If None, it is scaled to the image size.
font (str): The font to use for the text.
pil (bool): Whether to return the image as a PIL Image.
example (str): An example string to display in the plot. Useful for indicating the expected format of the output.
example (str): An example string to display. Useful for indicating the expected format of the output.
Returns:
None or PIL Image: If `pil` is True, the image will be returned as a PIL Image. Otherwise, nothing is returned.
(None) or (PIL.Image): If `pil` is True, a PIL Image is returned. Otherwise, nothing is returned.
"""
img = deepcopy(self.orig_img)
annotator = Annotator(img, line_width, font_size, font, pil, example)
@ -284,7 +284,7 @@ class Masks:
orig_shape (tuple): Original image size, in the format (height, width).
Properties:
segments (list): A list of segments which includes x, y, w, h, label, confidence, and mask of each detection masks.
segments (list): A list of segments which includes x, y, w, h, label, confidence, and mask of each detection.
Methods:
cpu(): Returns a copy of the masks tensor on CPU memory.

@ -181,7 +181,7 @@ class BaseTrainer:
LOGGER.info(f'Running DDP command {cmd}')
subprocess.run(cmd, check=True)
except Exception as e:
LOGGER.warning(e)
raise e
finally:
ddp_cleanup(self, str(file))
else:

@ -63,7 +63,6 @@ class BaseValidator:
dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
save_dir (Path): Directory to save results.
pbar (tqdm.tqdm): Progress bar for displaying progress.
logger (logging.Logger): Logger to log messages.
args (SimpleNamespace): Configuration for the validator.
"""
self.dataloader = dataloader

@ -24,8 +24,6 @@ def find_free_network_port() -> int:
def generate_ddp_file(trainer):
import_path = '.'.join(str(trainer.__class__).split('.')[1:-1])
if not trainer.resume:
shutil.rmtree(trainer.save_dir) # remove the save_dir
content = f'''cfg = {vars(trainer.args)} \nif __name__ == "__main__":
from ultralytics.{import_path} import {trainer.__class__.__name__}
@ -43,16 +41,17 @@ def generate_ddp_file(trainer):
def generate_ddp_command(world_size, trainer):
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
# Get file and args (do not use sys.argv due to security vulnerability)
exclude_args = ['save_dir']
args = [f'{k}={v}' for k, v in vars(trainer.args).items() if k not in exclude_args]
file = generate_ddp_file(trainer) # if argv[0].endswith('yolo') else os.path.abspath(argv[0])
# Build command
import __main__ # local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
file = os.path.abspath(sys.argv[0])
using_cli = not file.endswith('.py')
if not trainer.resume:
shutil.rmtree(trainer.save_dir) # remove the save_dir
if using_cli:
file = generate_ddp_file(trainer)
dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
port = find_free_network_port()
exclude_args = ['save_dir']
args = [f'{k}={v}' for k, v in vars(trainer.args).items() if k not in exclude_args]
cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file] + args
return cmd, file

Loading…
Cancel
Save