Fix `save_json(predn, batch)` (#105)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent 8b6466f731
commit 38d6df55cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -109,9 +109,6 @@ class BaseValidator:
self.plot_val_samples(batch, batch_i) self.plot_val_samples(batch, batch_i)
self.plot_predictions(batch, preds, batch_i) self.plot_predictions(batch, preds, batch_i)
if self.args.save_json:
self.pred_to_json(preds, batch)
stats = self.get_stats() stats = self.get_stats()
self.check_stats(stats) self.check_stats(stats)
self.print_results() self.print_results()
@ -126,8 +123,7 @@ class BaseValidator:
with open(str(self.save_dir / "predictions.json"), 'w') as f: with open(str(self.save_dir / "predictions.json"), 'w') as f:
self.logger.info(f"Saving {f.name}...") self.logger.info(f"Saving {f.name}...")
json.dump(self.jdict, f) # flatten and save json.dump(self.jdict, f) # flatten and save
stats = self.eval_json(stats) # update stats
stats = self.eval_json(stats)
return stats return stats
def get_dataloader(self, dataset_path, batch_size): def get_dataloader(self, dataset_path, batch_size):

@ -71,7 +71,7 @@ class DetectionValidator(BaseValidator):
def update_metrics(self, preds, batch): def update_metrics(self, preds, batch):
# Metrics # Metrics
for si, (pred) in enumerate(preds): for si, pred in enumerate(preds):
labels = self.targets[self.targets[:, 0] == si, 1:] labels = self.targets[self.targets[:, 0] == si, 1:]
nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions
shape = batch["ori_shape"][si] shape = batch["ori_shape"][si]
@ -103,11 +103,11 @@ class DetectionValidator(BaseValidator):
self.confusion_matrix.process_batch(predn, labelsn) self.confusion_matrix.process_batch(predn, labelsn)
self.stats.append((correct_bboxes, pred[:, 4], pred[:, 5], labels[:, 0])) # (conf, pcls, tcls) self.stats.append((correct_bboxes, pred[:, 4], pred[:, 5], labels[:, 0])) # (conf, pcls, tcls)
# TODO: Save/log # Save
''' if self.args.save_json:
if self.args.save_txt: self.pred_to_json(predn, batch["im_file"][si])
save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') # if self.args.save_txt:
''' # save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
def get_stats(self): def get_stats(self):
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
@ -197,18 +197,17 @@ class DetectionValidator(BaseValidator):
fname=self.save_dir / f'val_batch{ni}_pred.jpg', fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names) # pred names=self.names) # pred
def pred_to_json(self, preds, batch): def pred_to_json(self, predn, filename):
for i, f in enumerate(batch["im_file"]): stem = Path(filename).stem
stem = Path(f).stem image_id = int(stem) if stem.isnumeric() else stem
image_id = int(stem) if stem.isnumeric() else stem box = ops.xyxy2xywh(predn[:, :4]) # xywh
box = ops.xyxy2xywh(preds[i][:, :4]) # xywh box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner for p, b in zip(predn.tolist(), box.tolist()):
for p, b in zip(preds[i].tolist(), box.tolist()): self.jdict.append({
self.jdict.append({ 'image_id': image_id,
'image_id': image_id, 'category_id': self.class_map[int(p[5])],
'category_id': self.class_map[int(p[5])], 'bbox': [round(x, 3) for x in b],
'bbox': [round(x, 3) for x in b], 'score': round(p[4], 5)})
'score': round(p[4], 5)})
def eval_json(self, stats): def eval_json(self, stats):
if self.args.save_json and self.is_coco and len(self.jdict): if self.args.save_json and self.is_coco and len(self.jdict):

Loading…
Cancel
Save