You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
519 lines
22 KiB
519 lines
22 KiB
2 years ago
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||
2 years ago
|
|
||
2 years ago
|
import math
|
||
2 years ago
|
import os
|
||
2 years ago
|
import platform
|
||
2 years ago
|
import random
|
||
2 years ago
|
import time
|
||
2 years ago
|
from contextlib import contextmanager
|
||
2 years ago
|
from copy import deepcopy
|
||
|
from pathlib import Path
|
||
2 years ago
|
from typing import Union
|
||
2 years ago
|
|
||
2 years ago
|
import numpy as np
|
||
2 years ago
|
import torch
|
||
|
import torch.distributed as dist
|
||
2 years ago
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
2 years ago
|
import torchvision
|
||
2 years ago
|
|
||
1 year ago
|
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, __version__
|
||
|
from ultralytics.utils.checks import check_requirements, check_version
|
||
2 years ago
|
|
||
2 years ago
|
try:
|
||
|
import thop
|
||
|
except ImportError:
|
||
|
thop = None
|
||
|
|
||
2 years ago
|
TORCHVISION_0_10 = check_version(torchvision.__version__, '0.10.0')
|
||
2 years ago
|
TORCH_1_9 = check_version(torch.__version__, '1.9.0')
|
||
|
TORCH_1_11 = check_version(torch.__version__, '1.11.0')
|
||
|
TORCH_1_12 = check_version(torch.__version__, '1.12.0')
|
||
2 years ago
|
TORCH_2_0 = check_version(torch.__version__, minimum='2.0')
|
||
2 years ago
|
|
||
2 years ago
|
|
||
|
@contextmanager
|
||
|
def torch_distributed_zero_first(local_rank: int):
|
||
2 years ago
|
"""Decorator to make all processes in distributed training wait for each local_master to do something."""
|
||
2 years ago
|
initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
|
||
2 years ago
|
if initialized and local_rank not in (-1, 0):
|
||
2 years ago
|
dist.barrier(device_ids=[local_rank])
|
||
|
yield
|
||
2 years ago
|
if initialized and local_rank == 0:
|
||
2 years ago
|
dist.barrier(device_ids=[0])
|
||
|
|
||
|
|
||
2 years ago
|
def smart_inference_mode():
|
||
2 years ago
|
"""Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator."""
|
||
|
|
||
2 years ago
|
def decorate(fn):
|
||
2 years ago
|
"""Applies appropriate torch decorator for inference mode based on torch version."""
|
||
2 years ago
|
return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)
|
||
2 years ago
|
|
||
|
return decorate
|
||
|
|
||
|
|
||
1 year ago
|
def get_cpu_info():
|
||
|
"""Return a string with system CPU information, i.e. 'Apple M2'."""
|
||
|
check_requirements('py-cpuinfo')
|
||
|
import cpuinfo # noqa
|
||
|
return cpuinfo.get_cpu_info()['brand_raw'].replace('(R)', '').replace('CPU ', '').replace('@ ', '')
|
||
|
|
||
|
|
||
2 years ago
|
def select_device(device='', batch=0, newline=False, verbose=True):
|
||
2 years ago
|
"""Selects PyTorch Device. Options are device = None or 'cpu' or 0 or '0' or '0,1,2,3'."""
|
||
2 years ago
|
s = f'Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
|
||
2 years ago
|
device = str(device).lower()
|
||
|
for remove in 'cuda:', 'none', '(', ')', '[', ']', "'", ' ':
|
||
|
device = device.replace(remove, '') # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
|
||
2 years ago
|
cpu = device == 'cpu'
|
||
|
mps = device == 'mps' # Apple Metal Performance Shaders (MPS)
|
||
|
if cpu or mps:
|
||
|
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
|
||
|
elif device: # non-cpu device requested
|
||
1 year ago
|
if device == 'cuda':
|
||
|
device = '0'
|
||
2 years ago
|
visible = os.environ.get('CUDA_VISIBLE_DEVICES', None)
|
||
2 years ago
|
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
|
||
2 years ago
|
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', ''))):
|
||
2 years ago
|
LOGGER.info(s)
|
||
2 years ago
|
install = 'See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no ' \
|
||
|
'CUDA devices are seen by torch.\n' if torch.cuda.device_count() == 0 else ''
|
||
2 years ago
|
raise ValueError(f"Invalid CUDA 'device={device}' requested."
|
||
|
f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
|
||
|
f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
|
||
2 years ago
|
f'\ntorch.cuda.is_available(): {torch.cuda.is_available()}'
|
||
|
f'\ntorch.cuda.device_count(): {torch.cuda.device_count()}'
|
||
2 years ago
|
f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
|
||
2 years ago
|
f'{install}')
|
||
2 years ago
|
|
||
|
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
|
||
|
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
||
|
n = len(devices) # device count
|
||
2 years ago
|
if n > 1 and batch > 0 and batch % n != 0: # check batch_size is divisible by device_count
|
||
2 years ago
|
raise ValueError(f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
|
||
|
f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}.")
|
||
2 years ago
|
space = ' ' * (len(s) + 1)
|
||
|
for i, d in enumerate(devices):
|
||
|
p = torch.cuda.get_device_properties(i)
|
||
|
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
|
||
|
arg = 'cuda:0'
|
||
2 years ago
|
elif mps and getattr(torch, 'has_mps', False) and torch.backends.mps.is_available() and TORCH_2_0:
|
||
2 years ago
|
# Prefer MPS if available
|
||
1 year ago
|
s += f'MPS ({get_cpu_info()})\n'
|
||
2 years ago
|
arg = 'mps'
|
||
|
else: # revert to CPU
|
||
1 year ago
|
s += f'CPU ({get_cpu_info()})\n'
|
||
2 years ago
|
arg = 'cpu'
|
||
|
|
||
2 years ago
|
if verbose and RANK == -1:
|
||
2 years ago
|
LOGGER.info(s if newline else s.rstrip())
|
||
2 years ago
|
return torch.device(arg)
|
||
2 years ago
|
|
||
|
|
||
|
def time_sync():
|
||
2 years ago
|
"""PyTorch-accurate time."""
|
||
2 years ago
|
if torch.cuda.is_available():
|
||
|
torch.cuda.synchronize()
|
||
|
return time.time()
|
||
|
|
||
|
|
||
|
def fuse_conv_and_bn(conv, bn):
|
||
2 years ago
|
"""Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/."""
|
||
2 years ago
|
fusedconv = nn.Conv2d(conv.in_channels,
|
||
|
conv.out_channels,
|
||
|
kernel_size=conv.kernel_size,
|
||
|
stride=conv.stride,
|
||
|
padding=conv.padding,
|
||
|
dilation=conv.dilation,
|
||
|
groups=conv.groups,
|
||
|
bias=True).requires_grad_(False).to(conv.weight.device)
|
||
|
|
||
|
# Prepare filters
|
||
|
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
||
|
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
||
|
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
|
||
|
|
||
|
# Prepare spatial bias
|
||
|
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
|
||
|
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
|
||
|
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
|
||
|
|
||
|
return fusedconv
|
||
|
|
||
|
|
||
2 years ago
|
def fuse_deconv_and_bn(deconv, bn):
|
||
2 years ago
|
"""Fuse ConvTranspose2d() and BatchNorm2d() layers."""
|
||
2 years ago
|
fuseddconv = nn.ConvTranspose2d(deconv.in_channels,
|
||
|
deconv.out_channels,
|
||
|
kernel_size=deconv.kernel_size,
|
||
|
stride=deconv.stride,
|
||
|
padding=deconv.padding,
|
||
|
output_padding=deconv.output_padding,
|
||
|
dilation=deconv.dilation,
|
||
|
groups=deconv.groups,
|
||
|
bias=True).requires_grad_(False).to(deconv.weight.device)
|
||
|
|
||
2 years ago
|
# Prepare filters
|
||
2 years ago
|
w_deconv = deconv.weight.clone().view(deconv.out_channels, -1)
|
||
|
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
||
|
fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))
|
||
|
|
||
|
# Prepare spatial bias
|
||
|
b_conv = torch.zeros(deconv.weight.size(1), device=deconv.weight.device) if deconv.bias is None else deconv.bias
|
||
|
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
|
||
|
fuseddconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
|
||
|
|
||
|
return fuseddconv
|
||
|
|
||
|
|
||
2 years ago
|
def model_info(model, detailed=False, verbose=True, imgsz=640):
|
||
2 years ago
|
"""Model information. imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]."""
|
||
2 years ago
|
if not verbose:
|
||
|
return
|
||
2 years ago
|
n_p = get_num_params(model) # number of parameters
|
||
|
n_g = get_num_gradients(model) # number of gradients
|
||
|
n_l = len(list(model.modules())) # number of layers
|
||
2 years ago
|
if detailed:
|
||
2 years ago
|
LOGGER.info(
|
||
|
f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
|
||
2 years ago
|
for i, (name, p) in enumerate(model.named_parameters()):
|
||
|
name = name.replace('module_list.', '')
|
||
2 years ago
|
LOGGER.info('%5g %40s %9s %12g %20s %10.3g %10.3g %10s' %
|
||
|
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std(), p.dtype))
|
||
2 years ago
|
|
||
2 years ago
|
flops = get_flops(model, imgsz)
|
||
2 years ago
|
fused = ' (fused)' if getattr(model, 'is_fused', lambda: False)() else ''
|
||
2 years ago
|
fs = f', {flops:.1f} GFLOPs' if flops else ''
|
||
2 years ago
|
yaml_file = getattr(model, 'yaml_file', '') or getattr(model, 'yaml', {}).get('yaml_file', '')
|
||
|
model_name = Path(yaml_file).stem.replace('yolo', 'YOLO') or 'Model'
|
||
|
LOGGER.info(f'{model_name} summary{fused}: {n_l} layers, {n_p} parameters, {n_g} gradients{fs}')
|
||
|
return n_l, n_p, n_g, flops
|
||
2 years ago
|
|
||
|
|
||
|
def get_num_params(model):
|
||
2 years ago
|
"""Return the total number of parameters in a YOLO model."""
|
||
2 years ago
|
return sum(x.numel() for x in model.parameters())
|
||
|
|
||
|
|
||
|
def get_num_gradients(model):
|
||
2 years ago
|
"""Return the total number of parameters with gradients in a YOLO model."""
|
||
2 years ago
|
return sum(x.numel() for x in model.parameters() if x.requires_grad)
|
||
|
|
||
|
|
||
2 years ago
|
def model_info_for_loggers(trainer):
|
||
|
"""
|
||
|
Return model info dict with useful model information.
|
||
|
|
||
|
Example for YOLOv8n:
|
||
|
{'model/parameters': 3151904,
|
||
|
'model/GFLOPs': 8.746,
|
||
|
'model/speed_ONNX(ms)': 41.244,
|
||
|
'model/speed_TensorRT(ms)': 3.211,
|
||
|
'model/speed_PyTorch(ms)': 18.755}
|
||
|
"""
|
||
|
if trainer.args.profile: # profile ONNX and TensorRT times
|
||
1 year ago
|
from ultralytics.utils.benchmarks import ProfileModels
|
||
2 years ago
|
results = ProfileModels([trainer.last], device=trainer.device).profile()[0]
|
||
|
results.pop('model/name')
|
||
|
else: # only return PyTorch times from most recent validation
|
||
|
results = {
|
||
|
'model/parameters': get_num_params(trainer.model),
|
||
|
'model/GFLOPs': round(get_flops(trainer.model), 3)}
|
||
|
results['model/speed_PyTorch(ms)'] = round(trainer.validator.speed['inference'], 3)
|
||
|
return results
|
||
|
|
||
|
|
||
2 years ago
|
def get_flops(model, imgsz=640):
|
||
2 years ago
|
"""Return a YOLO model's FLOPs."""
|
||
2 years ago
|
try:
|
||
2 years ago
|
model = de_parallel(model)
|
||
2 years ago
|
p = next(model.parameters())
|
||
|
stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride
|
||
|
im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
||
2 years ago
|
flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1E9 * 2 if thop else 0 # stride GFLOPs
|
||
2 years ago
|
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
|
||
2 years ago
|
return flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
|
||
2 years ago
|
except Exception:
|
||
2 years ago
|
return 0
|
||
2 years ago
|
|
||
|
|
||
2 years ago
|
def get_flops_with_torch_profiler(model, imgsz=640):
|
||
1 year ago
|
"""Compute model FLOPs (thop alternative)."""
|
||
2 years ago
|
model = de_parallel(model)
|
||
|
p = next(model.parameters())
|
||
|
stride = (max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32) * 2 # max stride
|
||
|
im = torch.zeros((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
||
|
with torch.profiler.profile(with_flops=True) as prof:
|
||
|
model(im)
|
||
|
flops = sum(x.flops for x in prof.key_averages()) / 1E9
|
||
|
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
|
||
|
flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
|
||
|
return flops
|
||
|
|
||
|
|
||
2 years ago
|
def initialize_weights(model):
|
||
2 years ago
|
"""Initialize model weights to random values."""
|
||
2 years ago
|
for m in model.modules():
|
||
|
t = type(m)
|
||
|
if t is nn.Conv2d:
|
||
|
pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||
|
elif t is nn.BatchNorm2d:
|
||
|
m.eps = 1e-3
|
||
|
m.momentum = 0.03
|
||
|
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
|
||
|
m.inplace = True
|
||
|
|
||
|
|
||
|
def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
|
||
|
# Scales img(bs,3,y,x) by ratio constrained to gs-multiple
|
||
|
if ratio == 1.0:
|
||
|
return img
|
||
|
h, w = img.shape[2:]
|
||
|
s = (int(h * ratio), int(w * ratio)) # new size
|
||
|
img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
|
||
|
if not same_shape: # pad/crop img
|
||
|
h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
|
||
|
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
|
||
|
|
||
|
|
||
2 years ago
|
def make_divisible(x, divisor):
|
||
2 years ago
|
"""Returns nearest x divisible by divisor."""
|
||
2 years ago
|
if isinstance(divisor, torch.Tensor):
|
||
|
divisor = int(divisor.max()) # to int
|
||
|
return math.ceil(x / divisor) * divisor
|
||
|
|
||
|
|
||
2 years ago
|
def copy_attr(a, b, include=(), exclude=()):
|
||
2 years ago
|
"""Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes."""
|
||
2 years ago
|
for k, v in b.__dict__.items():
|
||
|
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
|
||
|
continue
|
||
|
else:
|
||
|
setattr(a, k, v)
|
||
|
|
||
|
|
||
2 years ago
|
def get_latest_opset():
|
||
2 years ago
|
"""Return second-most (for maturity) recently supported ONNX opset by this version of torch."""
|
||
2 years ago
|
return max(int(k[14:]) for k in vars(torch.onnx) if 'symbolic_opset' in k) - 1 # opset
|
||
2 years ago
|
|
||
|
|
||
2 years ago
|
def intersect_dicts(da, db, exclude=()):
|
||
2 years ago
|
"""Returns a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values."""
|
||
2 years ago
|
return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
|
||
2 years ago
|
|
||
|
|
||
|
def is_parallel(model):
|
||
2 years ago
|
"""Returns True if model is of type DP or DDP."""
|
||
2 years ago
|
return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))
|
||
2 years ago
|
|
||
|
|
||
|
def de_parallel(model):
|
||
2 years ago
|
"""De-parallelize a model: returns single-GPU model if model is of type DP or DDP."""
|
||
2 years ago
|
return model.module if is_parallel(model) else model
|
||
2 years ago
|
|
||
|
|
||
2 years ago
|
def one_cycle(y1=0.0, y2=1.0, steps=100):
|
||
2 years ago
|
"""Returns a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf."""
|
||
2 years ago
|
return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
|
||
|
|
||
|
|
||
2 years ago
|
def init_seeds(seed=0, deterministic=False):
|
||
2 years ago
|
"""Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html."""
|
||
2 years ago
|
random.seed(seed)
|
||
|
np.random.seed(seed)
|
||
|
torch.manual_seed(seed)
|
||
|
torch.cuda.manual_seed(seed)
|
||
|
torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
|
||
|
# torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
|
||
1 year ago
|
if deterministic:
|
||
2 years ago
|
if TORCH_2_0:
|
||
1 year ago
|
torch.use_deterministic_algorithms(True, warn_only=True) # warn if deterministic is not possible
|
||
2 years ago
|
torch.backends.cudnn.deterministic = True
|
||
|
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
||
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
||
|
else:
|
||
|
LOGGER.warning('WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.')
|
||
1 year ago
|
else:
|
||
|
torch.use_deterministic_algorithms(False)
|
||
|
torch.backends.cudnn.deterministic = False
|
||
2 years ago
|
|
||
|
|
||
2 years ago
|
class ModelEMA:
|
||
2 years ago
|
"""Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
|
||
2 years ago
|
Keeps a moving average of everything in the model state_dict (parameters and buffers)
|
||
|
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||
2 years ago
|
To disable EMA set the `enabled` attribute to `False`.
|
||
2 years ago
|
"""
|
||
|
|
||
|
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
||
2 years ago
|
"""Create EMA."""
|
||
2 years ago
|
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
||
|
self.updates = updates # number of EMA updates
|
||
|
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
||
|
for p in self.ema.parameters():
|
||
|
p.requires_grad_(False)
|
||
2 years ago
|
self.enabled = True
|
||
2 years ago
|
|
||
|
def update(self, model):
|
||
2 years ago
|
"""Update EMA parameters."""
|
||
2 years ago
|
if self.enabled:
|
||
|
self.updates += 1
|
||
|
d = self.decay(self.updates)
|
||
2 years ago
|
|
||
2 years ago
|
msd = de_parallel(model).state_dict() # model state_dict
|
||
|
for k, v in self.ema.state_dict().items():
|
||
|
if v.dtype.is_floating_point: # true for FP16 and FP32
|
||
|
v *= d
|
||
|
v += (1 - d) * msd[k].detach()
|
||
|
# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}'
|
||
2 years ago
|
|
||
|
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
||
2 years ago
|
"""Updates attributes and saves stripped model with optimizer removed."""
|
||
2 years ago
|
if self.enabled:
|
||
|
copy_attr(self.ema, model, include, exclude)
|
||
2 years ago
|
|
||
|
|
||
2 years ago
|
def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None:
|
||
2 years ago
|
"""
|
||
|
Strip optimizer from 'f' to finalize training, optionally save as 's'.
|
||
|
|
||
|
Args:
|
||
2 years ago
|
f (str): file path to model to strip the optimizer from. Default is 'best.pt'.
|
||
|
s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
|
||
2 years ago
|
|
||
|
Returns:
|
||
|
None
|
||
2 years ago
|
|
||
|
Usage:
|
||
|
from pathlib import Path
|
||
1 year ago
|
from ultralytics.utils.torch_utils import strip_optimizer
|
||
2 years ago
|
for f in Path('/Users/glennjocher/Downloads/weights').rglob('*.pt'):
|
||
|
strip_optimizer(f)
|
||
2 years ago
|
"""
|
||
2 years ago
|
# Use dill (if exists) to serialize the lambda functions where pickle does not do this
|
||
|
try:
|
||
|
import dill as pickle
|
||
|
except ImportError:
|
||
|
import pickle
|
||
|
|
||
2 years ago
|
x = torch.load(f, map_location=torch.device('cpu'))
|
||
1 year ago
|
if 'model' not in x:
|
||
|
LOGGER.info(f'Skipping {f}, not a valid Ultralytics model.')
|
||
|
return
|
||
|
|
||
|
if hasattr(x['model'], 'args'):
|
||
|
x['model'].args = dict(x['model'].args) # convert from IterableSimpleNamespace to dict
|
||
1 year ago
|
args = {**DEFAULT_CFG_DICT, **x['train_args']} if 'train_args' in x else None # combine args
|
||
2 years ago
|
if x.get('ema'):
|
||
|
x['model'] = x['ema'] # replace model with ema
|
||
|
for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
|
||
|
x[k] = None
|
||
|
x['epoch'] = -1
|
||
|
x['model'].half() # to FP16
|
||
|
for p in x['model'].parameters():
|
||
|
p.requires_grad = False
|
||
2 years ago
|
x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
|
||
|
# x['model'].args = x['train_args']
|
||
2 years ago
|
torch.save(x, s or f, pickle_module=pickle)
|
||
2 years ago
|
mb = os.path.getsize(s or f) / 1E6 # filesize
|
||
|
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
||
2 years ago
|
|
||
|
|
||
2 years ago
|
def profile(input, ops, n=10, device=None):
|
||
2 years ago
|
"""
|
||
|
YOLOv8 speed/memory/FLOPs profiler
|
||
|
|
||
2 years ago
|
Usage:
|
||
|
input = torch.randn(16, 3, 640, 640)
|
||
|
m1 = lambda x: x * torch.sigmoid(x)
|
||
|
m2 = nn.SiLU()
|
||
|
profile(input, [m1, m2], n=100) # profile over 100 iterations
|
||
|
"""
|
||
|
results = []
|
||
|
if not isinstance(device, torch.device):
|
||
|
device = select_device(device)
|
||
2 years ago
|
LOGGER.info(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
|
||
|
f"{'input':>24s}{'output':>24s}")
|
||
2 years ago
|
|
||
|
for x in input if isinstance(input, list) else [input]:
|
||
|
x = x.to(device)
|
||
|
x.requires_grad = True
|
||
|
for m in ops if isinstance(ops, list) else [ops]:
|
||
|
m = m.to(device) if hasattr(m, 'to') else m # device
|
||
|
m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
|
||
|
tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
|
||
|
try:
|
||
2 years ago
|
flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1E9 * 2 if thop else 0 # GFLOPs
|
||
2 years ago
|
except Exception:
|
||
|
flops = 0
|
||
|
|
||
|
try:
|
||
|
for _ in range(n):
|
||
|
t[0] = time_sync()
|
||
|
y = m(x)
|
||
|
t[1] = time_sync()
|
||
|
try:
|
||
|
_ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
|
||
|
t[2] = time_sync()
|
||
|
except Exception: # no backward method
|
||
|
# print(e) # for debug
|
||
|
t[2] = float('nan')
|
||
|
tf += (t[1] - t[0]) * 1000 / n # ms per op forward
|
||
|
tb += (t[2] - t[1]) * 1000 / n # ms per op backward
|
||
|
mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
|
||
|
s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' for x in (x, y)) # shapes
|
||
|
p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
|
||
2 years ago
|
LOGGER.info(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
|
||
2 years ago
|
results.append([p, flops, mem, tf, tb, s_in, s_out])
|
||
|
except Exception as e:
|
||
2 years ago
|
LOGGER.info(e)
|
||
2 years ago
|
results.append(None)
|
||
|
torch.cuda.empty_cache()
|
||
|
return results
|
||
2 years ago
|
|
||
|
|
||
|
class EarlyStopping:
|
||
2 years ago
|
"""
|
||
|
Early stopping class that stops training when a specified number of epochs have passed without improvement.
|
||
|
"""
|
||
|
|
||
2 years ago
|
def __init__(self, patience=50):
|
||
2 years ago
|
"""
|
||
|
Initialize early stopping object
|
||
|
|
||
|
Args:
|
||
2 years ago
|
patience (int, optional): Number of epochs to wait after fitness stops improving before stopping.
|
||
2 years ago
|
"""
|
||
2 years ago
|
self.best_fitness = 0.0 # i.e. mAP
|
||
|
self.best_epoch = 0
|
||
|
self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
|
||
|
self.possible_stop = False # possible stop may occur next epoch
|
||
|
|
||
|
def __call__(self, epoch, fitness):
|
||
2 years ago
|
"""
|
||
|
Check whether to stop training
|
||
|
|
||
|
Args:
|
||
|
epoch (int): Current epoch of training
|
||
|
fitness (float): Fitness value of current epoch
|
||
|
|
||
|
Returns:
|
||
2 years ago
|
(bool): True if training should stop, False otherwise
|
||
2 years ago
|
"""
|
||
|
if fitness is None: # check if fitness=None (happens when val=False)
|
||
|
return False
|
||
|
|
||
2 years ago
|
if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
|
||
|
self.best_epoch = epoch
|
||
|
self.best_fitness = fitness
|
||
|
delta = epoch - self.best_epoch # epochs without improvement
|
||
|
self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
|
||
|
stop = delta >= self.patience # stop training if patience exceeded
|
||
|
if stop:
|
||
|
LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. '
|
||
|
f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n'
|
||
|
f'To update EarlyStopping(patience={self.patience}) pass a new patience value, '
|
||
|
f'i.e. `patience=300` or use `patience=0` to disable EarlyStopping.')
|
||
|
return stop
|