# Ultralytics YOLO 🚀, GPL-3.0 license import math import os import platform import random import time from contextlib import contextmanager from copy import deepcopy from pathlib import Path import numpy as np import thop import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel as DDP import ultralytics from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, DEFAULT_CONFIG_KEYS, LOGGER from ultralytics.yolo.utils.checks import git_describe from .checks import check_version LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html RANK = int(os.getenv('RANK', -1)) WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) @contextmanager def torch_distributed_zero_first(local_rank: int): # Decorator to make all processes in distributed training wait for each local_master to do something initialized = torch.distributed.is_initialized() # prevent 'Default process group has not been initialized' errors if initialized and local_rank not in {-1, 0}: dist.barrier(device_ids=[local_rank]) yield if initialized and local_rank == 0: dist.barrier(device_ids=[0]) def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')): # Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator def decorate(fn): return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn) return decorate def DDP_model(model): # Model DDP creation with checks assert not check_version(torch.__version__, '1.12.0', pinned=True), \ 'torch==1.12.0 torchvision==0.13.0 DDP training is not supported due to a known issue. ' \ 'Please upgrade or downgrade torch to use DDP. See https://github.com/ultralytics/yolov5/issues/8395' if check_version(torch.__version__, '1.11.0'): return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True) else: return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) def select_device(device='', batch_size=0, newline=False): # device = None or 'cpu' or 0 or '0' or '0,1,2,3' ver = git_describe() or ultralytics.__version__ # git commit or pip package version s = f'Ultralytics YOLOv{ver} 🚀 Python-{platform.python_version()} torch-{torch.__version__} ' device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0' cpu = device == 'cpu' mps = device == 'mps' # Apple Metal Performance Shaders (MPS) if cpu or mps: os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False elif device: # non-cpu device requested os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available() assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \ f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)" if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7 n = len(devices) # device count if n > 1 and batch_size > 0: # check batch_size is divisible by device_count assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}' space = ' ' * (len(s) + 1) for i, d in enumerate(devices): p = torch.cuda.get_device_properties(i) s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB arg = 'cuda:0' elif mps and getattr(torch, 'has_mps', False) and torch.backends.mps.is_available(): # prefer MPS if available s += 'MPS\n' arg = 'mps' else: # revert to CPU s += 'CPU\n' arg = 'cpu' if RANK == -1: LOGGER.info(s if newline else s.rstrip()) return torch.device(arg) def time_sync(): # PyTorch-accurate time if torch.cuda.is_available(): torch.cuda.synchronize() return time.time() def fuse_conv_and_bn(conv, bn): # Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ fusedconv = nn.Conv2d(conv.in_channels, conv.out_channels, kernel_size=conv.kernel_size, stride=conv.stride, padding=conv.padding, dilation=conv.dilation, groups=conv.groups, bias=True).requires_grad_(False).to(conv.weight.device) # Prepare filters w_conv = conv.weight.clone().view(conv.out_channels, -1) w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape)) # Prepare spatial bias b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) return fusedconv def model_info(model, verbose=False, imgsz=640): # Model information. imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320] n_p = get_num_params(model) n_g = get_num_gradients(model) # number gradients if verbose: print(f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}") for i, (name, p) in enumerate(model.named_parameters()): name = name.replace('module_list.', '') print('%5g %40s %9s %12g %20s %10.3g %10.3g' % (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std())) flops = get_flops(model, imgsz) fs = f', {flops:.1f} GFLOPs' if flops else '' m = Path(getattr(model, 'yaml_file', '') or model.yaml.get('yaml_file', '')).stem.replace('yolo', 'YOLO') or 'Model' LOGGER.info(f"{m} summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}") def get_num_params(model): return sum(x.numel() for x in model.parameters()) def get_num_gradients(model): return sum(x.numel() for x in model.parameters() if x.requires_grad) def get_flops(model, imgsz=640): try: model = de_parallel(model) p = next(model.parameters()) stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format flops = thop.profile(deepcopy(model), inputs=(im,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs return flops except Exception: return 0 def initialize_weights(model): for m in model.modules(): t = type(m) if t is nn.Conv2d: pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif t is nn.BatchNorm2d: m.eps = 1e-3 m.momentum = 0.03 elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: m.inplace = True def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416) # Scales img(bs,3,y,x) by ratio constrained to gs-multiple if ratio == 1.0: return img h, w = img.shape[2:] s = (int(h * ratio), int(w * ratio)) # new size img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize if not same_shape: # pad/crop img h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w)) return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean def make_divisible(x, divisor): # Returns nearest x divisible by divisor if isinstance(divisor, torch.Tensor): divisor = int(divisor.max()) # to int return math.ceil(x / divisor) * divisor def copy_attr(a, b, include=(), exclude=()): # Copy attributes from b to a, options to only include [...] and to exclude [...] for k, v in b.__dict__.items(): if (len(include) and k not in include) or k.startswith('_') or k in exclude: continue else: setattr(a, k, v) def intersect_dicts(da, db, exclude=()): # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape} def is_parallel(model): # Returns True if model is of type DP or DDP return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) def de_parallel(model): # De-parallelize a model: returns single-GPU model if model is of type DP or DDP return model.module if is_parallel(model) else model def one_cycle(y1=0.0, y2=1.0, steps=100): # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1 def init_seeds(seed=0, deterministic=False): # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe # torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287 if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213 torch.use_deterministic_algorithms(True) torch.backends.cudnn.deterministic = True os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' os.environ['PYTHONHASHSEED'] = str(seed) class ModelEMA: """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models Keeps a moving average of everything in the model state_dict (parameters and buffers) For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage """ def __init__(self, model, decay=0.9999, tau=2000, updates=0): # Create EMA self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA self.updates = updates # number of EMA updates self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs) for p in self.ema.parameters(): p.requires_grad_(False) def update(self, model): # Update EMA parameters self.updates += 1 d = self.decay(self.updates) msd = de_parallel(model).state_dict() # model state_dict for k, v in self.ema.state_dict().items(): if v.dtype.is_floating_point: # true for FP16 and FP32 v *= d v += (1 - d) * msd[k].detach() # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32' def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): # Update EMA attributes copy_attr(self.ema, model, include, exclude) def strip_optimizer(f='best.pt', s=''): """ Strip optimizer from 'f' to finalize training, optionally save as 's'. Usage: from ultralytics.yolo.utils.torch_utils import strip_optimizer from pathlib import Path for f in Path('/Users/glennjocher/Downloads/weights').glob('*.pt'): strip_optimizer(f) Args: f (str): file path to model state to strip the optimizer from. Default is 'best.pt'. s (str): file path to save the model with stripped optimizer to. Default is ''. If not provided, the original file will be overwritten. Returns: None """ x = torch.load(f, map_location=torch.device('cpu')) args = {**DEFAULT_CONFIG_DICT, **x['train_args']} # combine model args with default args, preferring model args if x.get('ema'): x['model'] = x['ema'] # replace model with ema for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys x[k] = None x['epoch'] = -1 x['model'].half() # to FP16 for p in x['model'].parameters(): p.requires_grad = False x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS} # strip non-default keys torch.save(x, s or f) mb = os.path.getsize(s or f) / 1E6 # filesize LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB") def guess_task_from_head(head): task = None if head.lower() in ["classify", "classifier", "cls", "fc"]: task = "classify" if head.lower() in ["detect"]: task = "detect" if head.lower() in ["segment"]: task = "segment" if not task: raise SyntaxError("task or model not recognized! Please refer the docs at : ") # TODO: add docs links return task def profile(input, ops, n=10, device=None): """ YOLOv5 speed/memory/FLOPs profiler Usage: input = torch.randn(16, 3, 640, 640) m1 = lambda x: x * torch.sigmoid(x) m2 = nn.SiLU() profile(input, [m1, m2], n=100) # profile over 100 iterations """ results = [] if not isinstance(device, torch.device): device = select_device(device) print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}" f"{'input':>24s}{'output':>24s}") for x in input if isinstance(input, list) else [input]: x = x.to(device) x.requires_grad = True for m in ops if isinstance(ops, list) else [ops]: m = m.to(device) if hasattr(m, 'to') else m # device m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward try: flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs except Exception: flops = 0 try: for _ in range(n): t[0] = time_sync() y = m(x) t[1] = time_sync() try: _ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward() t[2] = time_sync() except Exception: # no backward method # print(e) # for debug t[2] = float('nan') tf += (t[1] - t[0]) * 1000 / n # ms per op forward tb += (t[2] - t[1]) * 1000 / n # ms per op backward mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB) s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' for x in (x, y)) # shapes p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}') results.append([p, flops, mem, tf, tb, s_in, s_out]) except Exception as e: print(e) results.append(None) torch.cuda.empty_cache() return results