Add AutoBatch from YOLOv5 (#145)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -27,6 +27,7 @@ from ultralytics.yolo.configs import get_config
|
||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||
from ultralytics.yolo.utils import (DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr,
|
||||
yaml_save)
|
||||
from ultralytics.yolo.utils.autobatch import check_train_batch_size
|
||||
from ultralytics.yolo.utils.checks import check_file, print_args
|
||||
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
|
||||
from ultralytics.yolo.utils.files import get_latest_run, increment_path
|
||||
@ -135,7 +136,7 @@ class BaseTrainer:
|
||||
self.fitness = None
|
||||
self.loss = None
|
||||
self.tloss = None
|
||||
self.loss_names = None
|
||||
self.loss_names = ['Loss']
|
||||
self.csv = self.save_dir / 'results.csv'
|
||||
self.plot_idx = [0, 1, 2]
|
||||
|
||||
@ -192,6 +193,15 @@ class BaseTrainer:
|
||||
self.set_model_attributes()
|
||||
if world_size > 1:
|
||||
self.model = DDP(self.model, device_ids=[rank])
|
||||
|
||||
# Batch size
|
||||
if self.batch_size == -1:
|
||||
if RANK == -1: # single-GPU only, estimate best batch size
|
||||
self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp)
|
||||
else:
|
||||
SyntaxError('batch=-1 to use AutoBatch is only available in Single-GPU training. '
|
||||
'Please pass a valid batch size value for Multi-GPU DDP training, i.e. batch=16')
|
||||
|
||||
# Optimizer
|
||||
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
|
||||
self.args.weight_decay *= self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
|
||||
|
@ -78,7 +78,7 @@ class BaseValidator:
|
||||
self.device = trainer.device
|
||||
self.data = trainer.data
|
||||
model = trainer.ema.ema or trainer.model
|
||||
self.args.half &= self.device.type != 'cpu'
|
||||
self.args.half = self.device.type != 'cpu' # force FP16 val during training
|
||||
model = model.half() if self.args.half else model.float()
|
||||
self.model = model
|
||||
self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
|
||||
|
72
ultralytics/yolo/utils/autobatch.py
Normal file
72
ultralytics/yolo/utils/autobatch.py
Normal file
@ -0,0 +1,72 @@
|
||||
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
||||
"""
|
||||
Auto-batch utils
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER, colorstr
|
||||
from ultralytics.yolo.utils.torch_utils import profile
|
||||
|
||||
|
||||
def check_train_batch_size(model, imgsz=640, amp=True):
|
||||
# Check YOLOv5 training batch size
|
||||
with torch.cuda.amp.autocast(amp):
|
||||
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
|
||||
|
||||
|
||||
def autobatch(model, imgsz=640, fraction=0.7, batch_size=16):
|
||||
# Automatically estimate best YOLOv5 batch size to use `fraction` of available CUDA memory
|
||||
# Usage:
|
||||
# import torch
|
||||
# from utils.autobatch import autobatch
|
||||
# model = torch.hub.load('ultralytics/yolov5', 'yolov5s', autoshape=False)
|
||||
# print(autobatch(model))
|
||||
|
||||
# Check device
|
||||
prefix = colorstr('AutoBatch: ')
|
||||
LOGGER.info(f'{prefix}Computing optimal batch size for --imgsz {imgsz}')
|
||||
device = next(model.parameters()).device # get model device
|
||||
if device.type == 'cpu':
|
||||
LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
|
||||
return batch_size
|
||||
if torch.backends.cudnn.benchmark:
|
||||
LOGGER.info(f'{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}')
|
||||
return batch_size
|
||||
|
||||
# Inspect CUDA memory
|
||||
gb = 1 << 30 # bytes to GiB (1024 ** 3)
|
||||
d = str(device).upper() # 'CUDA:0'
|
||||
properties = torch.cuda.get_device_properties(device) # device properties
|
||||
t = properties.total_memory / gb # GiB total
|
||||
r = torch.cuda.memory_reserved(device) / gb # GiB reserved
|
||||
a = torch.cuda.memory_allocated(device) / gb # GiB allocated
|
||||
f = t - (r + a) # GiB free
|
||||
LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free')
|
||||
|
||||
# Profile batch sizes
|
||||
batch_sizes = [1, 2, 4, 8, 16]
|
||||
try:
|
||||
img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
|
||||
results = profile(img, model, n=3, device=device)
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'{prefix}{e}')
|
||||
|
||||
# Fit a solution
|
||||
y = [x[2] for x in results if x] # memory [2]
|
||||
p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit
|
||||
b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
|
||||
if None in results: # some sizes failed
|
||||
i = results.index(None) # first fail index
|
||||
if b >= batch_sizes[i]: # y intercept above failure point
|
||||
b = batch_sizes[max(i - 1, 0)] # select prior safe point
|
||||
if b < 1 or b > 1024: # b outside of safe range
|
||||
b = batch_size
|
||||
LOGGER.warning(f'{prefix}WARNING ⚠️ CUDA anomaly detected, recommend restart environment and retry command.')
|
||||
|
||||
fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted
|
||||
LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅')
|
||||
return b
|
@ -299,3 +299,54 @@ def guess_task_from_head(head):
|
||||
raise SyntaxError("task or model not recognized! Please refer the docs at : ") # TODO: add docs links
|
||||
|
||||
return task
|
||||
|
||||
|
||||
def profile(input, ops, n=10, device=None):
|
||||
""" YOLOv5 speed/memory/FLOPs profiler
|
||||
Usage:
|
||||
input = torch.randn(16, 3, 640, 640)
|
||||
m1 = lambda x: x * torch.sigmoid(x)
|
||||
m2 = nn.SiLU()
|
||||
profile(input, [m1, m2], n=100) # profile over 100 iterations
|
||||
"""
|
||||
results = []
|
||||
if not isinstance(device, torch.device):
|
||||
device = select_device(device)
|
||||
print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
|
||||
f"{'input':>24s}{'output':>24s}")
|
||||
|
||||
for x in input if isinstance(input, list) else [input]:
|
||||
x = x.to(device)
|
||||
x.requires_grad = True
|
||||
for m in ops if isinstance(ops, list) else [ops]:
|
||||
m = m.to(device) if hasattr(m, 'to') else m # device
|
||||
m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
|
||||
tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
|
||||
try:
|
||||
flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs
|
||||
except Exception:
|
||||
flops = 0
|
||||
|
||||
try:
|
||||
for _ in range(n):
|
||||
t[0] = time_sync()
|
||||
y = m(x)
|
||||
t[1] = time_sync()
|
||||
try:
|
||||
_ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
|
||||
t[2] = time_sync()
|
||||
except Exception: # no backward method
|
||||
# print(e) # for debug
|
||||
t[2] = float('nan')
|
||||
tf += (t[1] - t[0]) * 1000 / n # ms per op forward
|
||||
tb += (t[2] - t[1]) * 1000 / n # ms per op backward
|
||||
mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
|
||||
s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' for x in (x, y)) # shapes
|
||||
p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
|
||||
print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
|
||||
results.append([p, flops, mem, tf, tb, s_in, s_out])
|
||||
except Exception as e:
|
||||
print(e)
|
||||
results.append(None)
|
||||
torch.cuda.empty_cache()
|
||||
return results
|
||||
|
@ -74,9 +74,16 @@ class DetectionTrainer(BaseTrainer):
|
||||
return self.compute_loss(preds, batch)
|
||||
|
||||
def label_loss_items(self, loss_items=None, prefix="train"):
|
||||
# We should just use named tensors here in future
|
||||
"""
|
||||
Returns a loss dict with labelled training loss items tensor
|
||||
"""
|
||||
# Not needed for classification but necessary for segmentation & detection
|
||||
keys = [f"{prefix}/{x}" for x in self.loss_names]
|
||||
return dict(zip(keys, loss_items)) if loss_items is not None else keys
|
||||
if loss_items is not None:
|
||||
loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
|
||||
return dict(zip(keys, loss_items))
|
||||
else:
|
||||
return keys
|
||||
|
||||
def progress_string(self):
|
||||
return ('\n' + '%11s' *
|
||||
|
Reference in New Issue
Block a user