Mask pycocotools (#116)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Laughing 2 years ago committed by GitHub
parent 37e0f6e0e4
commit a9dc1637c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -423,6 +423,7 @@ class BaseTrainer:
strip_optimizer(f) # strip optimizers
if f is self.best:
self.console.info(f'\nValidating {f}...')
self.validator.args.save_json = True
self.metrics = self.validator(model=f)
self.metrics.pop('fitness', None)
self.trigger_callbacks('on_val_end')

@ -1,11 +1,14 @@
import os
from multiprocessing.pool import ThreadPool
from pathlib import Path
import hydra
import numpy as np
import torch
import torch.nn.functional as F
from ultralytics.yolo.utils import DEFAULT_CONFIG, ops
from ultralytics.yolo.utils import DEFAULT_CONFIG, NUM_THREADS, ops
from ultralytics.yolo.utils.checks import check_requirements
from ultralytics.yolo.utils.metrics import ConfusionMatrix, SegmentMetrics, box_iou, mask_iou
from ultralytics.yolo.utils.plotting import output_to_target, plot_images
@ -16,10 +19,6 @@ class SegmentationValidator(DetectionValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
super().__init__(dataloader, save_dir, pbar, logger, args)
if self.args.save_json:
self.process = ops.process_mask_upsample # more accurate
else:
self.process = ops.process_mask # faster
self.metrics = SegmentMetrics(save_dir=self.save_dir, plot=self.args.plots)
def preprocess(self, batch):
@ -51,6 +50,10 @@ class SegmentationValidator(DetectionValidator):
self.seen = 0
self.jdict = []
self.stats = []
if self.args.save_json:
self.process = ops.process_mask_upsample # more accurate
else:
self.process = ops.process_mask # faster
def get_desc(self):
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P",
@ -118,16 +121,13 @@ class SegmentationValidator(DetectionValidator):
if self.args.plots and self.batch_i < 3:
self.plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
# TODO: Save/log
'''
if self.args.save_txt:
save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
# Save
if self.args.save_json:
pred_masks = scale_image(im[si].shape[1:],
pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(), shape, shapes[si][1])
save_one_json(predn, jdict, path, class_map, pred_masks) # append to COCO-JSON dictionary
# callbacks.run('on_val_image_end', pred, predn, path, names, im[si])
'''
pred_masks = ops.scale_image(batch["img"][si].shape[1:],
pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(), shape)
self.pred_to_json(predn, batch["im_file"][si], pred_masks)
# if self.args.save_txt:
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False):
"""
@ -198,6 +198,58 @@ class SegmentationValidator(DetectionValidator):
names=self.names) # pred
self.plot_masks.clear()
def pred_to_json(self, predn, filename, pred_masks):
# Save one JSON result {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
from pycocotools.mask import encode
def single_encode(x):
rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0]
rle["counts"] = rle["counts"].decode("utf-8")
return rle
stem = Path(filename).stem
image_id = int(stem) if stem.isnumeric() else stem
box = ops.xyxy2xywh(predn[:, :4]) # xywh
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
pred_masks = np.transpose(pred_masks, (2, 0, 1))
with ThreadPool(NUM_THREADS) as pool:
rles = pool.map(single_encode, pred_masks)
for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())):
self.jdict.append({
'image_id': image_id,
'category_id': self.class_map[int(p[5])],
'bbox': [round(x, 3) for x in b],
'score': round(p[4], 5),
'segmentation': rles[i]})
def eval_json(self, stats):
if self.args.save_json and self.is_coco and len(self.jdict):
anno_json = self.data['path'] / "annotations/instances_val2017.json" # annotations
pred_json = self.save_dir / "predictions.json" # predictions
self.logger.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
check_requirements('pycocotools>=2.0.6')
from pycocotools.coco import COCO # noqa
from pycocotools.cocoeval import COCOeval # noqa
for x in anno_json, pred_json:
assert x.is_file(), f"{x} file not found"
anno = COCO(str(anno_json)) # init annotations api
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
for i, eval in enumerate([COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm')]):
if self.is_coco:
eval.params.imgIds = [int(Path(x).stem)
for x in self.dataloader.dataset.im_files] # images to eval
eval.evaluate()
eval.accumulate()
eval.summarize()
idx = i * 4 + 2
stats[self.metric_keys[idx + 1]], stats[
self.metric_keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50
except Exception as e:
self.logger.warning(f'pycocotools unable to run: {e}')
return stats
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def val(cfg):

Loading…
Cancel
Save