diff --git a/tests/test_python.py b/tests/test_python.py index d51f091..9616d6f 100644 --- a/tests/test_python.py +++ b/tests/test_python.py @@ -101,7 +101,7 @@ def test_val_scratch(): def test_amp(): if torch.cuda.is_available(): - from ultralytics.yolo.engine.trainer import check_amp + from ultralytics.yolo.utils.checks import check_amp model = YOLO(MODEL).model.cuda() assert check_amp(model) diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index 7966526..b39ad13 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -24,10 +24,10 @@ from tqdm import tqdm from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset -from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, ONLINE, RANK, ROOT, SETTINGS, TQDM_BAR_FORMAT, __version__, - callbacks, clean_url, colorstr, emojis, yaml_save) +from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, __version__, callbacks, + clean_url, colorstr, emojis, yaml_save) from ultralytics.yolo.utils.autobatch import check_train_batch_size -from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args +from ultralytics.yolo.utils.checks import check_amp, check_file, check_imgsz, print_args from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command from ultralytics.yolo.utils.files import get_latest_run, increment_path from ultralytics.yolo.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, @@ -648,52 +648,3 @@ class BaseTrainer: LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups " f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias') return optimizer - - -def check_amp(model): - """ - This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model. - If the checks fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP - results, so AMP will be disabled during training. - - Args: - model (nn.Module): A YOLOv8 model instance. - - Returns: - (bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False. - - Raises: - AssertionError: If the AMP checks fail, indicating anomalies with the AMP functionality on the system. - """ - device = next(model.parameters()).device # get model device - if device.type in ('cpu', 'mps'): - return False # AMP only used on CUDA devices - - def amp_allclose(m, im): - """All close FP32 vs AMP results.""" - a = m(im, device=device, verbose=False)[0].boxes.data # FP32 inference - with torch.cuda.amp.autocast(True): - b = m(im, device=device, verbose=False)[0].boxes.data # AMP inference - del m - return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance - - f = ROOT / 'assets/bus.jpg' # image to check - im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if ONLINE else np.ones((640, 640, 3)) - prefix = colorstr('AMP: ') - LOGGER.info(f'{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...') - warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False." - try: - from ultralytics import YOLO - assert amp_allclose(YOLO('yolov8n.pt'), im) - LOGGER.info(f'{prefix}checks passed ✅') - except ConnectionError: - LOGGER.warning(f'{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}') - except (AttributeError, ModuleNotFoundError): - LOGGER.warning( - f'{prefix}checks skipped ⚠️. Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}' - ) - except AssertionError: - LOGGER.warning(f'{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to ' - f'NaN losses or zero-mAP results, so AMP will be disabled during training.') - return False - return True diff --git a/ultralytics/yolo/utils/checks.py b/ultralytics/yolo/utils/checks.py index 2767560..03b5048 100644 --- a/ultralytics/yolo/utils/checks.py +++ b/ultralytics/yolo/utils/checks.py @@ -344,6 +344,55 @@ def check_yolo(verbose=True, device=''): LOGGER.info(f'Setup complete ✅ {s}') +def check_amp(model): + """ + This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model. + If the checks fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP + results, so AMP will be disabled during training. + + Args: + model (nn.Module): A YOLOv8 model instance. + + Returns: + (bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False. + + Raises: + AssertionError: If the AMP checks fail, indicating anomalies with the AMP functionality on the system. + """ + device = next(model.parameters()).device # get model device + if device.type in ('cpu', 'mps'): + return False # AMP only used on CUDA devices + + def amp_allclose(m, im): + """All close FP32 vs AMP results.""" + a = m(im, device=device, verbose=False)[0].boxes.data # FP32 inference + with torch.cuda.amp.autocast(True): + b = m(im, device=device, verbose=False)[0].boxes.data # AMP inference + del m + return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance + + f = ROOT / 'assets/bus.jpg' # image to check + im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if ONLINE else np.ones((640, 640, 3)) + prefix = colorstr('AMP: ') + LOGGER.info(f'{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...') + warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False." + try: + from ultralytics import YOLO + assert amp_allclose(YOLO('yolov8n.pt'), im) + LOGGER.info(f'{prefix}checks passed ✅') + except ConnectionError: + LOGGER.warning(f'{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}') + except (AttributeError, ModuleNotFoundError): + LOGGER.warning( + f'{prefix}checks skipped ⚠️. Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}' + ) + except AssertionError: + LOGGER.warning(f'{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to ' + f'NaN losses or zero-mAP results, so AMP will be disabled during training.') + return False + return True + + def git_describe(path=ROOT): # path must be a directory # Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe try: diff --git a/ultralytics/yolo/utils/torch_utils.py b/ultralytics/yolo/utils/torch_utils.py index 98c0302..f576ccd 100644 --- a/ultralytics/yolo/utils/torch_utils.py +++ b/ultralytics/yolo/utils/torch_utils.py @@ -386,7 +386,7 @@ def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None: import pickle x = torch.load(f, map_location=torch.device('cpu')) - args = {**DEFAULT_CFG_DICT, **x['train_args']} # combine model args with default args, preferring model args + args = {**DEFAULT_CFG_DICT, **x['train_args']} if 'train_args' in x else None # combine args if x.get('ema'): x['model'] = x['ema'] # replace model with ema for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys