ultralytics 8.0.48 Edge TPU fix and Metrics updates (#1171)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: majid nasiri <majnasai@gmail.com>
This commit is contained in:
Glenn Jocher
2023-02-27 21:34:22 -08:00
committed by GitHub
parent a58f766f94
commit 74e4c94806
23 changed files with 426 additions and 245 deletions

View File

@ -13,7 +13,7 @@ from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_P
CLI_HELP_MSG = \
f"""
Arguments received: {str(['yolo'] + sys.argv[1:])}. Note that Ultralytics 'yolo' commands use the following syntax:
Arguments received: {str(['yolo'] + sys.argv[1:])}. Ultralytics 'yolo' commands use the following syntax:
yolo TASK MODE ARGS
@ -217,6 +217,9 @@ def entrypoint(debug=''):
if a.startswith('--'):
LOGGER.warning(f"WARNING ⚠️ '{a}' does not require leading dashes '--', updating to '{a[2:]}'.")
a = a[2:]
if a.endswith(','):
LOGGER.warning(f"WARNING ⚠️ '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.")
a = a[:-1]
if '=' in a:
try:
re.sub(r' *= *', '=', a) # remove spaces around equals sign
@ -284,6 +287,9 @@ def entrypoint(debug=''):
model = YOLO(model, task=task)
# Task Update
if task and task != model.task:
LOGGER.warning(f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. "
f'This may produce errors.')
task = task or model.task
overrides['task'] = task

View File

@ -243,15 +243,12 @@ class Exporter:
if coreml: # CoreML
f[4], _ = self._export_coreml()
if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
LOGGER.warning('WARNING ⚠️ YOLOv8 TensorFlow export is still under development. '
'Please consider contributing to the effort if you have TF expertise. Thank you!')
nms = False
self.args.int8 |= edgetpu
f[5], s_model = self._export_saved_model()
if pb or tfjs: # pb prerequisite to tfjs
f[6], _ = self._export_pb(s_model)
if tflite:
f[7], _ = self._export_tflite(s_model, nms=nms, agnostic_nms=self.args.agnostic_nms)
f[7], _ = self._export_tflite(s_model, nms=False, agnostic_nms=self.args.agnostic_nms)
if edgetpu:
f[8], _ = self._export_edgetpu(tflite_model=str(
Path(f[5]) / (self.file.stem + '_full_integer_quant.tflite'))) # int8 in/out
@ -619,20 +616,18 @@ class Exporter:
@try_export
def _export_edgetpu(self, tflite_model='', prefix=colorstr('Edge TPU:')):
# YOLOv8 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/
LOGGER.warning(f'{prefix} WARNING ⚠️ Edge TPU known bug https://github.com/ultralytics/ultralytics/issues/1185')
cmd = 'edgetpu_compiler --version'
help_url = 'https://coral.ai/docs/edgetpu/compiler/'
assert LINUX, f'export only supported on Linux. See {help_url}'
if subprocess.run(f'{cmd} > /dev/null', shell=True).returncode != 0:
if subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True).returncode != 0:
LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
for c in (
# 'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -', # errors
'wget --no-check-certificate -q -O - https://packages.cloud.google.com/apt/doc/apt-key.gpg | '
'sudo apt-key add -',
'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | ' # no comma
'sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
'sudo apt-get update',
'sudo apt-get install edgetpu-compiler'):
'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',
'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
'sudo apt-get update', 'sudo apt-get install edgetpu-compiler'):
subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)
ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]

View File

@ -43,7 +43,7 @@ class YOLO:
cfg (str): The model configuration if loaded from *.yaml file.
ckpt_path (str): The checkpoint file path.
overrides (dict): Overrides for the trainer object.
metrics_data (Any): The data for metrics.
metrics (Any): The data for metrics.
Methods:
__call__(source=None, stream=False, **kwargs):
@ -67,7 +67,7 @@ class YOLO:
list(ultralytics.yolo.engine.results.Results): The prediction results.
"""
def __init__(self, model='yolov8n.pt', task=None) -> None:
def __init__(self, model='yolov8n.pt', task=None, session=None) -> None:
"""
Initializes the YOLO model.
@ -83,7 +83,8 @@ class YOLO:
self.cfg = None # if loaded from *.yaml
self.ckpt_path = None
self.overrides = {} # overrides for trainer object
self.metrics_data = None
self.metrics = None # validation/training metrics
self.session = session # HUB session
# Load or create new YOLO model
suffix = Path(model).suffix
@ -184,6 +185,7 @@ class YOLO:
self._check_is_pytorch_model()
self.model.fuse()
@smart_inference_mode()
def predict(self, source=None, stream=False, **kwargs):
"""
Perform prediction using the YOLO model.
@ -217,7 +219,6 @@ class YOLO:
is_cli = sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
@smart_inference_mode()
def track(self, source=None, stream=False, **kwargs):
from ultralytics.tracker import register_tracker
register_tracker(self)
@ -252,7 +253,7 @@ class YOLO:
validator = TASK_MAP[self.task][2](args=args)
validator(model=self.model)
self.metrics_data = validator.metrics
self.metrics = validator.metrics
return validator.metrics
@ -314,12 +315,13 @@ class YOLO:
if not overrides.get('resume'): # manually set model only if not resuming
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model
self.trainer.hub_session = self.session # attach optional HUB session
self.trainer.train()
# update model and cfg after training
if RANK in {0, -1}:
self.model, _ = attempt_load_one_weight(str(self.trainer.best))
self.overrides = self.model.args
self.metrics_data = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
def to(self, device):
"""
@ -352,15 +354,6 @@ class YOLO:
"""
return self.model.transforms if hasattr(self.model, 'transforms') else None
@property
def metrics(self):
"""
Returns metrics if computed
"""
if not self.metrics_data:
LOGGER.info('No metrics data found! Run training or validation operation first.')
return self.metrics_data
@staticmethod
def add_callback(event: str, func):
"""

View File

@ -139,7 +139,8 @@ class Results:
annotator.masks(masks.data, colors=[colors(x, True) for x in boxes.cls], im_gpu=im)
if logits is not None:
top5i = logits.argsort(0, descending=True)[:5].tolist() # top 5 indices
n5 = min(len(self.names), 5)
top5i = logits.argsort(0, descending=True)[:n5].tolist() # top 5 indices
text = f"{', '.join(f'{names[j] if names else j} {logits[j]:.2f}' for j in top5i)}, "
annotator.text((32, 32), text, txt_color=(255, 255, 255)) # TODO: allow setting colors

View File

@ -243,6 +243,24 @@ def is_docker() -> bool:
return False
def is_online() -> bool:
"""
Check internet connectivity by attempting to connect to a known online host.
Returns:
bool: True if connection is successful, False otherwise.
"""
import socket
with contextlib.suppress(Exception):
host = socket.gethostbyname('www.github.com')
socket.create_connection((host, 80), timeout=2)
return True
return False
ONLINE = is_online()
def is_pip_package(filepath: str = __name__) -> bool:
"""
Determines if the file at the given filepath is part of a pip package.
@ -513,6 +531,7 @@ def set_sentry():
RANK in {-1, 0} and \
Path(sys.argv[0]).name == 'yolo' and \
not TESTS_RUNNING and \
ONLINE and \
((is_pip_package() and not is_git_dir()) or
(get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git' and get_git_branch() == 'main')):

View File

@ -151,4 +151,5 @@ def add_integration_callbacks(instance):
for x in clearml_callbacks, comet_callbacks, hub_callbacks, tb_callbacks:
for k, v in x.items():
instance.callbacks[k].append(v) # callback[name].append(func)
if v not in instance.callbacks[k]: # prevent duplicate callbacks addition
instance.callbacks[k].append(v) # callback[name].append(func)

View File

@ -4,24 +4,33 @@ import json
from time import time
from ultralytics.hub.utils import PREFIX, traces
from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING
from ultralytics.yolo.utils import LOGGER
from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
def on_pretrain_routine_end(trainer):
session = not TESTS_RUNNING and getattr(trainer, 'hub_session', None)
session = getattr(trainer, 'hub_session', None)
if session:
# Start timer for upload rate limit
LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀')
session.t = {'metrics': time(), 'ckpt': time()} # start timer on self.rate_limit
session.timers = {'metrics': time(), 'ckpt': time()} # start timer on session.rate_limit
def on_fit_epoch_end(trainer):
session = getattr(trainer, 'hub_session', None)
if session:
session.metrics_queue[trainer.epoch] = json.dumps(trainer.metrics) # json string
if time() - session.t['metrics'] > session.rate_limits['metrics']:
# Upload metrics after val end
all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics}
if trainer.epoch == 0:
model_info = {
'model/parameters': get_num_params(trainer.model),
'model/GFLOPs': round(get_flops(trainer.model), 3),
'model/speed(ms)': round(trainer.validator.speed['inference'], 3)}
all_plots = {**all_plots, **model_info}
session.metrics_queue[trainer.epoch] = json.dumps(all_plots)
if time() - session.timers['metrics'] > session.rate_limits['metrics']:
session.upload_metrics()
session.t['metrics'] = time() # reset timer
session.timers['metrics'] = time() # reset timer
session.metrics_queue = {} # reset queue
@ -30,21 +39,21 @@ def on_model_save(trainer):
if session:
# Upload checkpoints with rate limiting
is_best = trainer.best_fitness == trainer.fitness
if time() - session.t['ckpt'] > session.rate_limits['ckpt']:
if time() - session.timers['ckpt'] > session.rate_limits['ckpt']:
LOGGER.info(f'{PREFIX}Uploading checkpoint {session.model_id}')
session.upload_model(trainer.epoch, trainer.last, is_best)
session.t['ckpt'] = time() # reset timer
session.timers['ckpt'] = time() # reset timer
def on_train_end(trainer):
session = getattr(trainer, 'hub_session', None)
if session:
# Upload final model and metrics with exponential standoff
LOGGER.info(f'{PREFIX}Training completed successfully ✅\n'
f'{PREFIX}Uploading final {session.model_id}')
session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics['metrics/mAP50-95(B)'], final=True)
session.shutdown() # stop heartbeats
LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀')
LOGGER.info(f'{PREFIX}Syncing final model...')
session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics.get('metrics/mAP50-95(B)', 0), final=True)
session.alive = False # stop heartbeats
LOGGER.info(f'{PREFIX}Done ✅\n'
f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀')
def on_train_start(trainer):

View File

@ -1,8 +1,12 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING
from torch.utils.tensorboard import SummaryWriter
try:
from torch.utils.tensorboard import SummaryWriter
from ultralytics.yolo.utils import LOGGER
assert not TESTS_RUNNING # do not log pytest
except (ImportError, AssertionError):
SummaryWriter = None
writer = None # TensorBoard SummaryWriter instance
@ -18,7 +22,6 @@ def on_pretrain_routine_start(trainer):
try:
writer = SummaryWriter(str(trainer.save_dir))
except Exception as e:
writer = None # TensorBoard SummaryWriter instance
LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}')

View File

@ -21,7 +21,7 @@ import torch
from matplotlib import font_manager
from ultralytics.yolo.utils import (AUTOINSTALL, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, downloads, emojis,
is_colab, is_docker, is_jupyter)
is_colab, is_docker, is_jupyter, is_online)
def is_ascii(s) -> bool:
@ -171,21 +171,6 @@ def check_font(font='Arial.ttf'):
return file
def check_online() -> bool:
"""
Check internet connectivity by attempting to connect to a known online host.
Returns:
bool: True if connection is successful, False otherwise.
"""
import socket
with contextlib.suppress(Exception):
host = socket.gethostbyname('www.github.com')
socket.create_connection((host, 80), timeout=2)
return True
return False
def check_python(minimum: str = '3.7.0') -> bool:
"""
Check current python version against the required minimum version.
@ -229,7 +214,7 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
if s and install and AUTOINSTALL: # check environment variable
LOGGER.info(f"{prefix} YOLOv8 requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...")
try:
assert check_online(), 'AutoUpdate skipped (offline)'
assert is_online(), 'AutoUpdate skipped (offline)'
LOGGER.info(subprocess.check_output(f'pip install {s} {cmds}', shell=True).decode())
s = f"{prefix} {n} package{'s' * (n > 1)} updated per {file or requirements}\n" \
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
@ -249,13 +234,13 @@ def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''):
assert s in suffix, f'{msg}{f} acceptable suffix is {suffix}'
def check_yolov5u_filename(file: str):
def check_yolov5u_filename(file: str, verbose: bool = True):
# Replace legacy YOLOv5 filenames with updated YOLOv5u filenames
if 'yolov3' in file or 'yolov5' in file and 'u' not in file:
original_file = file
file = re.sub(r'(.*yolov5([nsmlx]))\.', '\\1u.', file) # i.e. yolov5n.pt -> yolov5nu.pt
file = re.sub(r'(.*yolov3(|-tiny|-spp))\.', '\\1u.', file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
if file != original_file:
if file != original_file and verbose:
LOGGER.info(f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are "
f'trained with https://github.com/ultralytics/ultralytics and feature improved performance vs '
f'standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n')

View File

@ -12,7 +12,7 @@ import requests
import torch
from tqdm import tqdm
from ultralytics.yolo.utils import LOGGER, checks
from ultralytics.yolo.utils import LOGGER, checks, is_online
GITHUB_ASSET_NAMES = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] + \
[f'yolov5{size}u.pt' for size in 'nsmlx'] + \
@ -112,7 +112,7 @@ def safe_download(url,
break # success
f.unlink() # remove partial downloads
except Exception as e:
if i == 0 and not checks.check_online():
if i == 0 and not is_online():
raise ConnectionError(f'❌ Download failure for {url}. Environment is not online.') from e
elif i >= retry:
raise ConnectionError(f'❌ Download failure for {url}. Retry limit reached.') from e
@ -134,8 +134,7 @@ def safe_download(url,
def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'):
# Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
from ultralytics.yolo.utils import SETTINGS
from ultralytics.yolo.utils.checks import check_yolov5u_filename
from ultralytics.yolo.utils import SETTINGS # scoped for circular import
def github_assets(repository, version='latest'):
# Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov8s.pt', ...])
@ -146,7 +145,7 @@ def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'):
# YOLOv3/5u updates
file = str(file)
file = check_yolov5u_filename(file)
file = checks.check_yolov5u_filename(file)
file = Path(file.strip().replace("'", ''))
if file.exists():
return str(file)

View File

@ -43,16 +43,18 @@ def bbox_ioa(box1, box2, eps=1e-7):
def box_iou(box1, box2, eps=1e-7):
# https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
"""
Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
Arguments:
box1 (Tensor[N, 4])
box2 (Tensor[M, 4])
eps
Returns:
iou (Tensor[N, M]): the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2
iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
"""
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
@ -109,7 +111,7 @@ def mask_iou(mask1, mask2, eps=1e-7):
mask1: [N, n] m1 means number of predicted objects
mask2: [M, n] m2 means number of gt objects
Note: n means image_w x image_h
return: masks iou, [N, M]
Returns: masks iou, [N, M]
"""
intersection = torch.matmul(mask1, mask2.t()).clamp(0)
union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection # (area1 + area2) - intersection
@ -121,7 +123,7 @@ def masks_iou(mask1, mask2, eps=1e-7):
mask1: [N, n] m1 means number of predicted objects
mask2: [N, n] m2 means number of gt objects
Note: n means image_w x image_h
return: masks iou, (N, )
Returns: masks iou, (N, )
"""
intersection = (mask1 * mask2).sum(1).clamp(0) # (N, )
union = (mask1.sum(1) + mask2.sum(1))[None] - intersection # (area1 + area2) - intersection
@ -317,10 +319,10 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi
def compute_ap(recall, precision):
""" Compute the average precision, given the recall and precision curves
# Arguments
Arguments:
recall: The recall curve (list)
precision: The precision curve (list)
# Returns
Returns:
Average precision, precision curve, recall curve
"""
@ -344,17 +346,30 @@ def compute_ap(recall, precision):
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), names=(), eps=1e-16, prefix=''):
""" Compute the average precision, given the recall and precision curves.
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
# Arguments
tp: True positives (nparray, nx1 or nx10).
conf: Objectness value from 0-1 (nparray).
pred_cls: Predicted object classes (nparray).
target_cls: True object classes (nparray).
plot: Plot precision-recall curve at mAP@0.5
save_dir: Plot save directory
# Returns
The average precision as computed in py-faster-rcnn.
"""
Computes the average precision per class for object detection evaluation.
Args:
tp (np.ndarray): Binary array indicating whether the detection is correct (True) or not (False).
conf (np.ndarray): Array of confidence scores of the detections.
pred_cls (np.ndarray): Array of predicted classes of the detections.
target_cls (np.ndarray): Array of true classes of the detections.
plot (bool, optional): Whether to plot PR curves or not. Defaults to False.
save_dir (Path, optional): Directory to save the PR curves. Defaults to an empty path.
names (tuple, optional): Tuple of class names to plot PR curves. Defaults to an empty tuple.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-16.
prefix (str, optional): A prefix string for saving the plot files. Defaults to an empty string.
Returns:
(tuple): A tuple of six arrays and one array of unique classes, where:
tp (np.ndarray): True positive counts for each class.
fp (np.ndarray): False positive counts for each class.
p (np.ndarray): Precision values at each confidence threshold.
r (np.ndarray): Recall values at each confidence threshold.
f1 (np.ndarray): F1-score values at each confidence threshold.
ap (np.ndarray): Average precision for each class at different IoU thresholds.
unique_classes (np.ndarray): An array of unique classes that have data.
"""
# Sort by objectness
@ -411,6 +426,32 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), na
class Metric:
"""
Class for computing evaluation metrics for YOLOv8 model.
Attributes:
p (list): Precision for each class. Shape: (nc,).
r (list): Recall for each class. Shape: (nc,).
f1 (list): F1 score for each class. Shape: (nc,).
all_ap (list): AP scores for all classes and all IoU thresholds. Shape: (nc, 10).
ap_class_index (list): Index of class for each AP score. Shape: (nc,).
nc (int): Number of classes.
Methods:
ap50(): AP at IoU threshold of 0.5 for all classes. Returns: List of AP scores. Shape: (nc,) or [].
ap(): AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: List of AP scores. Shape: (nc,) or [].
mp(): Mean precision of all classes. Returns: Float.
mr(): Mean recall of all classes. Returns: Float.
map50(): Mean AP at IoU threshold of 0.5 for all classes. Returns: Float.
map75(): Mean AP at IoU threshold of 0.75 for all classes. Returns: Float.
map(): Mean AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: Float.
mean_results(): Mean of results, returns mp, mr, map50, map.
class_result(i): Class-aware result, returns p[i], r[i], ap50[i], ap[i].
maps(): mAP of each class. Returns: Array of mAP scores, shape: (nc,).
fitness(): Model fitness as a weighted combination of metrics. Returns: Float.
update(results): Update metric attributes with new evaluation results.
"""
def __init__(self) -> None:
self.p = [] # (nc, )
@ -420,10 +461,14 @@ class Metric:
self.ap_class_index = [] # (nc, )
self.nc = 0
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
@property
def ap50(self):
"""AP@0.5 of all classes.
Return:
Returns:
(nc, ) or [].
"""
return self.all_ap[:, 0] if len(self.all_ap) else []
@ -431,7 +476,7 @@ class Metric:
@property
def ap(self):
"""AP@0.5:0.95
Return:
Returns:
(nc, ) or [].
"""
return self.all_ap.mean(1) if len(self.all_ap) else []
@ -439,7 +484,7 @@ class Metric:
@property
def mp(self):
"""mean precision of all classes.
Return:
Returns:
float.
"""
return self.p.mean() if len(self.p) else 0.0
@ -447,7 +492,7 @@ class Metric:
@property
def mr(self):
"""mean recall of all classes.
Return:
Returns:
float.
"""
return self.r.mean() if len(self.r) else 0.0
@ -455,7 +500,7 @@ class Metric:
@property
def map50(self):
"""Mean AP@0.5 of all classes.
Return:
Returns:
float.
"""
return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0
@ -463,7 +508,7 @@ class Metric:
@property
def map75(self):
"""Mean AP@0.75 of all classes.
Return:
Returns:
float.
"""
return self.all_ap[:, 5].mean() if len(self.all_ap) else 0.0
@ -471,7 +516,7 @@ class Metric:
@property
def map(self):
"""Mean AP@0.5:0.95 of all classes.
Return:
Returns:
float.
"""
return self.all_ap.mean() if len(self.all_ap) else 0.0
@ -506,6 +551,32 @@ class Metric:
class DetMetrics:
"""
This class is a utility class for computing detection metrics such as precision, recall, and mean average precision
(mAP) of an object detection model.
Args:
save_dir (Path): A path to the directory where the output plots will be saved. Defaults to current directory.
plot (bool): A flag that indicates whether to plot precision-recall curves for each class. Defaults to False.
names (tuple of str): A tuple of strings that represents the names of the classes. Defaults to an empty tuple.
Attributes:
save_dir (Path): A path to the directory where the output plots will be saved.
plot (bool): A flag that indicates whether to plot the precision-recall curves for each class.
names (tuple of str): A tuple of strings that represents the names of the classes.
box (Metric): An instance of the Metric class for storing the results of the detection metrics.
speed (dict): A dictionary for storing the execution time of different parts of the detection process.
Methods:
process(tp, conf, pred_cls, target_cls): Updates the metric results with the latest batch of predictions.
keys: Returns a list of keys for accessing the computed detection metrics.
mean_results: Returns a list of mean values for the computed detection metrics.
class_result(i): Returns a list of values for the computed detection metrics for a specific class.
maps: Returns a dictionary of mean average precision (mAP) values for different IoU thresholds.
fitness: Computes the fitness score based on the computed detection metrics.
ap_class_index: Returns a list of class indices sorted by their average precision (AP) values.
results_dict: Returns a dictionary that maps detection metric keys to their computed values.
"""
def __init__(self, save_dir=Path('.'), plot=False, names=()) -> None:
self.save_dir = save_dir
@ -514,6 +585,10 @@ class DetMetrics:
self.box = Metric()
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def process(self, tp, conf, pred_cls, target_cls):
results = ap_per_class(tp, conf, pred_cls, target_cls, plot=self.plot, save_dir=self.save_dir,
names=self.names)[2:]
@ -548,6 +623,31 @@ class DetMetrics:
class SegmentMetrics:
"""
Calculates and aggregates detection and segmentation metrics over a given set of classes.
Args:
save_dir (Path): Path to the directory where the output plots should be saved. Default is the current directory.
plot (bool): Whether to save the detection and segmentation plots. Default is False.
names (list): List of class names. Default is an empty list.
Attributes:
save_dir (Path): Path to the directory where the output plots should be saved.
plot (bool): Whether to save the detection and segmentation plots.
names (list): List of class names.
box (Metric): An instance of the Metric class to calculate box detection metrics.
seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
speed (dict): Dictionary to store the time taken in different phases of inference.
Methods:
process(tp_m, tp_b, conf, pred_cls, target_cls): Processes metrics over the given set of predictions.
mean_results(): Returns the mean of the detection and segmentation metrics over all the classes.
class_result(i): Returns the detection and segmentation metrics of class `i`.
maps: Returns the mean Average Precision (mAP) scores for IoU thresholds ranging from 0.50 to 0.95.
fitness: Returns the fitness scores, which are a single weighted combination of metrics.
ap_class_index: Returns the list of indices of classes used to compute Average Precision (AP).
results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
"""
def __init__(self, save_dir=Path('.'), plot=False, names=()) -> None:
self.save_dir = save_dir
@ -557,7 +657,22 @@ class SegmentMetrics:
self.seg = Metric()
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def process(self, tp_m, tp_b, conf, pred_cls, target_cls):
"""
Processes the detection and segmentation metrics over the given set of predictions.
Args:
tp_m (list): List of True Positive masks.
tp_b (list): List of True Positive boxes.
conf (list): List of confidence scores.
pred_cls (list): List of predicted classes.
target_cls (list): List of target classes.
"""
results_mask = ap_per_class(tp_m,
conf,
pred_cls,
@ -610,12 +725,32 @@ class SegmentMetrics:
class ClassifyMetrics:
"""
Class for computing classification metrics including top-1 and top-5 accuracy.
Attributes:
top1 (float): The top-1 accuracy.
top5 (float): The top-5 accuracy.
speed (Dict[str, float]): A dictionary containing the time taken for each step in the pipeline.
Properties:
fitness (float): The fitness of the model, which is equal to top-5 accuracy.
results_dict (Dict[str, Union[float, str]]): A dictionary containing the classification metrics and fitness.
keys (List[str]): A list of keys for the results_dict.
Methods:
process(targets, pred): Processes the targets and predictions to compute classification metrics.
"""
def __init__(self) -> None:
self.top1 = 0
self.top5 = 0
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def process(self, targets, pred):
# target classes and predicted classes
pred, targets = torch.cat(pred), torch.cat(targets)

View File

@ -301,14 +301,14 @@ def plot_images(images,
# Plot masks
if len(masks):
if masks.max() > 1.0: # mean that masks are overlap
if idx.shape[0] == masks.shape[0]: # overlap_masks=False
image_masks = masks[idx]
else: # overlap_masks=True
image_masks = masks[[i]] # (1, 640, 640)
nl = idx.sum()
index = np.arange(nl).reshape(nl, 1, 1) + 1
image_masks = np.repeat(image_masks, nl, axis=0)
image_masks = np.where(image_masks == index, 1.0, 0.0)
else:
image_masks = masks[idx]
im = np.asarray(annotator.im).copy()
for j, box in enumerate(boxes.T.tolist()):

View File

@ -52,7 +52,8 @@ class ClassificationPredictor(BasePredictor):
return log_string
prob = result.probs
# Print results
top5i = prob.argsort(0, descending=True)[:5].tolist() # top 5 indices
n5 = min(len(self.model.names), 5)
top5i = prob.argsort(0, descending=True)[:n5].tolist() # top 5 indices
log_string += f"{', '.join(f'{self.model.names[j]} {prob[j]:.2f}' for j in top5i)}, "
# write

View File

@ -27,7 +27,8 @@ class ClassificationValidator(BaseValidator):
return batch
def update_metrics(self, preds, batch):
self.pred.append(preds.argsort(1, descending=True)[:, :5])
n5 = min(len(self.model.names), 5)
self.pred.append(preds.argsort(1, descending=True)[:, :n5])
self.targets.append(batch['cls'])
def finalize_metrics(self, *args, **kwargs):