From 38d6df55cbbc007a4bccf13498ecb3c70833498f Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 27 Dec 2022 18:42:36 +0100 Subject: [PATCH] Fix `save_json(predn, batch)` (#105) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ultralytics/yolo/engine/validator.py | 6 +---- ultralytics/yolo/v8/detect/val.py | 35 ++++++++++++++-------------- 2 files changed, 18 insertions(+), 23 deletions(-) diff --git a/ultralytics/yolo/engine/validator.py b/ultralytics/yolo/engine/validator.py index 4c756f0..aae1daa 100644 --- a/ultralytics/yolo/engine/validator.py +++ b/ultralytics/yolo/engine/validator.py @@ -109,9 +109,6 @@ class BaseValidator: self.plot_val_samples(batch, batch_i) self.plot_predictions(batch, preds, batch_i) - if self.args.save_json: - self.pred_to_json(preds, batch) - stats = self.get_stats() self.check_stats(stats) self.print_results() @@ -126,8 +123,7 @@ class BaseValidator: with open(str(self.save_dir / "predictions.json"), 'w') as f: self.logger.info(f"Saving {f.name}...") json.dump(self.jdict, f) # flatten and save - - stats = self.eval_json(stats) + stats = self.eval_json(stats) # update stats return stats def get_dataloader(self, dataset_path, batch_size): diff --git a/ultralytics/yolo/v8/detect/val.py b/ultralytics/yolo/v8/detect/val.py index 763d86d..a8ce7f3 100644 --- a/ultralytics/yolo/v8/detect/val.py +++ b/ultralytics/yolo/v8/detect/val.py @@ -71,7 +71,7 @@ class DetectionValidator(BaseValidator): def update_metrics(self, preds, batch): # Metrics - for si, (pred) in enumerate(preds): + for si, pred in enumerate(preds): labels = self.targets[self.targets[:, 0] == si, 1:] nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions shape = batch["ori_shape"][si] @@ -103,11 +103,11 @@ class DetectionValidator(BaseValidator): self.confusion_matrix.process_batch(predn, labelsn) self.stats.append((correct_bboxes, pred[:, 4], pred[:, 5], labels[:, 0])) # (conf, pcls, tcls) - # TODO: Save/log - ''' - if self.args.save_txt: - save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') - ''' + # Save + if self.args.save_json: + self.pred_to_json(predn, batch["im_file"][si]) + # if self.args.save_txt: + # save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') def get_stats(self): stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy @@ -197,18 +197,17 @@ class DetectionValidator(BaseValidator): fname=self.save_dir / f'val_batch{ni}_pred.jpg', names=self.names) # pred - def pred_to_json(self, preds, batch): - for i, f in enumerate(batch["im_file"]): - stem = Path(f).stem - image_id = int(stem) if stem.isnumeric() else stem - box = ops.xyxy2xywh(preds[i][:, :4]) # xywh - box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner - for p, b in zip(preds[i].tolist(), box.tolist()): - self.jdict.append({ - 'image_id': image_id, - 'category_id': self.class_map[int(p[5])], - 'bbox': [round(x, 3) for x in b], - 'score': round(p[4], 5)}) + def pred_to_json(self, predn, filename): + stem = Path(filename).stem + image_id = int(stem) if stem.isnumeric() else stem + box = ops.xyxy2xywh(predn[:, :4]) # xywh + box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner + for p, b in zip(predn.tolist(), box.tolist()): + self.jdict.append({ + 'image_id': image_id, + 'category_id': self.class_map[int(p[5])], + 'bbox': [round(x, 3) for x in b], + 'score': round(p[4], 5)}) def eval_json(self, stats): if self.args.save_json and self.is_coco and len(self.jdict):