From 400f3f72a1437b6c64e644337a33d5bf7076dfd5 Mon Sep 17 00:00:00 2001 From: Wenchao Ding <40165666+berry-ding@users.noreply.github.com> Date: Thu, 6 Jul 2023 06:16:22 +0800 Subject: [PATCH] `ultralytics 8.0.127` add FastSAM model (#3390) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: dingwenchao <12962189468@163.com> Co-authored-by: 丁文超 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher Co-authored-by: Ayush Chaurasia --- docs/models/fast-sam.md | 67 ++++- ultralytics/__init__.py | 5 +- ultralytics/yolo/fastsam/__init__.py | 8 + ultralytics/yolo/fastsam/model.py | 104 +++++++ ultralytics/yolo/fastsam/predict.py | 51 ++++ ultralytics/yolo/fastsam/prompt.py | 406 +++++++++++++++++++++++++++ ultralytics/yolo/fastsam/utils.py | 63 +++++ ultralytics/yolo/fastsam/val.py | 244 ++++++++++++++++ 8 files changed, 942 insertions(+), 6 deletions(-) create mode 100644 ultralytics/yolo/fastsam/__init__.py create mode 100644 ultralytics/yolo/fastsam/model.py create mode 100644 ultralytics/yolo/fastsam/predict.py create mode 100644 ultralytics/yolo/fastsam/prompt.py create mode 100644 ultralytics/yolo/fastsam/utils.py create mode 100644 ultralytics/yolo/fastsam/val.py diff --git a/docs/models/fast-sam.md b/docs/models/fast-sam.md index 25cb986..a3e0dbc 100644 --- a/docs/models/fast-sam.md +++ b/docs/models/fast-sam.md @@ -32,9 +32,68 @@ FastSAM is designed to address the limitations of the [Segment Anything Model (S ## Usage -FastSAM is not yet available within the [`ultralytics` package](../quickstart.md), but it is available directly from the [https://github.com/CASIA-IVA-Lab/FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) repository. Here is a brief overview of the typical steps you might take to use FastSAM: +### Python API -### Installation +The FastSAM models are easy to integrate into your Python applications. Ultralytics provides a user-friendly Python API to streamline the process. + +#### Predict Usage + +To perform object detection on an image, use the `predict` method as shown below: + +```python +from ultralytics import FastSAM +from ultralytics.yolo.fastsam import FastSAMPrompt + +IMAGE_PATH = 'images/dog.jpg' +DEVICE = 'cpu' +model = FastSAM('FastSAM.pt') +results = model( + IMAGE_PATH, + device=DEVICE, + retina_masks=True, + imgsz=1024, + conf=0.4, + iou=0.9, +) + +prompt_process = FastSAMPrompt(IMAGE_PATH, everything_results, device=DEVICE) + +# Everything prompt +ann = prompt_process.everything_prompt() + +# Bbox default shape [0,0,0,0] -> [x1,y1,x2,y2] +ann = prompt_process.box_prompt(bbox=[200, 200, 300, 300]) + +# Text prompt +ann = prompt_process.text_prompt(text='a photo of a dog') + +# Point prompt +# points default [[0,0]] [[x1,y1],[x2,y2]] +# point_label default [0] [1,0] 0:background, 1:foreground +ann = prompt_process.point_prompt(points=[[200, 200]], pointlabel=[1]) +prompt_process.plot(annotations=ann, output='./') +``` + +This snippet demonstrates the simplicity of loading a pre-trained model and running a prediction on an image. + +#### Val Usage + +Validation of the model on a dataset can be done as follows: + +```python +from ultralytics import FastSAM + +model = FastSAM('FastSAM.pt') +results = model.val(data='coco8-seg.yaml) +``` + +Please note that FastSAM only supports detection and segmentation of a single class of object. This means it will recognize and segment all objects as the same class. Therefore, when preparing the dataset, you need to convert all object category IDs to 0. + +### FastSAM official Usage + +FastSAM is also available directly from the [https://github.com/CASIA-IVA-Lab/FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) repository. Here is a brief overview of the typical steps you might take to use FastSAM: + +#### Installation 1. Clone the FastSAM repository: ```shell @@ -58,7 +117,7 @@ FastSAM is not yet available within the [`ultralytics` package](../quickstart.md pip install git+https://github.com/openai/CLIP.git ``` -### Example Usage +#### Example Usage 1. Download a [model checkpoint](https://drive.google.com/file/d/1m1sjY4ihXBU1fZXdQ-Xdj-mDltW-2Rqv/view?usp=sharing). @@ -101,4 +160,4 @@ We would like to acknowledge the FastSAM authors for their significant contribut } ``` -The original FastSAM paper can be found on [arXiv](https://arxiv.org/abs/2306.12156). The authors have made their work publicly available, and the codebase can be accessed on [GitHub](https://github.com/CASIA-IVA-Lab/FastSAM). We appreciate their efforts in advancing the field and making their work accessible to the broader community. \ No newline at end of file +The original FastSAM paper can be found on [arXiv](https://arxiv.org/abs/2306.12156). The authors have made their work publicly available, and the codebase can be accessed on [GitHub](https://github.com/CASIA-IVA-Lab/FastSAM). We appreciate their efforts in advancing the field and making their work accessible to the broader community. diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index c86adef..d9c6c24 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,13 +1,14 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = '8.0.126' +__version__ = '8.0.127' from ultralytics.hub import start from ultralytics.vit.rtdetr import RTDETR from ultralytics.vit.sam import SAM from ultralytics.yolo.engine.model import YOLO +from ultralytics.yolo.fastsam import FastSAM from ultralytics.yolo.nas import NAS from ultralytics.yolo.utils.checks import check_yolo as checks from ultralytics.yolo.utils.downloads import download -__all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'RTDETR', 'checks', 'start', 'download' # allow simpler import +__all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'RTDETR', 'checks', 'start', 'download', 'FastSAM' # allow simpler import diff --git a/ultralytics/yolo/fastsam/__init__.py b/ultralytics/yolo/fastsam/__init__.py new file mode 100644 index 0000000..8f47772 --- /dev/null +++ b/ultralytics/yolo/fastsam/__init__.py @@ -0,0 +1,8 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from .model import FastSAM +from .predict import FastSAMPredictor +from .prompt import FastSAMPrompt +from .val import FastSAMValidator + +__all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMValidator' diff --git a/ultralytics/yolo/fastsam/model.py b/ultralytics/yolo/fastsam/model.py new file mode 100644 index 0000000..5bec2b2 --- /dev/null +++ b/ultralytics/yolo/fastsam/model.py @@ -0,0 +1,104 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +""" +FastSAM model interface. + +Usage - Predict: + from ultralytics import FastSAM + + model = FastSAM('last.pt') + results = model.predict('ultralytics/assets/bus.jpg') +""" + +from ultralytics.yolo.cfg import get_cfg +from ultralytics.yolo.engine.exporter import Exporter +from ultralytics.yolo.engine.model import YOLO +from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir +from ultralytics.yolo.utils.checks import check_imgsz + +from ...yolo.utils.torch_utils import model_info, smart_inference_mode +from .predict import FastSAMPredictor + + +class FastSAM(YOLO): + + @smart_inference_mode() + def predict(self, source=None, stream=False, **kwargs): + """ + Perform prediction using the YOLO model. + + Args: + source (str | int | PIL | np.ndarray): The source of the image to make predictions on. + Accepts all source types accepted by the YOLO model. + stream (bool): Whether to stream the predictions or not. Defaults to False. + **kwargs : Additional keyword arguments passed to the predictor. + Check the 'configuration' section in the documentation for all available options. + + Returns: + (List[ultralytics.yolo.engine.results.Results]): The prediction results. + """ + if source is None: + source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg' + LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") + overrides = self.overrides.copy() + overrides['conf'] = 0.25 + overrides.update(kwargs) # prefer kwargs + overrides['mode'] = kwargs.get('mode', 'predict') + assert overrides['mode'] in ['track', 'predict'] + overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python + self.predictor = FastSAMPredictor(overrides=overrides) + self.predictor.setup_model(model=self.model, verbose=False) + + return self.predictor(source, stream=stream) + + def train(self, **kwargs): + """Function trains models but raises an error as FastSAM models do not support training.""" + raise NotImplementedError("FastSAM models don't support training") + + def val(self, **kwargs): + """Run validation given dataset.""" + overrides = dict(task='segment', mode='val') + overrides.update(kwargs) # prefer kwargs + args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) + args.imgsz = check_imgsz(args.imgsz, max_dim=1) + validator = FastSAM(args=args) + validator(model=self.model) + self.metrics = validator.metrics + return validator.metrics + + @smart_inference_mode() + def export(self, **kwargs): + """ + Export model. + + Args: + **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs + """ + overrides = dict(task='detect') + overrides.update(kwargs) + overrides['mode'] = 'export' + args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) + args.task = self.task + if args.imgsz == DEFAULT_CFG.imgsz: + args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed + if args.batch == DEFAULT_CFG.batch: + args.batch = 1 # default to 1 if not modified + return Exporter(overrides=args)(model=self.model) + + def info(self, detailed=False, verbose=True): + """ + Logs model info. + + Args: + detailed (bool): Show detailed information about model. + verbose (bool): Controls verbosity. + """ + return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640) + + def __call__(self, source=None, stream=False, **kwargs): + """Calls the 'predict' function with given arguments to perform object detection.""" + return self.predict(source, stream, **kwargs) + + def __getattr__(self, attr): + """Raises error if object has no requested attribute.""" + name = self.__class__.__name__ + raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") diff --git a/ultralytics/yolo/fastsam/predict.py b/ultralytics/yolo/fastsam/predict.py new file mode 100644 index 0000000..613b072 --- /dev/null +++ b/ultralytics/yolo/fastsam/predict.py @@ -0,0 +1,51 @@ +import torch + +from ultralytics.yolo.engine.results import Results +from ultralytics.yolo.fastsam.utils import bbox_iou +from ultralytics.yolo.utils import DEFAULT_CFG, ops +from ultralytics.yolo.v8.detect.predict import DetectionPredictor + + +class FastSAMPredictor(DetectionPredictor): + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + super().__init__(cfg, overrides, _callbacks) + self.args.task = 'segment' + + def postprocess(self, preds, img, orig_imgs): + """TODO: filter by classes.""" + p = ops.non_max_suppression(preds[0], + self.args.conf, + self.args.iou, + agnostic=self.args.agnostic_nms, + max_det=self.args.max_det, + nc=len(self.model.names), + classes=self.args.classes) + full_box = torch.zeros_like(p[0][0]) + full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0 + full_box = full_box.view(1, -1) + critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:]) + if critical_iou_index.numel() != 0: + full_box[0][4] = p[0][critical_iou_index][:, 4] + full_box[0][6:] = p[0][critical_iou_index][:, 6:] + p[0][critical_iou_index] = full_box + results = [] + proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported + for i, pred in enumerate(p): + orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs + path = self.batch[0] + img_path = path[i] if isinstance(path, list) else path + if not len(pred): # save empty boxes + results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])) + continue + if self.args.retina_masks: + if not isinstance(orig_imgs, torch.Tensor): + pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) + masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC + else: + masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC + if not isinstance(orig_imgs, torch.Tensor): + pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) + results.append( + Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)) + return results diff --git a/ultralytics/yolo/fastsam/prompt.py b/ultralytics/yolo/fastsam/prompt.py new file mode 100644 index 0000000..7999a6d --- /dev/null +++ b/ultralytics/yolo/fastsam/prompt.py @@ -0,0 +1,406 @@ +import os + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import torch +from PIL import Image + +try: + import clip # for linear_assignment + +except (ImportError, AssertionError, AttributeError): + from ultralytics.yolo.utils.checks import check_requirements + + check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source + import clip + + +class FastSAMPrompt: + + def __init__(self, img_path, results, device='cuda') -> None: + # self.img_path = img_path + self.device = device + self.results = results + self.img_path = img_path + self.ori_img = cv2.imread(img_path) + + def _segment_image(self, image, bbox): + image_array = np.array(image) + segmented_image_array = np.zeros_like(image_array) + x1, y1, x2, y2 = bbox + segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2] + segmented_image = Image.fromarray(segmented_image_array) + black_image = Image.new('RGB', image.size, (255, 255, 255)) + # transparency_mask = np.zeros_like((), dtype=np.uint8) + transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8) + transparency_mask[y1:y2, x1:x2] = 255 + transparency_mask_image = Image.fromarray(transparency_mask, mode='L') + black_image.paste(segmented_image, mask=transparency_mask_image) + return black_image + + def _format_results(self, result, filter=0): + annotations = [] + n = len(result.masks.data) + for i in range(n): + annotation = {} + mask = result.masks.data[i] == 1.0 + + if torch.sum(mask) < filter: + continue + annotation['id'] = i + annotation['segmentation'] = mask.cpu().numpy() + annotation['bbox'] = result.boxes.data[i] + annotation['score'] = result.boxes.conf[i] + annotation['area'] = annotation['segmentation'].sum() + annotations.append(annotation) + return annotations + + def filter_masks(annotations): # filte the overlap mask + annotations.sort(key=lambda x: x['area'], reverse=True) + to_remove = set() + for i in range(0, len(annotations)): + a = annotations[i] + for j in range(i + 1, len(annotations)): + b = annotations[j] + if i != j and j not in to_remove: + # check if + if b['area'] < a['area']: + if (a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8: + to_remove.add(j) + + return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove + + def _get_bbox_from_mask(self, mask): + mask = mask.astype(np.uint8) + contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + x1, y1, w, h = cv2.boundingRect(contours[0]) + x2, y2 = x1 + w, y1 + h + if len(contours) > 1: + for b in contours: + x_t, y_t, w_t, h_t = cv2.boundingRect(b) + # 将多个bbox合并成一个 + x1 = min(x1, x_t) + y1 = min(y1, y_t) + x2 = max(x2, x_t + w_t) + y2 = max(y2, y_t + h_t) + h = y2 - y1 + w = x2 - x1 + return [x1, y1, x2, y2] + + def plot(self, + annotations, + output, + bbox=None, + points=None, + point_label=None, + mask_random_color=True, + better_quality=True, + retina=False, + withContours=True): + if isinstance(annotations[0], dict): + annotations = [annotation['segmentation'] for annotation in annotations] + result_name = os.path.basename(self.img_path) + image = self.ori_img + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + original_h = image.shape[0] + original_w = image.shape[1] + # for MacOS only + # plt.switch_backend('TkAgg') + plt.figure(figsize=(original_w / 100, original_h / 100)) + # Add subplot with no margin. + plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) + plt.margins(0, 0) + plt.gca().xaxis.set_major_locator(plt.NullLocator()) + plt.gca().yaxis.set_major_locator(plt.NullLocator()) + + plt.imshow(image) + if better_quality: + if isinstance(annotations[0], torch.Tensor): + annotations = np.array(annotations.cpu()) + for i, mask in enumerate(annotations): + mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)) + annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)) + if self.device == 'cpu': + annotations = np.array(annotations) + self.fast_show_mask( + annotations, + plt.gca(), + random_color=mask_random_color, + bbox=bbox, + points=points, + pointlabel=point_label, + retinamask=retina, + target_height=original_h, + target_width=original_w, + ) + else: + if isinstance(annotations[0], np.ndarray): + annotations = torch.from_numpy(annotations) + self.fast_show_mask_gpu( + annotations, + plt.gca(), + random_color=mask_random_color, + bbox=bbox, + points=points, + pointlabel=point_label, + retinamask=retina, + target_height=original_h, + target_width=original_w, + ) + if isinstance(annotations, torch.Tensor): + annotations = annotations.cpu().numpy() + if withContours: + contour_all = [] + temp = np.zeros((original_h, original_w, 1)) + for i, mask in enumerate(annotations): + if type(mask) == dict: + mask = mask['segmentation'] + annotation = mask.astype(np.uint8) + if not retina: + annotation = cv2.resize( + annotation, + (original_w, original_h), + interpolation=cv2.INTER_NEAREST, + ) + contours, hierarchy = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + for contour in contours: + contour_all.append(contour) + cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2) + color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8]) + contour_mask = temp / 255 * color.reshape(1, 1, -1) + plt.imshow(contour_mask) + + save_path = output + if not os.path.exists(save_path): + os.makedirs(save_path) + plt.axis('off') + fig = plt.gcf() + plt.draw() + + try: + buf = fig.canvas.tostring_rgb() + except AttributeError: + fig.canvas.draw() + buf = fig.canvas.tostring_rgb() + cols, rows = fig.canvas.get_width_height() + img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3) + cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)) + + # CPU post process + def fast_show_mask( + self, + annotation, + ax, + random_color=False, + bbox=None, + points=None, + pointlabel=None, + retinamask=True, + target_height=960, + target_width=960, + ): + msak_sum = annotation.shape[0] + height = annotation.shape[1] + weight = annotation.shape[2] + # 将annotation 按照面积 排序 + areas = np.sum(annotation, axis=(1, 2)) + sorted_indices = np.argsort(areas) + annotation = annotation[sorted_indices] + + index = (annotation != 0).argmax(axis=0) + if random_color: + color = np.random.random((msak_sum, 1, 1, 3)) + else: + color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255]) + transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6 + visual = np.concatenate([color, transparency], axis=-1) + mask_image = np.expand_dims(annotation, -1) * visual + + show = np.zeros((height, weight, 4)) + h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij') + indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) + # 使用向量化索引更新show的值 + show[h_indices, w_indices, :] = mask_image[indices] + if bbox is not None: + x1, y1, x2, y2 = bbox + ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1)) + # draw point + if points is not None: + plt.scatter( + [point[0] for i, point in enumerate(points) if pointlabel[i] == 1], + [point[1] for i, point in enumerate(points) if pointlabel[i] == 1], + s=20, + c='y', + ) + plt.scatter( + [point[0] for i, point in enumerate(points) if pointlabel[i] == 0], + [point[1] for i, point in enumerate(points) if pointlabel[i] == 0], + s=20, + c='m', + ) + + if not retinamask: + show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST) + ax.imshow(show) + + def fast_show_mask_gpu( + self, + annotation, + ax, + random_color=False, + bbox=None, + points=None, + pointlabel=None, + retinamask=True, + target_height=960, + target_width=960, + ): + msak_sum = annotation.shape[0] + height = annotation.shape[1] + weight = annotation.shape[2] + areas = torch.sum(annotation, dim=(1, 2)) + sorted_indices = torch.argsort(areas, descending=False) + annotation = annotation[sorted_indices] + # 找每个位置第一个非零值下标 + index = (annotation != 0).to(torch.long).argmax(dim=0) + if random_color: + color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device) + else: + color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([ + 30 / 255, 144 / 255, 255 / 255]).to(annotation.device) + transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6 + visual = torch.cat([color, transparency], dim=-1) + mask_image = torch.unsqueeze(annotation, -1) * visual + # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式 + show = torch.zeros((height, weight, 4)).to(annotation.device) + h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight), indexing='ij') + indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) + # 使用向量化索引更新show的值 + show[h_indices, w_indices, :] = mask_image[indices] + show_cpu = show.cpu().numpy() + if bbox is not None: + x1, y1, x2, y2 = bbox + ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1)) + # draw point + if points is not None: + plt.scatter( + [point[0] for i, point in enumerate(points) if pointlabel[i] == 1], + [point[1] for i, point in enumerate(points) if pointlabel[i] == 1], + s=20, + c='y', + ) + plt.scatter( + [point[0] for i, point in enumerate(points) if pointlabel[i] == 0], + [point[1] for i, point in enumerate(points) if pointlabel[i] == 0], + s=20, + c='m', + ) + if not retinamask: + show_cpu = cv2.resize(show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST) + ax.imshow(show_cpu) + + # clip + @torch.no_grad() + def retrieve(self, model, preprocess, elements, search_text: str, device) -> int: + preprocessed_images = [preprocess(image).to(device) for image in elements] + tokenized_text = clip.tokenize([search_text]).to(device) + stacked_images = torch.stack(preprocessed_images) + image_features = model.encode_image(stacked_images) + text_features = model.encode_text(tokenized_text) + image_features /= image_features.norm(dim=-1, keepdim=True) + text_features /= text_features.norm(dim=-1, keepdim=True) + probs = 100.0 * image_features @ text_features.T + return probs[:, 0].softmax(dim=0) + + def _crop_image(self, format_results): + + image = Image.fromarray(cv2.cvtColor(self.ori_img, cv2.COLOR_BGR2RGB)) + ori_w, ori_h = image.size + annotations = format_results + mask_h, mask_w = annotations[0]['segmentation'].shape + if ori_w != mask_w or ori_h != mask_h: + image = image.resize((mask_w, mask_h)) + cropped_boxes = [] + cropped_images = [] + not_crop = [] + filter_id = [] + # annotations, _ = filter_masks(annotations) + # filter_id = list(_) + for _, mask in enumerate(annotations): + if np.sum(mask['segmentation']) <= 100: + filter_id.append(_) + continue + bbox = self._get_bbox_from_mask(mask['segmentation']) # mask 的 bbox + cropped_boxes.append(self._segment_image(image, bbox)) # 保存裁剪的图片 + # cropped_boxes.append(segment_image(image,mask["segmentation"])) + cropped_images.append(bbox) # 保存裁剪的图片的bbox + + return cropped_boxes, cropped_images, not_crop, filter_id, annotations + + def box_prompt(self, bbox): + + assert (bbox[2] != 0 and bbox[3] != 0) + masks = self.results[0].masks.data + target_height = self.ori_img.shape[0] + target_width = self.ori_img.shape[1] + h = masks.shape[1] + w = masks.shape[2] + if h != target_height or w != target_width: + bbox = [ + int(bbox[0] * w / target_width), + int(bbox[1] * h / target_height), + int(bbox[2] * w / target_width), + int(bbox[3] * h / target_height), ] + bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0 + bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0 + bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w + bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h + + # IoUs = torch.zeros(len(masks), dtype=torch.float32) + bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0]) + + masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2)) + orig_masks_area = torch.sum(masks, dim=(1, 2)) + + union = bbox_area + orig_masks_area - masks_area + IoUs = masks_area / union + max_iou_index = torch.argmax(IoUs) + + return np.array([masks[max_iou_index].cpu().numpy()]) + + def point_prompt(self, points, pointlabel): # numpy 处理 + + masks = self._format_results(self.results[0], 0) + target_height = self.ori_img.shape[0] + target_width = self.ori_img.shape[1] + h = masks[0]['segmentation'].shape[0] + w = masks[0]['segmentation'].shape[1] + if h != target_height or w != target_width: + points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points] + onemask = np.zeros((h, w)) + for i, annotation in enumerate(masks): + if type(annotation) == dict: + mask = annotation['segmentation'] + else: + mask = annotation + for i, point in enumerate(points): + if mask[point[1], point[0]] == 1 and pointlabel[i] == 1: + onemask += mask + if mask[point[1], point[0]] == 1 and pointlabel[i] == 0: + onemask -= mask + onemask = onemask >= 1 + return np.array([onemask]) + + def text_prompt(self, text): + format_results = self._format_results(self.results[0], 0) + cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results) + clip_model, preprocess = clip.load('ViT-B/32', device=self.device) + scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device) + max_idx = scores.argsort() + max_idx = max_idx[-1] + max_idx += sum(np.array(filter_id) <= int(max_idx)) + return np.array([annotations[max_idx]['segmentation']]) + + def everything_prompt(self): + return self.results[0].masks.data diff --git a/ultralytics/yolo/fastsam/utils.py b/ultralytics/yolo/fastsam/utils.py new file mode 100644 index 0000000..b0b8d9e --- /dev/null +++ b/ultralytics/yolo/fastsam/utils.py @@ -0,0 +1,63 @@ +import torch + + +def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20): + '''Adjust bounding boxes to stick to image border if they are within a certain threshold. + Args: + boxes: (n, 4) + image_shape: (height, width) + threshold: pixel threshold + + Returns: + adjusted_boxes: adjusted bounding boxes + ''' + + # Image dimensions + h, w = image_shape + + # Adjust boxes + boxes[:, 0] = torch.where(boxes[:, 0] < threshold, 0, boxes[:, 0]) # x1 + boxes[:, 1] = torch.where(boxes[:, 1] < threshold, 0, boxes[:, 1]) # y1 + boxes[:, 2] = torch.where(boxes[:, 2] > w - threshold, w, boxes[:, 2]) # x2 + boxes[:, 3] = torch.where(boxes[:, 3] > h - threshold, h, boxes[:, 3]) # y2 + + return boxes + + +def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False): + '''Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes. + Args: + box1: (4, ) + boxes: (n, 4) + + Returns: + high_iou_indices: Indices of boxes with IoU > thres + ''' + boxes = adjust_bboxes_to_image_border(boxes, image_shape) + # obtain coordinates for intersections + x1 = torch.max(box1[0], boxes[:, 0]) + y1 = torch.max(box1[1], boxes[:, 1]) + x2 = torch.min(box1[2], boxes[:, 2]) + y2 = torch.min(box1[3], boxes[:, 3]) + + # compute the area of intersection + intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0) + + # compute the area of both individual boxes + box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) + box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + + # compute the area of union + union = box1_area + box2_area - intersection + + # compute the IoU + iou = intersection / union # Should be shape (n, ) + if raw_output: + if iou.numel() == 0: + return 0 + return iou + + # get indices of boxes with IoU > thres + high_iou_indices = torch.nonzero(iou > iou_thres).flatten() + + return high_iou_indices diff --git a/ultralytics/yolo/fastsam/val.py b/ultralytics/yolo/fastsam/val.py new file mode 100644 index 0000000..250bd5e --- /dev/null +++ b/ultralytics/yolo/fastsam/val.py @@ -0,0 +1,244 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from multiprocessing.pool import ThreadPool +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F + +from ultralytics.yolo.utils import LOGGER, NUM_THREADS, ops +from ultralytics.yolo.utils.checks import check_requirements +from ultralytics.yolo.utils.metrics import SegmentMetrics, box_iou, mask_iou +from ultralytics.yolo.utils.plotting import output_to_target, plot_images +from ultralytics.yolo.v8.detect import DetectionValidator + + +class FastSAMValidator(DetectionValidator): + + def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): + """Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.""" + super().__init__(dataloader, save_dir, pbar, args, _callbacks) + self.args.task = 'segment' + self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot) + + def preprocess(self, batch): + """Preprocesses batch by converting masks to float and sending to device.""" + batch = super().preprocess(batch) + batch['masks'] = batch['masks'].to(self.device).float() + return batch + + def init_metrics(self, model): + """Initialize metrics and select mask processing function based on save_json flag.""" + super().init_metrics(model) + self.plot_masks = [] + if self.args.save_json: + check_requirements('pycocotools>=2.0.6') + self.process = ops.process_mask_upsample # more accurate + else: + self.process = ops.process_mask # faster + + def get_desc(self): + """Return a formatted description of evaluation metrics.""" + return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Mask(P', + 'R', 'mAP50', 'mAP50-95)') + + def postprocess(self, preds): + """Postprocesses YOLO predictions and returns output detections with proto.""" + p = ops.non_max_suppression(preds[0], + self.args.conf, + self.args.iou, + labels=self.lb, + multi_label=True, + agnostic=self.args.single_cls, + max_det=self.args.max_det, + nc=self.nc) + proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported + return p, proto + + def update_metrics(self, preds, batch): + """Metrics.""" + for si, (pred, proto) in enumerate(zip(preds[0], preds[1])): + idx = batch['batch_idx'] == si + cls = batch['cls'][idx] + bbox = batch['bboxes'][idx] + nl, npr = cls.shape[0], pred.shape[0] # number of labels, predictions + shape = batch['ori_shape'][si] + correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init + correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init + self.seen += 1 + + if npr == 0: + if nl: + self.stats.append((correct_bboxes, correct_masks, *torch.zeros( + (2, 0), device=self.device), cls.squeeze(-1))) + if self.args.plots: + self.confusion_matrix.process_batch(detections=None, labels=cls.squeeze(-1)) + continue + + # Masks + midx = [si] if self.args.overlap_mask else idx + gt_masks = batch['masks'][midx] + pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=batch['img'][si].shape[1:]) + + # Predictions + if self.args.single_cls: + pred[:, 5] = 0 + predn = pred.clone() + ops.scale_boxes(batch['img'][si].shape[1:], predn[:, :4], shape, + ratio_pad=batch['ratio_pad'][si]) # native-space pred + + # Evaluate + if nl: + height, width = batch['img'].shape[2:] + tbox = ops.xywh2xyxy(bbox) * torch.tensor( + (width, height, width, height), device=self.device) # target boxes + ops.scale_boxes(batch['img'][si].shape[1:], tbox, shape, + ratio_pad=batch['ratio_pad'][si]) # native-space labels + labelsn = torch.cat((cls, tbox), 1) # native-space labels + correct_bboxes = self._process_batch(predn, labelsn) + # TODO: maybe remove these `self.` arguments as they already are member variable + correct_masks = self._process_batch(predn, + labelsn, + pred_masks, + gt_masks, + overlap=self.args.overlap_mask, + masks=True) + if self.args.plots: + self.confusion_matrix.process_batch(predn, labelsn) + + # Append correct_masks, correct_boxes, pconf, pcls, tcls + self.stats.append((correct_bboxes, correct_masks, pred[:, 4], pred[:, 5], cls.squeeze(-1))) + + pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8) + if self.args.plots and self.batch_i < 3: + self.plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot + + # Save + if self.args.save_json: + pred_masks = ops.scale_image(pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(), + shape, + ratio_pad=batch['ratio_pad'][si]) + self.pred_to_json(predn, batch['im_file'][si], pred_masks) + # if self.args.save_txt: + # save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') + + def finalize_metrics(self, *args, **kwargs): + """Sets speed and confusion matrix for evaluation metrics.""" + self.metrics.speed = self.speed + self.metrics.confusion_matrix = self.confusion_matrix + + def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False): + """ + Return correct prediction matrix + Arguments: + detections (array[N, 6]), x1, y1, x2, y2, conf, class + labels (array[M, 5]), class, x1, y1, x2, y2 + Returns: + correct (array[N, 10]), for 10 IoU levels + """ + if masks: + if overlap: + nl = len(labels) + index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1 + gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640) + gt_masks = torch.where(gt_masks == index, 1.0, 0.0) + if gt_masks.shape[1:] != pred_masks.shape[1:]: + gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode='bilinear', align_corners=False)[0] + gt_masks = gt_masks.gt_(0.5) + iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1)) + else: # boxes + iou = box_iou(labels[:, 1:], detections[:, :4]) + + correct = np.zeros((detections.shape[0], self.iouv.shape[0])).astype(bool) + correct_class = labels[:, 0:1] == detections[:, 5] + for i in range(len(self.iouv)): + x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match + if x[0].shape[0]: + matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), + 1).cpu().numpy() # [label, detect, iou] + if x[0].shape[0] > 1: + matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 1], return_index=True)[1]] + # matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 0], return_index=True)[1]] + correct[matches[:, 1].astype(int), i] = True + return torch.tensor(correct, dtype=torch.bool, device=detections.device) + + def plot_val_samples(self, batch, ni): + """Plots validation samples with bounding box labels.""" + plot_images(batch['img'], + batch['batch_idx'], + batch['cls'].squeeze(-1), + batch['bboxes'], + batch['masks'], + paths=batch['im_file'], + fname=self.save_dir / f'val_batch{ni}_labels.jpg', + names=self.names, + on_plot=self.on_plot) + + def plot_predictions(self, batch, preds, ni): + """Plots batch predictions with masks and bounding boxes.""" + plot_images( + batch['img'], + *output_to_target(preds[0], max_det=15), # not set to self.args.max_det due to slow plotting speed + torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks, + paths=batch['im_file'], + fname=self.save_dir / f'val_batch{ni}_pred.jpg', + names=self.names, + on_plot=self.on_plot) # pred + self.plot_masks.clear() + + def pred_to_json(self, predn, filename, pred_masks): + """Save one JSON result.""" + # Example result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236} + from pycocotools.mask import encode # noqa + + def single_encode(x): + """Encode predicted masks as RLE and append results to jdict.""" + rle = encode(np.asarray(x[:, :, None], order='F', dtype='uint8'))[0] + rle['counts'] = rle['counts'].decode('utf-8') + return rle + + stem = Path(filename).stem + image_id = int(stem) if stem.isnumeric() else stem + box = ops.xyxy2xywh(predn[:, :4]) # xywh + box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner + pred_masks = np.transpose(pred_masks, (2, 0, 1)) + with ThreadPool(NUM_THREADS) as pool: + rles = pool.map(single_encode, pred_masks) + for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())): + self.jdict.append({ + 'image_id': image_id, + 'category_id': self.class_map[int(p[5])], + 'bbox': [round(x, 3) for x in b], + 'score': round(p[4], 5), + 'segmentation': rles[i]}) + + def eval_json(self, stats): + """Return COCO-style object detection evaluation metrics.""" + if self.args.save_json and self.is_coco and len(self.jdict): + anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations + pred_json = self.save_dir / 'predictions.json' # predictions + LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...') + try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb + check_requirements('pycocotools>=2.0.6') + from pycocotools.coco import COCO # noqa + from pycocotools.cocoeval import COCOeval # noqa + + for x in anno_json, pred_json: + assert x.is_file(), f'{x} file not found' + anno = COCO(str(anno_json)) # init annotations api + pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path) + for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm')]): + if self.is_coco: + eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval + eval.evaluate() + eval.accumulate() + eval.summarize() + idx = i * 4 + 2 + stats[self.metrics.keys[idx + 1]], stats[ + self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50 + except Exception as e: + LOGGER.warning(f'pycocotools unable to run: {e}') + return stats