From a9dc1637c278018e3b37dc72b82d3758189e9ba5 Mon Sep 17 00:00:00 2001 From: Laughing <61612323+Laughing-q@users.noreply.github.com> Date: Fri, 30 Dec 2022 01:01:03 +0800 Subject: [PATCH] Mask pycocotools (#116) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ultralytics/yolo/engine/trainer.py | 1 + ultralytics/yolo/v8/segment/val.py | 80 ++++++++++++++++++++++++------ 2 files changed, 67 insertions(+), 14 deletions(-) diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index bdb589b..10478a4 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -423,6 +423,7 @@ class BaseTrainer: strip_optimizer(f) # strip optimizers if f is self.best: self.console.info(f'\nValidating {f}...') + self.validator.args.save_json = True self.metrics = self.validator(model=f) self.metrics.pop('fitness', None) self.trigger_callbacks('on_val_end') diff --git a/ultralytics/yolo/v8/segment/val.py b/ultralytics/yolo/v8/segment/val.py index 2b05b09..05eea88 100644 --- a/ultralytics/yolo/v8/segment/val.py +++ b/ultralytics/yolo/v8/segment/val.py @@ -1,11 +1,14 @@ import os +from multiprocessing.pool import ThreadPool +from pathlib import Path import hydra import numpy as np import torch import torch.nn.functional as F -from ultralytics.yolo.utils import DEFAULT_CONFIG, ops +from ultralytics.yolo.utils import DEFAULT_CONFIG, NUM_THREADS, ops +from ultralytics.yolo.utils.checks import check_requirements from ultralytics.yolo.utils.metrics import ConfusionMatrix, SegmentMetrics, box_iou, mask_iou from ultralytics.yolo.utils.plotting import output_to_target, plot_images @@ -16,10 +19,6 @@ class SegmentationValidator(DetectionValidator): def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None): super().__init__(dataloader, save_dir, pbar, logger, args) - if self.args.save_json: - self.process = ops.process_mask_upsample # more accurate - else: - self.process = ops.process_mask # faster self.metrics = SegmentMetrics(save_dir=self.save_dir, plot=self.args.plots) def preprocess(self, batch): @@ -51,6 +50,10 @@ class SegmentationValidator(DetectionValidator): self.seen = 0 self.jdict = [] self.stats = [] + if self.args.save_json: + self.process = ops.process_mask_upsample # more accurate + else: + self.process = ops.process_mask # faster def get_desc(self): return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P", @@ -118,16 +121,13 @@ class SegmentationValidator(DetectionValidator): if self.args.plots and self.batch_i < 3: self.plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot - # TODO: Save/log - ''' - if self.args.save_txt: - save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') + # Save if self.args.save_json: - pred_masks = scale_image(im[si].shape[1:], - pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(), shape, shapes[si][1]) - save_one_json(predn, jdict, path, class_map, pred_masks) # append to COCO-JSON dictionary - # callbacks.run('on_val_image_end', pred, predn, path, names, im[si]) - ''' + pred_masks = ops.scale_image(batch["img"][si].shape[1:], + pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(), shape) + self.pred_to_json(predn, batch["im_file"][si], pred_masks) + # if self.args.save_txt: + # save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False): """ @@ -198,6 +198,58 @@ class SegmentationValidator(DetectionValidator): names=self.names) # pred self.plot_masks.clear() + def pred_to_json(self, predn, filename, pred_masks): + # Save one JSON result {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236} + from pycocotools.mask import encode + + def single_encode(x): + rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0] + rle["counts"] = rle["counts"].decode("utf-8") + return rle + + stem = Path(filename).stem + image_id = int(stem) if stem.isnumeric() else stem + box = ops.xyxy2xywh(predn[:, :4]) # xywh + box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner + pred_masks = np.transpose(pred_masks, (2, 0, 1)) + with ThreadPool(NUM_THREADS) as pool: + rles = pool.map(single_encode, pred_masks) + for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())): + self.jdict.append({ + 'image_id': image_id, + 'category_id': self.class_map[int(p[5])], + 'bbox': [round(x, 3) for x in b], + 'score': round(p[4], 5), + 'segmentation': rles[i]}) + + def eval_json(self, stats): + if self.args.save_json and self.is_coco and len(self.jdict): + anno_json = self.data['path'] / "annotations/instances_val2017.json" # annotations + pred_json = self.save_dir / "predictions.json" # predictions + self.logger.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...') + try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb + check_requirements('pycocotools>=2.0.6') + from pycocotools.coco import COCO # noqa + from pycocotools.cocoeval import COCOeval # noqa + + for x in anno_json, pred_json: + assert x.is_file(), f"{x} file not found" + anno = COCO(str(anno_json)) # init annotations api + pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path) + for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm')]): + if self.is_coco: + eval.params.imgIds = [int(Path(x).stem) + for x in self.dataloader.dataset.im_files] # images to eval + eval.evaluate() + eval.accumulate() + eval.summarize() + idx = i * 4 + 2 + stats[self.metric_keys[idx + 1]], stats[ + self.metric_keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50 + except Exception as e: + self.logger.warning(f'pycocotools unable to run: {e}') + return stats + @hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) def val(cfg):