Move `check_amp()` to checks.py (#2948)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
single_channel
Ayush Chaurasia 1 year ago committed by GitHub
parent d07ba25dc4
commit 67cf53b475
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -101,7 +101,7 @@ def test_val_scratch():
def test_amp(): def test_amp():
if torch.cuda.is_available(): if torch.cuda.is_available():
from ultralytics.yolo.engine.trainer import check_amp from ultralytics.yolo.utils.checks import check_amp
model = YOLO(MODEL).model.cuda() model = YOLO(MODEL).model.cuda()
assert check_amp(model) assert check_amp(model)

@ -24,10 +24,10 @@ from tqdm import tqdm
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, ONLINE, RANK, ROOT, SETTINGS, TQDM_BAR_FORMAT, __version__, from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, __version__, callbacks,
callbacks, clean_url, colorstr, emojis, yaml_save) clean_url, colorstr, emojis, yaml_save)
from ultralytics.yolo.utils.autobatch import check_train_batch_size from ultralytics.yolo.utils.autobatch import check_train_batch_size
from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args from ultralytics.yolo.utils.checks import check_amp, check_file, check_imgsz, print_args
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
from ultralytics.yolo.utils.files import get_latest_run, increment_path from ultralytics.yolo.utils.files import get_latest_run, increment_path
from ultralytics.yolo.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, from ultralytics.yolo.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle,
@ -648,52 +648,3 @@ class BaseTrainer:
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups " LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias') f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias')
return optimizer return optimizer
def check_amp(model):
"""
This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model.
If the checks fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP
results, so AMP will be disabled during training.
Args:
model (nn.Module): A YOLOv8 model instance.
Returns:
(bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
Raises:
AssertionError: If the AMP checks fail, indicating anomalies with the AMP functionality on the system.
"""
device = next(model.parameters()).device # get model device
if device.type in ('cpu', 'mps'):
return False # AMP only used on CUDA devices
def amp_allclose(m, im):
"""All close FP32 vs AMP results."""
a = m(im, device=device, verbose=False)[0].boxes.data # FP32 inference
with torch.cuda.amp.autocast(True):
b = m(im, device=device, verbose=False)[0].boxes.data # AMP inference
del m
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance
f = ROOT / 'assets/bus.jpg' # image to check
im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if ONLINE else np.ones((640, 640, 3))
prefix = colorstr('AMP: ')
LOGGER.info(f'{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...')
warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False."
try:
from ultralytics import YOLO
assert amp_allclose(YOLO('yolov8n.pt'), im)
LOGGER.info(f'{prefix}checks passed ✅')
except ConnectionError:
LOGGER.warning(f'{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}')
except (AttributeError, ModuleNotFoundError):
LOGGER.warning(
f'{prefix}checks skipped ⚠️. Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}'
)
except AssertionError:
LOGGER.warning(f'{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to '
f'NaN losses or zero-mAP results, so AMP will be disabled during training.')
return False
return True

@ -344,6 +344,55 @@ def check_yolo(verbose=True, device=''):
LOGGER.info(f'Setup complete ✅ {s}') LOGGER.info(f'Setup complete ✅ {s}')
def check_amp(model):
"""
This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model.
If the checks fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP
results, so AMP will be disabled during training.
Args:
model (nn.Module): A YOLOv8 model instance.
Returns:
(bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
Raises:
AssertionError: If the AMP checks fail, indicating anomalies with the AMP functionality on the system.
"""
device = next(model.parameters()).device # get model device
if device.type in ('cpu', 'mps'):
return False # AMP only used on CUDA devices
def amp_allclose(m, im):
"""All close FP32 vs AMP results."""
a = m(im, device=device, verbose=False)[0].boxes.data # FP32 inference
with torch.cuda.amp.autocast(True):
b = m(im, device=device, verbose=False)[0].boxes.data # AMP inference
del m
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance
f = ROOT / 'assets/bus.jpg' # image to check
im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if ONLINE else np.ones((640, 640, 3))
prefix = colorstr('AMP: ')
LOGGER.info(f'{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...')
warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False."
try:
from ultralytics import YOLO
assert amp_allclose(YOLO('yolov8n.pt'), im)
LOGGER.info(f'{prefix}checks passed ✅')
except ConnectionError:
LOGGER.warning(f'{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}')
except (AttributeError, ModuleNotFoundError):
LOGGER.warning(
f'{prefix}checks skipped ⚠️. Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}'
)
except AssertionError:
LOGGER.warning(f'{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to '
f'NaN losses or zero-mAP results, so AMP will be disabled during training.')
return False
return True
def git_describe(path=ROOT): # path must be a directory def git_describe(path=ROOT): # path must be a directory
# Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe # Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
try: try:

@ -386,7 +386,7 @@ def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None:
import pickle import pickle
x = torch.load(f, map_location=torch.device('cpu')) x = torch.load(f, map_location=torch.device('cpu'))
args = {**DEFAULT_CFG_DICT, **x['train_args']} # combine model args with default args, preferring model args args = {**DEFAULT_CFG_DICT, **x['train_args']} if 'train_args' in x else None # combine args
if x.get('ema'): if x.get('ema'):
x['model'] = x['ema'] # replace model with ema x['model'] = x['ema'] # replace model with ema
for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys

Loading…
Cancel
Save