ultralytics 8.0.31
updates and fixes (#857)
Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kalen Michael <kalenmike@gmail.com>
This commit is contained in:
@ -49,6 +49,20 @@ CLI_HELP_MSG = \
|
||||
GitHub: https://github.com/ultralytics/ultralytics
|
||||
"""
|
||||
|
||||
CFG_FLOAT_KEYS = {'warmup_epochs', 'box', 'cls', 'dfl'}
|
||||
CFG_FRACTION_KEYS = {
|
||||
'dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr', 'fl_gamma',
|
||||
'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'degrees', 'translate', 'scale', 'shear', 'perspective', 'flipud',
|
||||
'fliplr', 'mosaic', 'mixup', 'copy_paste', 'conf', 'iou'}
|
||||
CFG_INT_KEYS = {
|
||||
'epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
|
||||
'line_thickness', 'workspace', 'nbs'}
|
||||
CFG_BOOL_KEYS = {
|
||||
'save', 'cache', 'exist_ok', 'pretrained', 'verbose', 'deterministic', 'single_cls', 'image_weights', 'rect',
|
||||
'cos_lr', 'overlap_mask', 'val', 'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt',
|
||||
'save_conf', 'save_crop', 'hide_labels', 'hide_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks',
|
||||
'boxes', 'keras', 'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'v5loader'}
|
||||
|
||||
|
||||
def cfg2dict(cfg):
|
||||
"""
|
||||
@ -88,11 +102,31 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG, override
|
||||
check_cfg_mismatch(cfg, overrides)
|
||||
cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)
|
||||
|
||||
# Type checks
|
||||
# Special handling for numeric project/names
|
||||
for k in 'project', 'name':
|
||||
if k in cfg and isinstance(cfg[k], (int, float)):
|
||||
cfg[k] = str(cfg[k])
|
||||
|
||||
# Type and Value checks
|
||||
for k, v in cfg.items():
|
||||
if v is not None: # None values may be from optional args
|
||||
if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)):
|
||||
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
|
||||
elif k in CFG_FRACTION_KEYS:
|
||||
if not isinstance(v, (int, float)):
|
||||
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')")
|
||||
if not (0.0 <= v <= 1.0):
|
||||
raise ValueError(f"'{k}={v}' is an invalid value. "
|
||||
f"Valid '{k}' values are between 0.0 and 1.0.")
|
||||
elif k in CFG_INT_KEYS and not isinstance(v, int):
|
||||
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||
f"'{k}' must be an int (i.e. '{k}=0')")
|
||||
elif k in CFG_BOOL_KEYS and not isinstance(v, bool):
|
||||
raise TypeError(f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||
f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')")
|
||||
|
||||
# Return instance
|
||||
return IterableSimpleNamespace(**cfg)
|
||||
|
||||
@ -184,6 +218,7 @@ def entrypoint(debug=''):
|
||||
try:
|
||||
re.sub(r' *= *', '=', a) # remove spaces around equals sign
|
||||
k, v = a.split('=', 1) # split on first '=' sign
|
||||
assert v, f"missing '{k}' value"
|
||||
if k == 'cfg': # custom.yaml passed
|
||||
LOGGER.info(f"{PREFIX}Overriding {DEFAULT_CFG_PATH} with {v}")
|
||||
overrides = {k: val for k, val in yaml_load(v).items() if k != 'cfg'}
|
||||
@ -198,7 +233,7 @@ def entrypoint(debug=''):
|
||||
with contextlib.suppress(Exception):
|
||||
v = eval(v)
|
||||
overrides[k] = v
|
||||
except (NameError, SyntaxError, ValueError) as e:
|
||||
except (NameError, SyntaxError, ValueError, AssertionError) as e:
|
||||
raise argument_error(a) from e
|
||||
|
||||
elif a in tasks:
|
||||
@ -224,7 +259,7 @@ def entrypoint(debug=''):
|
||||
mode = overrides.get('mode', None)
|
||||
if mode is None:
|
||||
mode = DEFAULT_CFG.mode or 'predict'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'mode=' is missing. Valid modes are {modes}. Using default 'mode={mode}'.")
|
||||
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {modes}. Using default 'mode={mode}'.")
|
||||
elif mode not in modes:
|
||||
if mode != 'checks':
|
||||
raise ValueError(emojis(f"ERROR ❌ Invalid 'mode={mode}'. Valid modes are {modes}."))
|
||||
@ -237,7 +272,7 @@ def entrypoint(debug=''):
|
||||
task = overrides.pop('task', None)
|
||||
if model is None:
|
||||
model = task2model.get(task, 'yolov8n.pt')
|
||||
LOGGER.warning(f"WARNING ⚠️ 'model=' is missing. Using default 'model={model}'.")
|
||||
LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
|
||||
from ultralytics.yolo.engine.model import YOLO
|
||||
overrides['model'] = model
|
||||
model = YOLO(model)
|
||||
@ -251,15 +286,15 @@ def entrypoint(debug=''):
|
||||
if mode == 'predict' and 'source' not in overrides:
|
||||
overrides['source'] = DEFAULT_CFG.source or ROOT / "assets" if (ROOT / "assets").exists() \
|
||||
else "https://ultralytics.com/images/bus.jpg"
|
||||
LOGGER.warning(f"WARNING ⚠️ 'source=' is missing. Using default 'source={overrides['source']}'.")
|
||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
|
||||
elif mode in ('train', 'val'):
|
||||
if 'data' not in overrides:
|
||||
overrides['data'] = task2data.get(task, DEFAULT_CFG.data)
|
||||
LOGGER.warning(f"WARNING ⚠️ 'data=' is missing. Using {model.task} default 'data={overrides['data']}'.")
|
||||
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using {model.task} default 'data={overrides['data']}'.")
|
||||
elif mode == 'export':
|
||||
if 'format' not in overrides:
|
||||
overrides['format'] = DEFAULT_CFG.format or 'torchscript'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'format=' is missing. Using default 'format={overrides['format']}'.")
|
||||
LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.")
|
||||
|
||||
# Run command in python
|
||||
# getattr(model, mode)(**vars(get_cfg(overrides=overrides))) # default args using default.yaml
|
||||
|
@ -132,7 +132,7 @@ class YOLODataset(BaseDataset):
|
||||
for lb in labels:
|
||||
lb["segments"] = []
|
||||
if len_cls == 0:
|
||||
raise ValueError(f"{self.prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}")
|
||||
raise ValueError(f"All labels empty in {cache_path}, can not start training without labels. {HELP_URL}")
|
||||
return labels
|
||||
|
||||
# TODO: use hyp config to set all these augmentations
|
||||
|
@ -131,7 +131,11 @@ def yaml_save(file='data.yaml', data=None):
|
||||
|
||||
with open(file, 'w') as f:
|
||||
# Dump data to file in YAML format, converting Path objects to strings
|
||||
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
|
||||
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v
|
||||
for k, v in data.items()},
|
||||
f,
|
||||
sort_keys=False,
|
||||
allow_unicode=True)
|
||||
|
||||
|
||||
def yaml_load(file='data.yaml', append_filename=False):
|
||||
@ -164,7 +168,7 @@ def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
|
||||
None
|
||||
"""
|
||||
yaml_dict = yaml_load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file
|
||||
dump = yaml.dump(yaml_dict, default_flow_style=False)
|
||||
dump = yaml.dump(yaml_dict, sort_keys=False, allow_unicode=True)
|
||||
LOGGER.info(f"Printing '{colorstr('bold', 'black', yaml_file)}'\n\n{dump}")
|
||||
|
||||
|
||||
|
@ -7,6 +7,7 @@ import sys
|
||||
import tempfile
|
||||
|
||||
from . import USER_CONFIG_DIR
|
||||
from .torch_utils import TORCH_1_9
|
||||
|
||||
|
||||
def find_free_network_port() -> int:
|
||||
@ -47,8 +48,9 @@ def generate_ddp_command(world_size, trainer):
|
||||
using_cli = not file_name.endswith(".py")
|
||||
if using_cli:
|
||||
file_name = generate_ddp_file(trainer)
|
||||
torch_distributed_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch"
|
||||
return [
|
||||
sys.executable, "-m", "torch.distributed.run", "--nproc_per_node", f"{world_size}", "--master_port",
|
||||
sys.executable, "-m", torch_distributed_cmd, "--nproc_per_node", f"{world_size}", "--master_port",
|
||||
f"{find_free_network_port()}", file_name] + sys.argv[1:]
|
||||
|
||||
|
||||
|
@ -24,6 +24,10 @@ LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable
|
||||
RANK = int(os.getenv('RANK', -1))
|
||||
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
|
||||
|
||||
TORCH_1_9 = check_version(torch.__version__, '1.9.0')
|
||||
TORCH_1_11 = check_version(torch.__version__, '1.11.0')
|
||||
TORCH_1_12 = check_version(torch.__version__, '1.12.0')
|
||||
|
||||
|
||||
@contextmanager
|
||||
def torch_distributed_zero_first(local_rank: int):
|
||||
@ -36,10 +40,10 @@ def torch_distributed_zero_first(local_rank: int):
|
||||
dist.barrier(device_ids=[0])
|
||||
|
||||
|
||||
def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
|
||||
def smart_inference_mode():
|
||||
# Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator
|
||||
def decorate(fn):
|
||||
return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn)
|
||||
return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)
|
||||
|
||||
return decorate
|
||||
|
||||
@ -49,7 +53,7 @@ def DDP_model(model):
|
||||
assert not check_version(torch.__version__, '1.12.0', pinned=True), \
|
||||
'torch==1.12.0 torchvision==0.13.0 DDP training is not supported due to a known issue. ' \
|
||||
'Please upgrade or downgrade torch to use DDP. See https://github.com/ultralytics/yolov5/issues/8395'
|
||||
if check_version(torch.__version__, '1.11.0'):
|
||||
if TORCH_1_11:
|
||||
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
|
||||
else:
|
||||
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
|
||||
@ -267,7 +271,7 @@ def init_seeds(seed=0, deterministic=False):
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
|
||||
# torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
|
||||
if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213
|
||||
if deterministic and TORCH_1_12: # https://github.com/ultralytics/yolov5/pull/8213
|
||||
torch.use_deterministic_algorithms(True)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
||||
|
Reference in New Issue
Block a user