Fix save_json(predn, batch) (#105)
				
					
				
			Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
		| @ -109,9 +109,6 @@ class BaseValidator: | ||||
|                 self.plot_val_samples(batch, batch_i) | ||||
|                 self.plot_predictions(batch, preds, batch_i) | ||||
|  | ||||
|             if self.args.save_json: | ||||
|                 self.pred_to_json(preds, batch) | ||||
|  | ||||
|         stats = self.get_stats() | ||||
|         self.check_stats(stats) | ||||
|         self.print_results() | ||||
| @ -126,8 +123,7 @@ class BaseValidator: | ||||
|                 with open(str(self.save_dir / "predictions.json"), 'w') as f: | ||||
|                     self.logger.info(f"Saving {f.name}...") | ||||
|                     json.dump(self.jdict, f)  # flatten and save | ||||
|  | ||||
|             stats = self.eval_json(stats) | ||||
|                 stats = self.eval_json(stats)  # update stats | ||||
|             return stats | ||||
|  | ||||
|     def get_dataloader(self, dataset_path, batch_size): | ||||
|  | ||||
| @ -71,7 +71,7 @@ class DetectionValidator(BaseValidator): | ||||
|  | ||||
|     def update_metrics(self, preds, batch): | ||||
|         # Metrics | ||||
|         for si, (pred) in enumerate(preds): | ||||
|         for si, pred in enumerate(preds): | ||||
|             labels = self.targets[self.targets[:, 0] == si, 1:] | ||||
|             nl, npr = labels.shape[0], pred.shape[0]  # number of labels, predictions | ||||
|             shape = batch["ori_shape"][si] | ||||
| @ -103,11 +103,11 @@ class DetectionValidator(BaseValidator): | ||||
|                     self.confusion_matrix.process_batch(predn, labelsn) | ||||
|             self.stats.append((correct_bboxes, pred[:, 4], pred[:, 5], labels[:, 0]))  # (conf, pcls, tcls) | ||||
|  | ||||
|             # TODO: Save/log | ||||
|             ''' | ||||
|             if self.args.save_txt: | ||||
|                 save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') | ||||
|             ''' | ||||
|             # Save | ||||
|             if self.args.save_json: | ||||
|                 self.pred_to_json(predn, batch["im_file"][si]) | ||||
|             # if self.args.save_txt: | ||||
|             #    save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') | ||||
|  | ||||
|     def get_stats(self): | ||||
|         stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)]  # to numpy | ||||
| @ -197,18 +197,17 @@ class DetectionValidator(BaseValidator): | ||||
|                     fname=self.save_dir / f'val_batch{ni}_pred.jpg', | ||||
|                     names=self.names)  # pred | ||||
|  | ||||
|     def pred_to_json(self, preds, batch): | ||||
|         for i, f in enumerate(batch["im_file"]): | ||||
|             stem = Path(f).stem | ||||
|             image_id = int(stem) if stem.isnumeric() else stem | ||||
|             box = ops.xyxy2xywh(preds[i][:, :4])  # xywh | ||||
|             box[:, :2] -= box[:, 2:] / 2  # xy center to top-left corner | ||||
|             for p, b in zip(preds[i].tolist(), box.tolist()): | ||||
|                 self.jdict.append({ | ||||
|                     'image_id': image_id, | ||||
|                     'category_id': self.class_map[int(p[5])], | ||||
|                     'bbox': [round(x, 3) for x in b], | ||||
|                     'score': round(p[4], 5)}) | ||||
|     def pred_to_json(self, predn, filename): | ||||
|         stem = Path(filename).stem | ||||
|         image_id = int(stem) if stem.isnumeric() else stem | ||||
|         box = ops.xyxy2xywh(predn[:, :4])  # xywh | ||||
|         box[:, :2] -= box[:, 2:] / 2  # xy center to top-left corner | ||||
|         for p, b in zip(predn.tolist(), box.tolist()): | ||||
|             self.jdict.append({ | ||||
|                 'image_id': image_id, | ||||
|                 'category_id': self.class_map[int(p[5])], | ||||
|                 'bbox': [round(x, 3) for x in b], | ||||
|                 'score': round(p[4], 5)}) | ||||
|  | ||||
|     def eval_json(self, stats): | ||||
|         if self.args.save_json and self.is_coco and len(self.jdict): | ||||
|  | ||||
		Reference in New Issue
	
	Block a user