Fix Classification train logging (#157)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
@ -11,7 +11,9 @@ import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
# Constants
|
||||
@ -57,8 +59,8 @@ HELP_MSG = \
|
||||
"""
|
||||
|
||||
# Settings
|
||||
# torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
||||
# np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
||||
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
||||
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
||||
pd.options.display.max_columns = 10
|
||||
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
|
||||
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
|
||||
|
@ -565,14 +565,8 @@ class SegmentMetrics:
|
||||
@property
|
||||
def keys(self):
|
||||
return [
|
||||
"metrics/precision(B)",
|
||||
"metrics/recall(B)",
|
||||
"metrics/mAP50(B)",
|
||||
"metrics/mAP50-95(B)", # metrics
|
||||
"metrics/precision(M)",
|
||||
"metrics/recall(M)",
|
||||
"metrics/mAP50(M)",
|
||||
"metrics/mAP50-95(M)"]
|
||||
"metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)",
|
||||
"metrics/precision(M)", "metrics/recall(M)", "metrics/mAP50(M)", "metrics/mAP50-95(M)"]
|
||||
|
||||
def mean_results(self):
|
||||
return self.metric_box.mean_results() + self.metric_mask.mean_results()
|
||||
@ -603,7 +597,10 @@ class ClassifyMetrics:
|
||||
self.top1 = 0
|
||||
self.top5 = 0
|
||||
|
||||
def process(self, correct):
|
||||
def process(self, targets, pred):
|
||||
# target classes and predicted classes
|
||||
pred, targets = torch.cat(pred), torch.cat(targets)
|
||||
correct = (targets[:, None] == pred).float()
|
||||
acc = torch.stack((correct[:, 0], correct.max(1).values), dim=1) # (top1, top5) accuracy
|
||||
self.top1, self.top5 = acc.mean(0).tolist()
|
||||
|
||||
@ -617,4 +614,4 @@ class ClassifyMetrics:
|
||||
|
||||
@property
|
||||
def keys(self):
|
||||
return ["top1", "top5"]
|
||||
return ["metrics/accuracy_top1", "metrics/accuracy_top5"]
|
||||
|
Reference in New Issue
Block a user