Fix Classification train logging (#157)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
Glenn Jocher
2023-01-08 17:52:44 +01:00
committed by GitHub
parent d387359f74
commit e79ea1666c
7 changed files with 86 additions and 40 deletions

View File

@ -11,7 +11,9 @@ import uuid
from pathlib import Path
import cv2
import numpy as np
import pandas as pd
import torch
import yaml
# Constants
@ -57,8 +59,8 @@ HELP_MSG = \
"""
# Settings
# torch.set_printoptions(linewidth=320, precision=5, profile='long')
# np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
torch.set_printoptions(linewidth=320, precision=5, profile='long')
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
pd.options.display.max_columns = 10
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads

View File

@ -565,14 +565,8 @@ class SegmentMetrics:
@property
def keys(self):
return [
"metrics/precision(B)",
"metrics/recall(B)",
"metrics/mAP50(B)",
"metrics/mAP50-95(B)", # metrics
"metrics/precision(M)",
"metrics/recall(M)",
"metrics/mAP50(M)",
"metrics/mAP50-95(M)"]
"metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)",
"metrics/precision(M)", "metrics/recall(M)", "metrics/mAP50(M)", "metrics/mAP50-95(M)"]
def mean_results(self):
return self.metric_box.mean_results() + self.metric_mask.mean_results()
@ -603,7 +597,10 @@ class ClassifyMetrics:
self.top1 = 0
self.top5 = 0
def process(self, correct):
def process(self, targets, pred):
# target classes and predicted classes
pred, targets = torch.cat(pred), torch.cat(targets)
correct = (targets[:, None] == pred).float()
acc = torch.stack((correct[:, 0], correct.max(1).values), dim=1) # (top1, top5) accuracy
self.top1, self.top5 = acc.mean(0).tolist()
@ -617,4 +614,4 @@ class ClassifyMetrics:
@property
def keys(self):
return ["top1", "top5"]
return ["metrics/accuracy_top1", "metrics/accuracy_top5"]