From 3633d4c06b806a971f444919c389ac6e889cdb23 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 7 Feb 2023 02:37:13 +0400 Subject: [PATCH] TAL `min_memory` argument, precommit, Docker fixes (#836) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jaap van de Loosdrecht --- .pre-commit-config.yaml | 4 +- docker/Dockerfile-arm64 | 2 +- ultralytics/yolo/cfg/default.yaml | 1 + ultralytics/yolo/engine/exporter.py | 25 ++++++++---- ultralytics/yolo/utils/tal.py | 62 +++++++++++++++++++++-------- ultralytics/yolo/v8/detect/train.py | 13 ++++-- 6 files changed, 76 insertions(+), 31 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 88c6c98..6af0efb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,7 +37,7 @@ repos: # - id: isort # name: Sort imports - - repo: https://github.com/pre-commit/mirrors-yapf + - repo: https://github.com/google/yapf rev: v0.32.0 hooks: - id: yapf @@ -54,7 +54,7 @@ repos: # exclude: "README.md|README.zh-CN.md|CONTRIBUTING.md" - repo: https://github.com/PyCQA/flake8 - rev: 5.0.4 + rev: 6.0.0 hooks: - id: flake8 name: PEP8 diff --git a/docker/Dockerfile-arm64 b/docker/Dockerfile-arm64 index ad4d9e1..44e594b 100644 --- a/docker/Dockerfile-arm64 +++ b/docker/Dockerfile-arm64 @@ -27,7 +27,7 @@ RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics COPY requirements.txt . RUN python3 -m pip install --upgrade pip wheel RUN pip install --no-cache ultralytics albumentations gsutil notebook \ - coremltools onnx onnx-simplifier onnxruntime openvino-dev>=2022.3 + coremltools onnx onnxruntime # tensorflow-aarch64 tensorflowjs \ # Cleanup diff --git a/ultralytics/yolo/cfg/default.yaml b/ultralytics/yolo/cfg/default.yaml index 5be02eb..4f0a43e 100644 --- a/ultralytics/yolo/cfg/default.yaml +++ b/ultralytics/yolo/cfg/default.yaml @@ -29,6 +29,7 @@ rect: False # support rectangular training if mode='train', support rectangular cos_lr: False # use cosine learning rate scheduler close_mosaic: 10 # disable mosaic augmentation for final 10 epochs resume: False # resume training from last checkpoint +min_memory: False # minimize memory footprint loss function, choices=[False, True, ] # Segmentation overlap_mask: True # masks should overlap during training (segment train only) mask_ratio: 4 # mask downsample ratio (segment train only) diff --git a/ultralytics/yolo/engine/exporter.py b/ultralytics/yolo/engine/exporter.py index 681f2f1..a70a68a 100644 --- a/ultralytics/yolo/engine/exporter.py +++ b/ultralytics/yolo/engine/exporter.py @@ -82,13 +82,19 @@ MACOS = platform.system() == 'Darwin' # macOS environment def export_formats(): # YOLOv8 export formats - x = [['PyTorch', '-', '.pt', True, True], ['TorchScript', 'torchscript', '.torchscript', True, True], - ['ONNX', 'onnx', '.onnx', True, True], ['OpenVINO', 'openvino', '_openvino_model', True, False], - ['TensorRT', 'engine', '.engine', False, True], ['CoreML', 'coreml', '.mlmodel', True, False], - ['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True], - ['TensorFlow GraphDef', 'pb', '.pb', True, True], ['TensorFlow Lite', 'tflite', '.tflite', True, False], - ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False, False], - ['TensorFlow.js', 'tfjs', '_web_model', False, False], ['PaddlePaddle', 'paddle', '_paddle_model', True, True]] + x = [ + ['PyTorch', '-', '.pt', True, True], + ['TorchScript', 'torchscript', '.torchscript', True, True], + ['ONNX', 'onnx', '.onnx', True, True], + ['OpenVINO', 'openvino', '_openvino_model', True, False], + ['TensorRT', 'engine', '.engine', False, True], + ['CoreML', 'coreml', '.mlmodel', True, False], + ['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True], + ['TensorFlow GraphDef', 'pb', '.pb', True, True], + ['TensorFlow Lite', 'tflite', '.tflite', True, False], + ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False, False], + ['TensorFlow.js', 'tfjs', '_web_model', False, False], + ['PaddlePaddle', 'paddle', '_paddle_model', True, True],] return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU']) @@ -138,9 +144,12 @@ class Exporter: self.run_callbacks("on_export_start") t = time.time() format = self.args.format.lower() # to lowercase + if format in {'tensorrt', 'trt'}: # engine aliases + format = 'engine' fmts = tuple(export_formats()['Argument'][1:]) # available export formats flags = [x == format for x in fmts] - assert sum(flags), f'ERROR: Invalid format={format}, valid formats are {fmts}' + if sum(flags) != 1: + raise ValueError(f"Invalid export format='{format}'. Valid formats are {fmts}") jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle = flags # export booleans # Load PyTorch model diff --git a/ultralytics/yolo/utils/tal.py b/ultralytics/yolo/utils/tal.py index 98481ad..45b4c16 100644 --- a/ultralytics/yolo/utils/tal.py +++ b/ultralytics/yolo/utils/tal.py @@ -10,7 +10,7 @@ from .metrics import bbox_iou TORCH_1_10 = check_version(torch.__version__, '1.10.0') -def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9): +def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9, roll_out=False): """select the positive anchor center in gt Args: @@ -21,10 +21,18 @@ def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9): """ n_anchors = xy_centers.shape[0] bs, n_boxes, _ = gt_bboxes.shape - lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom - bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1) - # return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype) - return bbox_deltas.amin(3).gt_(eps) + if roll_out: + bbox_deltas = torch.empty((bs, n_boxes, n_anchors), device=gt_bboxes.device) + for b in range(bs): + lt, rb = gt_bboxes[b].view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom + bbox_deltas[b] = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), + dim=2).view(n_boxes, n_anchors, -1).amin(2).gt_(eps) + return bbox_deltas + else: + lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom + bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1) + # return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype) + return bbox_deltas.amin(3).gt_(eps) def select_highest_overlaps(mask_pos, overlaps, n_max_boxes): @@ -55,7 +63,7 @@ def select_highest_overlaps(mask_pos, overlaps, n_max_boxes): class TaskAlignedAssigner(nn.Module): - def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9): + def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9, roll_out_thr=0): super().__init__() self.topk = topk self.num_classes = num_classes @@ -63,6 +71,7 @@ class TaskAlignedAssigner(nn.Module): self.alpha = alpha self.beta = beta self.eps = eps + self.roll_out_thr = roll_out_thr @torch.no_grad() def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt): @@ -84,6 +93,7 @@ class TaskAlignedAssigner(nn.Module): """ self.bs = pd_scores.size(0) self.n_max_boxes = gt_bboxes.size(1) + self.roll_out = self.n_max_boxes > self.roll_out_thr if self.roll_out_thr else False if self.n_max_boxes == 0: device = gt_bboxes.device @@ -112,7 +122,7 @@ class TaskAlignedAssigner(nn.Module): # get anchor_align metric, (b, max_num_obj, h*w) align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes) # get in_gts mask, (b, max_num_obj, h*w) - mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes) + mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes, roll_out=self.roll_out) # get topk_metric mask, (b, max_num_obj, h*w) mask_topk = self.select_topk_candidates(align_metric * mask_in_gts, topk_mask=mask_gt.repeat([1, 1, self.topk]).bool()) @@ -122,14 +132,27 @@ class TaskAlignedAssigner(nn.Module): return mask_pos, align_metric, overlaps def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes): - ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj - ind[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.n_max_boxes) # b, max_num_obj - ind[1] = gt_labels.long().squeeze(-1) # b, max_num_obj - # get the scores of each grid for each gt cls - bbox_scores = pd_scores[ind[0], :, ind[1]] # b, max_num_obj, h*w - - overlaps = bbox_iou(gt_bboxes.unsqueeze(2), pd_bboxes.unsqueeze(1), xywh=False, CIoU=True).squeeze(3).clamp(0) - align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta) + if self.roll_out: + align_metric = torch.empty((self.bs, self.n_max_boxes, pd_scores.shape[1]), device=pd_scores.device) + overlaps = torch.empty((self.bs, self.n_max_boxes, pd_scores.shape[1]), device=pd_scores.device) + ind_0 = torch.empty(self.n_max_boxes, dtype=torch.long) + for b in range(self.bs): + ind_0[:], ind_2 = b, gt_labels[b].squeeze(-1).long() + # get the scores of each grid for each gt cls + bbox_scores = pd_scores[ind_0, :, ind_2] # b, max_num_obj, h*w + overlaps[b] = bbox_iou(gt_bboxes[b].unsqueeze(1), pd_bboxes[b].unsqueeze(0), xywh=False, + CIoU=True).squeeze(2).clamp(0) + align_metric[b] = bbox_scores.pow(self.alpha) * overlaps[b].pow(self.beta) + else: + ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj + ind[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.n_max_boxes) # b, max_num_obj + ind[1] = gt_labels.long().squeeze(-1) # b, max_num_obj + # get the scores of each grid for each gt cls + bbox_scores = pd_scores[ind[0], :, ind[1]] # b, max_num_obj, h*w + + overlaps = bbox_iou(gt_bboxes.unsqueeze(2), pd_bboxes.unsqueeze(1), xywh=False, + CIoU=True).squeeze(3).clamp(0) + align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta) return align_metric, overlaps def select_topk_candidates(self, metrics, largest=True, topk_mask=None): @@ -145,9 +168,14 @@ class TaskAlignedAssigner(nn.Module): if topk_mask is None: topk_mask = (topk_metrics.max(-1, keepdim=True) > self.eps).tile([1, 1, self.topk]) # (b, max_num_obj, topk) - topk_idxs = torch.where(topk_mask, topk_idxs, 0) + topk_idxs[~topk_mask] = 0 # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w) - is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(-2) + if self.roll_out: + is_in_topk = torch.empty(metrics.shape, dtype=torch.long, device=metrics.device) + for b in range(len(topk_idxs)): + is_in_topk[b] = F.one_hot(topk_idxs[b], num_anchors).sum(-2) + else: + is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(-2) # filter invalid bboxes is_in_topk = torch.where(is_in_topk > 1, 0, is_in_topk) return is_in_topk.to(metrics.dtype) diff --git a/ultralytics/yolo/v8/detect/train.py b/ultralytics/yolo/v8/detect/train.py index ce15ba4..c5b3a8b 100644 --- a/ultralytics/yolo/v8/detect/train.py +++ b/ultralytics/yolo/v8/detect/train.py @@ -33,14 +33,15 @@ class DetectionTrainer(BaseTrainer): augment=mode == "train", cache=self.args.cache, pad=0 if mode == "train" else 0.5, - rect=self.args.rect or mode=="val", + rect=self.args.rect or mode == "val", rank=rank, workers=self.args.workers, close_mosaic=self.args.close_mosaic != 0, prefix=colorstr(f'{mode}: '), shuffle=mode == "train", seed=self.args.seed)[0] if self.args.v5loader else \ - build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, rank=rank, mode=mode, rect=mode=="val")[0] + build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, rank=rank, mode=mode, + rect=mode == "val")[0] def preprocess_batch(self, batch): batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255 @@ -121,7 +122,13 @@ class Loss: self.device = device self.use_dfl = m.reg_max > 1 - self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0) + roll_out_thr = h.min_memory if h.min_memory > 1 else 64 if h.min_memory else 0 # 64 is default + + self.assigner = TaskAlignedAssigner(topk=10, + num_classes=self.nc, + alpha=0.5, + beta=6.0, + roll_out_thr=roll_out_thr) self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device) self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)