ultralytics 8.0.153
YOLO Tasks Cleanup (#4314)
This commit is contained in:
@ -9,7 +9,7 @@ from ultralytics.cfg import get_cfg
|
||||
from ultralytics.engine.exporter import Exporter
|
||||
from ultralytics.hub.utils import HUB_WEB_ROOT
|
||||
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
|
||||
from ultralytics.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
|
||||
from ultralytics.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks, emojis,
|
||||
is_git_dir, yaml_load)
|
||||
from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
|
||||
from ultralytics.utils.downloads import GITHUB_ASSET_STEMS
|
||||
@ -448,11 +448,11 @@ class Model:
|
||||
"""Load model/trainer/validator/predictor."""
|
||||
try:
|
||||
return self.task_map[self.task][key]
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
name = self.__class__.__name__
|
||||
mode = inspect.stack()[1][3] # get the function name.
|
||||
raise NotImplementedError(
|
||||
f'WARNING ⚠️ `{name}` model does not support `{mode}` mode for `{self.task}` task yet.')
|
||||
emojis(f'WARNING ⚠️ `{name}` model does not support `{mode}` mode for `{self.task}` task yet.')) from e
|
||||
|
||||
@property
|
||||
def task_map(self):
|
||||
|
@ -51,9 +51,18 @@ class BaseValidator:
|
||||
device (torch.device): Device to use for validation.
|
||||
batch_i (int): Current batch index.
|
||||
training (bool): Whether the model is in training mode.
|
||||
speed (float): Batch processing speed in seconds.
|
||||
jdict (dict): Dictionary to store validation results.
|
||||
names (dict): Class names.
|
||||
seen: Records the number of images seen so far during validation.
|
||||
stats: Placeholder for statistics during validation.
|
||||
confusion_matrix: Placeholder for a confusion matrix.
|
||||
nc: Number of classes.
|
||||
iouv: (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
|
||||
jdict (dict): Dictionary to store JSON validation results.
|
||||
speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
|
||||
batch processing times in milliseconds.
|
||||
save_dir (Path): Directory to save results.
|
||||
plots (dict): Dictionary to store plots for visualization.
|
||||
callbacks (dict): Dictionary to store various callback functions.
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||
@ -65,6 +74,7 @@ class BaseValidator:
|
||||
save_dir (Path): Directory to save results.
|
||||
pbar (tqdm.tqdm): Progress bar for displaying progress.
|
||||
args (SimpleNamespace): Configuration for the validator.
|
||||
_callbacks (dict): Dictionary to store various callback functions.
|
||||
"""
|
||||
self.dataloader = dataloader
|
||||
self.pbar = pbar
|
||||
@ -74,8 +84,14 @@ class BaseValidator:
|
||||
self.device = None
|
||||
self.batch_i = None
|
||||
self.training = True
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
self.names = None
|
||||
self.seen = None
|
||||
self.stats = None
|
||||
self.confusion_matrix = None
|
||||
self.nc = None
|
||||
self.iouv = None
|
||||
self.jdict = None
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
|
||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||
name = self.args.name or f'{self.args.mode}'
|
||||
@ -200,26 +216,26 @@ class BaseValidator:
|
||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
|
||||
return stats
|
||||
|
||||
def match_predictions(self, pred_classes: torch.Tensor, true_classes: torch.Tensor,
|
||||
iou: torch.Tensor) -> torch.Tensor:
|
||||
def match_predictions(self, pred_classes, true_classes, iou):
|
||||
"""
|
||||
Matches predictions to ground truth objects (pred_classes, true_classes) using IoU.
|
||||
|
||||
Args:
|
||||
pred_classes (torch.Tensor): Predicted class indices of shape(N,).
|
||||
true_classes (torch.Tensor): Target class indices of shape(M,).
|
||||
iou (torch.Tensor): IoU thresholds from 0.50 to 0.95 in space of 0.05.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds.
|
||||
"""
|
||||
correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
|
||||
correct_class = true_classes[:, None] == pred_classes
|
||||
for i in range(len(self.iouv)):
|
||||
x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match
|
||||
if x[0].shape[0]:
|
||||
for i, iouv in enumerate(self.iouv):
|
||||
x = torch.nonzero(iou.ge(iouv) & correct_class) # IoU > threshold and classes match
|
||||
if x.shape[0]:
|
||||
# Concatenate [label, detect, iou]
|
||||
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
|
||||
if x[0].shape[0] > 1:
|
||||
matches = torch.cat((x, iou[x[:, 0], x[:, 1]].unsqueeze(1)), 1).cpu().numpy()
|
||||
if x.shape[0] > 1:
|
||||
matches = matches[matches[:, 2].argsort()[::-1]]
|
||||
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
||||
# matches = matches[matches[:, 2].argsort()[::-1]]
|
||||
|
Reference in New Issue
Block a user