ultralytics 8.0.153 YOLO Tasks Cleanup (#4314)

This commit is contained in:
Glenn Jocher
2023-08-12 02:30:57 +02:00
committed by GitHub
parent 39395aedc8
commit 822608986c
22 changed files with 87 additions and 55 deletions

View File

@ -9,7 +9,7 @@ from ultralytics.cfg import get_cfg
from ultralytics.engine.exporter import Exporter
from ultralytics.hub.utils import HUB_WEB_ROOT
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
from ultralytics.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
from ultralytics.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks, emojis,
is_git_dir, yaml_load)
from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
from ultralytics.utils.downloads import GITHUB_ASSET_STEMS
@ -448,11 +448,11 @@ class Model:
"""Load model/trainer/validator/predictor."""
try:
return self.task_map[self.task][key]
except Exception:
except Exception as e:
name = self.__class__.__name__
mode = inspect.stack()[1][3] # get the function name.
raise NotImplementedError(
f'WARNING ⚠️ `{name}` model does not support `{mode}` mode for `{self.task}` task yet.')
emojis(f'WARNING ⚠️ `{name}` model does not support `{mode}` mode for `{self.task}` task yet.')) from e
@property
def task_map(self):

View File

@ -51,9 +51,18 @@ class BaseValidator:
device (torch.device): Device to use for validation.
batch_i (int): Current batch index.
training (bool): Whether the model is in training mode.
speed (float): Batch processing speed in seconds.
jdict (dict): Dictionary to store validation results.
names (dict): Class names.
seen: Records the number of images seen so far during validation.
stats: Placeholder for statistics during validation.
confusion_matrix: Placeholder for a confusion matrix.
nc: Number of classes.
iouv: (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
jdict (dict): Dictionary to store JSON validation results.
speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
batch processing times in milliseconds.
save_dir (Path): Directory to save results.
plots (dict): Dictionary to store plots for visualization.
callbacks (dict): Dictionary to store various callback functions.
"""
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
@ -65,6 +74,7 @@ class BaseValidator:
save_dir (Path): Directory to save results.
pbar (tqdm.tqdm): Progress bar for displaying progress.
args (SimpleNamespace): Configuration for the validator.
_callbacks (dict): Dictionary to store various callback functions.
"""
self.dataloader = dataloader
self.pbar = pbar
@ -74,8 +84,14 @@ class BaseValidator:
self.device = None
self.batch_i = None
self.training = True
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
self.names = None
self.seen = None
self.stats = None
self.confusion_matrix = None
self.nc = None
self.iouv = None
self.jdict = None
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
name = self.args.name or f'{self.args.mode}'
@ -200,26 +216,26 @@ class BaseValidator:
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
return stats
def match_predictions(self, pred_classes: torch.Tensor, true_classes: torch.Tensor,
iou: torch.Tensor) -> torch.Tensor:
def match_predictions(self, pred_classes, true_classes, iou):
"""
Matches predictions to ground truth objects (pred_classes, true_classes) using IoU.
Args:
pred_classes (torch.Tensor): Predicted class indices of shape(N,).
true_classes (torch.Tensor): Target class indices of shape(M,).
iou (torch.Tensor): IoU thresholds from 0.50 to 0.95 in space of 0.05.
Returns:
(torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds.
"""
correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
correct_class = true_classes[:, None] == pred_classes
for i in range(len(self.iouv)):
x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match
if x[0].shape[0]:
for i, iouv in enumerate(self.iouv):
x = torch.nonzero(iou.ge(iouv) & correct_class) # IoU > threshold and classes match
if x.shape[0]:
# Concatenate [label, detect, iou]
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
if x[0].shape[0] > 1:
matches = torch.cat((x, iou[x[:, 0], x[:, 1]].unsqueeze(1)), 1).cpu().numpy()
if x.shape[0] > 1:
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
# matches = matches[matches[:, 2].argsort()[::-1]]