From 99275814f125b3b26ca155b22e9b2d51ac8ea5e4 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 5 Jan 2023 13:49:16 +0100 Subject: [PATCH] Add AutoBatch from YOLOv5 (#145) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ultralytics/yolo/engine/trainer.py | 12 ++++- ultralytics/yolo/engine/validator.py | 2 +- ultralytics/yolo/utils/autobatch.py | 72 +++++++++++++++++++++++++++ ultralytics/yolo/utils/torch_utils.py | 51 +++++++++++++++++++ ultralytics/yolo/v8/detect/train.py | 11 +++- 5 files changed, 144 insertions(+), 4 deletions(-) create mode 100644 ultralytics/yolo/utils/autobatch.py diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index 4cbb4f0..30a15e0 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -27,6 +27,7 @@ from ultralytics.yolo.configs import get_config from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml from ultralytics.yolo.utils import (DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, yaml_save) +from ultralytics.yolo.utils.autobatch import check_train_batch_size from ultralytics.yolo.utils.checks import check_file, print_args from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command from ultralytics.yolo.utils.files import get_latest_run, increment_path @@ -135,7 +136,7 @@ class BaseTrainer: self.fitness = None self.loss = None self.tloss = None - self.loss_names = None + self.loss_names = ['Loss'] self.csv = self.save_dir / 'results.csv' self.plot_idx = [0, 1, 2] @@ -192,6 +193,15 @@ class BaseTrainer: self.set_model_attributes() if world_size > 1: self.model = DDP(self.model, device_ids=[rank]) + + # Batch size + if self.batch_size == -1: + if RANK == -1: # single-GPU only, estimate best batch size + self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp) + else: + SyntaxError('batch=-1 to use AutoBatch is only available in Single-GPU training. ' + 'Please pass a valid batch size value for Multi-GPU DDP training, i.e. batch=16') + # Optimizer self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing self.args.weight_decay *= self.batch_size * self.accumulate / self.args.nbs # scale weight_decay diff --git a/ultralytics/yolo/engine/validator.py b/ultralytics/yolo/engine/validator.py index 055599d..8ca8d04 100644 --- a/ultralytics/yolo/engine/validator.py +++ b/ultralytics/yolo/engine/validator.py @@ -78,7 +78,7 @@ class BaseValidator: self.device = trainer.device self.data = trainer.data model = trainer.ema.ema or trainer.model - self.args.half &= self.device.type != 'cpu' + self.args.half = self.device.type != 'cpu' # force FP16 val during training model = model.half() if self.args.half else model.float() self.model = model self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device) diff --git a/ultralytics/yolo/utils/autobatch.py b/ultralytics/yolo/utils/autobatch.py new file mode 100644 index 0000000..7181cdb --- /dev/null +++ b/ultralytics/yolo/utils/autobatch.py @@ -0,0 +1,72 @@ +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license +""" +Auto-batch utils +""" + +from copy import deepcopy + +import numpy as np +import torch + +from ultralytics.yolo.utils import LOGGER, colorstr +from ultralytics.yolo.utils.torch_utils import profile + + +def check_train_batch_size(model, imgsz=640, amp=True): + # Check YOLOv5 training batch size + with torch.cuda.amp.autocast(amp): + return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size + + +def autobatch(model, imgsz=640, fraction=0.7, batch_size=16): + # Automatically estimate best YOLOv5 batch size to use `fraction` of available CUDA memory + # Usage: + # import torch + # from utils.autobatch import autobatch + # model = torch.hub.load('ultralytics/yolov5', 'yolov5s', autoshape=False) + # print(autobatch(model)) + + # Check device + prefix = colorstr('AutoBatch: ') + LOGGER.info(f'{prefix}Computing optimal batch size for --imgsz {imgsz}') + device = next(model.parameters()).device # get model device + if device.type == 'cpu': + LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}') + return batch_size + if torch.backends.cudnn.benchmark: + LOGGER.info(f'{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}') + return batch_size + + # Inspect CUDA memory + gb = 1 << 30 # bytes to GiB (1024 ** 3) + d = str(device).upper() # 'CUDA:0' + properties = torch.cuda.get_device_properties(device) # device properties + t = properties.total_memory / gb # GiB total + r = torch.cuda.memory_reserved(device) / gb # GiB reserved + a = torch.cuda.memory_allocated(device) / gb # GiB allocated + f = t - (r + a) # GiB free + LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free') + + # Profile batch sizes + batch_sizes = [1, 2, 4, 8, 16] + try: + img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes] + results = profile(img, model, n=3, device=device) + except Exception as e: + LOGGER.warning(f'{prefix}{e}') + + # Fit a solution + y = [x[2] for x in results if x] # memory [2] + p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit + b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size) + if None in results: # some sizes failed + i = results.index(None) # first fail index + if b >= batch_sizes[i]: # y intercept above failure point + b = batch_sizes[max(i - 1, 0)] # select prior safe point + if b < 1 or b > 1024: # b outside of safe range + b = batch_size + LOGGER.warning(f'{prefix}WARNING ⚠️ CUDA anomaly detected, recommend restart environment and retry command.') + + fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted + LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅') + return b diff --git a/ultralytics/yolo/utils/torch_utils.py b/ultralytics/yolo/utils/torch_utils.py index 5bdb1bd..89fe589 100644 --- a/ultralytics/yolo/utils/torch_utils.py +++ b/ultralytics/yolo/utils/torch_utils.py @@ -299,3 +299,54 @@ def guess_task_from_head(head): raise SyntaxError("task or model not recognized! Please refer the docs at : ") # TODO: add docs links return task + + +def profile(input, ops, n=10, device=None): + """ YOLOv5 speed/memory/FLOPs profiler + Usage: + input = torch.randn(16, 3, 640, 640) + m1 = lambda x: x * torch.sigmoid(x) + m2 = nn.SiLU() + profile(input, [m1, m2], n=100) # profile over 100 iterations + """ + results = [] + if not isinstance(device, torch.device): + device = select_device(device) + print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}" + f"{'input':>24s}{'output':>24s}") + + for x in input if isinstance(input, list) else [input]: + x = x.to(device) + x.requires_grad = True + for m in ops if isinstance(ops, list) else [ops]: + m = m.to(device) if hasattr(m, 'to') else m # device + m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m + tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward + try: + flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs + except Exception: + flops = 0 + + try: + for _ in range(n): + t[0] = time_sync() + y = m(x) + t[1] = time_sync() + try: + _ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward() + t[2] = time_sync() + except Exception: # no backward method + # print(e) # for debug + t[2] = float('nan') + tf += (t[1] - t[0]) * 1000 / n # ms per op forward + tb += (t[2] - t[1]) * 1000 / n # ms per op backward + mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB) + s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' for x in (x, y)) # shapes + p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters + print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}') + results.append([p, flops, mem, tf, tb, s_in, s_out]) + except Exception as e: + print(e) + results.append(None) + torch.cuda.empty_cache() + return results diff --git a/ultralytics/yolo/v8/detect/train.py b/ultralytics/yolo/v8/detect/train.py index f4aa037..f8a5e17 100644 --- a/ultralytics/yolo/v8/detect/train.py +++ b/ultralytics/yolo/v8/detect/train.py @@ -74,9 +74,16 @@ class DetectionTrainer(BaseTrainer): return self.compute_loss(preds, batch) def label_loss_items(self, loss_items=None, prefix="train"): - # We should just use named tensors here in future + """ + Returns a loss dict with labelled training loss items tensor + """ + # Not needed for classification but necessary for segmentation & detection keys = [f"{prefix}/{x}" for x in self.loss_names] - return dict(zip(keys, loss_items)) if loss_items is not None else keys + if loss_items is not None: + loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats + return dict(zip(keys, loss_items)) + else: + return keys def progress_string(self): return ('\n' + '%11s' *