ultralytics 8.0.79
expand Docs reference section (#2053)
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Fri3dChicken <87434761+AmoghDhaliwal@users.noreply.github.com>
This commit is contained in:
@ -262,13 +262,15 @@ class RandomPerspective:
|
||||
return img, M, s
|
||||
|
||||
def apply_bboxes(self, bboxes, M):
|
||||
"""apply affine to bboxes only.
|
||||
"""
|
||||
Apply affine to bboxes only.
|
||||
|
||||
Args:
|
||||
bboxes(ndarray): list of bboxes, xyxy format, with shape (num_bboxes, 4).
|
||||
M(ndarray): affine matrix.
|
||||
bboxes (ndarray): list of bboxes, xyxy format, with shape (num_bboxes, 4).
|
||||
M (ndarray): affine matrix.
|
||||
|
||||
Returns:
|
||||
new_bboxes(ndarray): bboxes after affine, [num_bboxes, 4].
|
||||
new_bboxes (ndarray): bboxes after affine, [num_bboxes, 4].
|
||||
"""
|
||||
n = len(bboxes)
|
||||
if n == 0:
|
||||
@ -285,14 +287,16 @@ class RandomPerspective:
|
||||
return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1)), dtype=bboxes.dtype).reshape(4, n).T
|
||||
|
||||
def apply_segments(self, segments, M):
|
||||
"""apply affine to segments and generate new bboxes from segments.
|
||||
"""
|
||||
Apply affine to segments and generate new bboxes from segments.
|
||||
|
||||
Args:
|
||||
segments(ndarray): list of segments, [num_samples, 500, 2].
|
||||
M(ndarray): affine matrix.
|
||||
segments (ndarray): list of segments, [num_samples, 500, 2].
|
||||
M (ndarray): affine matrix.
|
||||
|
||||
Returns:
|
||||
new_segments(ndarray): list of segments after affine, [num_samples, 500, 2].
|
||||
new_bboxes(ndarray): bboxes after affine, [N, 4].
|
||||
new_segments (ndarray): list of segments after affine, [num_samples, 500, 2].
|
||||
new_bboxes (ndarray): bboxes after affine, [N, 4].
|
||||
"""
|
||||
n, num = segments.shape[:2]
|
||||
if n == 0:
|
||||
@ -308,13 +312,15 @@ class RandomPerspective:
|
||||
return bboxes, segments
|
||||
|
||||
def apply_keypoints(self, keypoints, M):
|
||||
"""apply affine to keypoints.
|
||||
"""
|
||||
Apply affine to keypoints.
|
||||
|
||||
Args:
|
||||
keypoints(ndarray): keypoints, [N, 17, 3].
|
||||
M(ndarray): affine matrix.
|
||||
keypoints (ndarray): keypoints, [N, 17, 3].
|
||||
M (ndarray): affine matrix.
|
||||
|
||||
Return:
|
||||
new_keypoints(ndarray): keypoints after affine, [N, 17, 3].
|
||||
new_keypoints (ndarray): keypoints after affine, [N, 17, 3].
|
||||
"""
|
||||
n, nkpt = keypoints.shape[:2]
|
||||
if n == 0:
|
||||
@ -333,7 +339,7 @@ class RandomPerspective:
|
||||
Affine images and targets.
|
||||
|
||||
Args:
|
||||
labels(Dict): a dict of `bboxes`, `segments`, `keypoints`.
|
||||
labels (dict): a dict of `bboxes`, `segments`, `keypoints`.
|
||||
"""
|
||||
if self.pre_transform and 'mosaic_border' not in labels:
|
||||
labels = self.pre_transform(labels)
|
||||
|
@ -18,11 +18,30 @@ from .utils import HELP_URL, IMG_FORMATS
|
||||
|
||||
|
||||
class BaseDataset(Dataset):
|
||||
"""Base Dataset.
|
||||
"""
|
||||
Base dataset class for loading and processing image data.
|
||||
|
||||
Args:
|
||||
img_path (str): image path.
|
||||
pipeline (dict): a dict of image transforms.
|
||||
label_path (str): label path, this can also be an ann_file or other custom label path.
|
||||
img_path (str): Image path.
|
||||
imgsz (int): Target image size for resizing. Default is 640.
|
||||
cache (bool): Cache images in memory or on disk for faster loading. Default is False.
|
||||
augment (bool): Apply data augmentation. Default is True.
|
||||
hyp (dict): Dictionary of hyperparameters for data augmentation. Default is None.
|
||||
prefix (str): Prefix for file paths. Default is an empty string.
|
||||
rect (bool): Enable rectangular training. Default is False.
|
||||
batch_size (int): Batch size for rectangular training. Default is None.
|
||||
stride (int): Stride for rectangular training. Default is 32.
|
||||
pad (float): Padding for rectangular training. Default is 0.5.
|
||||
single_cls (bool): Use a single class for all labels. Default is False.
|
||||
classes (list): List of included classes. Default is None.
|
||||
|
||||
Attributes:
|
||||
im_files (list): List of image file paths.
|
||||
labels (list): List of label data dictionaries.
|
||||
ni (int): Number of images in the dataset.
|
||||
ims (list): List of loaded images.
|
||||
npy_files (list): List of numpy file paths.
|
||||
transforms (callable): Image transformation function.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -21,10 +21,7 @@ from .utils import PIN_MEMORY
|
||||
|
||||
|
||||
class InfiniteDataLoader(dataloader.DataLoader):
|
||||
"""Dataloader that reuses workers
|
||||
|
||||
Uses same syntax as vanilla DataLoader
|
||||
"""
|
||||
"""Dataloader that reuses workers. Uses same syntax as vanilla DataLoader."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -40,10 +37,11 @@ class InfiniteDataLoader(dataloader.DataLoader):
|
||||
|
||||
|
||||
class _RepeatSampler:
|
||||
"""Sampler that repeats forever
|
||||
"""
|
||||
Sampler that repeats forever.
|
||||
|
||||
Args:
|
||||
sampler (Sampler)
|
||||
sampler (Dataset.sampler): The sampler to repeat.
|
||||
"""
|
||||
|
||||
def __init__(self, sampler):
|
||||
@ -173,7 +171,7 @@ def load_inference_source(source=None, transforms=None, imgsz=640, vid_stride=1,
|
||||
auto (bool, optional): Automatically apply pre-processing. Default is True.
|
||||
|
||||
Returns:
|
||||
dataset: A dataset object for the specified input source.
|
||||
dataset (Dataset): A dataset object for the specified input source.
|
||||
"""
|
||||
source, webcam, screenshot, from_img, in_memory, tensor = check_source(source)
|
||||
source_type = source.source_type if in_memory else SourceTypes(webcam, screenshot, from_img, tensor)
|
||||
|
@ -178,7 +178,7 @@ class _RepeatSampler:
|
||||
""" Sampler that repeats forever
|
||||
|
||||
Args:
|
||||
sampler (Sampler)
|
||||
sampler (Dataset.sampler): The sampler to repeat.
|
||||
"""
|
||||
|
||||
def __init__(self, sampler):
|
||||
|
@ -17,31 +17,31 @@ from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image_lab
|
||||
|
||||
|
||||
class YOLODataset(BaseDataset):
|
||||
cache_version = '1.0.2' # dataset labels *.cache version, >= 1.0.0 for YOLOv8
|
||||
rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
|
||||
"""
|
||||
Dataset class for loading images object detection and/or segmentation labels in YOLO format.
|
||||
Dataset class for loading object detection and/or segmentation labels in YOLO format.
|
||||
|
||||
Args:
|
||||
img_path (str): path to the folder containing images.
|
||||
imgsz (int): image size (default: 640).
|
||||
cache (bool): if True, a cache file of the labels is created to speed up future creation of dataset instances
|
||||
(default: False).
|
||||
augment (bool): if True, data augmentation is applied (default: True).
|
||||
hyp (dict): hyperparameters to apply data augmentation (default: None).
|
||||
prefix (str): prefix to print in log messages (default: '').
|
||||
rect (bool): if True, rectangular training is used (default: False).
|
||||
batch_size (int): size of batches (default: None).
|
||||
stride (int): stride (default: 32).
|
||||
pad (float): padding (default: 0.0).
|
||||
single_cls (bool): if True, single class training is used (default: False).
|
||||
use_segments (bool): if True, segmentation masks are used as labels (default: False).
|
||||
use_keypoints (bool): if True, keypoints are used as labels (default: False).
|
||||
names (dict): A dictionary of class names. (default: None).
|
||||
img_path (str): Path to the folder containing images.
|
||||
imgsz (int, optional): Image size. Defaults to 640.
|
||||
cache (bool, optional): Cache images to RAM or disk during training. Defaults to False.
|
||||
augment (bool, optional): If True, data augmentation is applied. Defaults to True.
|
||||
hyp (dict, optional): Hyperparameters to apply data augmentation. Defaults to None.
|
||||
prefix (str, optional): Prefix to print in log messages. Defaults to ''.
|
||||
rect (bool, optional): If True, rectangular training is used. Defaults to False.
|
||||
batch_size (int, optional): Size of batches. Defaults to None.
|
||||
stride (int, optional): Stride. Defaults to 32.
|
||||
pad (float, optional): Padding. Defaults to 0.0.
|
||||
single_cls (bool, optional): If True, single class training is used. Defaults to False.
|
||||
use_segments (bool, optional): If True, segmentation masks are used as labels. Defaults to False.
|
||||
use_keypoints (bool, optional): If True, keypoints are used as labels. Defaults to False.
|
||||
data (dict, optional): A dataset YAML dictionary. Defaults to None.
|
||||
classes (list): List of included classes. Default is None.
|
||||
|
||||
Returns:
|
||||
A PyTorch dataset object that can be used for training an object detection or segmentation model.
|
||||
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
|
||||
"""
|
||||
cache_version = '1.0.2' # dataset labels *.cache version, >= 1.0.0 for YOLOv8
|
||||
rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
|
||||
|
||||
def __init__(self,
|
||||
img_path,
|
||||
|
@ -7,21 +7,36 @@ from .augment import LetterBox
|
||||
|
||||
|
||||
class MixAndRectDataset:
|
||||
"""A wrapper of multiple images mixed dataset.
|
||||
"""
|
||||
A dataset class that applies mosaic and mixup transformations as well as rectangular training.
|
||||
|
||||
Args:
|
||||
dataset (:obj:`BaseDataset`): The dataset to be mixed.
|
||||
transforms (Sequence[dict]): config dict to be composed.
|
||||
Attributes:
|
||||
dataset: The base dataset.
|
||||
imgsz: The size of the images in the dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset):
|
||||
"""
|
||||
Args:
|
||||
dataset (BaseDataset): The base dataset to apply transformations to.
|
||||
"""
|
||||
self.dataset = dataset
|
||||
self.imgsz = dataset.imgsz
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the number of items in the dataset."""
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Applies mosaic, mixup and rectangular training transformations to an item in the dataset.
|
||||
|
||||
Args:
|
||||
index (int): Index of the item in the dataset.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary containing the transformed item data.
|
||||
"""
|
||||
labels = deepcopy(self.dataset[index])
|
||||
for transform in self.dataset.transforms.tolist():
|
||||
# mosaic and mixup
|
||||
|
@ -270,9 +270,8 @@ def check_cls_dataset(dataset: str):
|
||||
"""
|
||||
Check a classification dataset such as Imagenet.
|
||||
|
||||
Copy code
|
||||
This function takes a `dataset` name as input and returns a dictionary containing information about the dataset.
|
||||
If the dataset is not found, it attempts to download the dataset from the internet and save it to the local file system.
|
||||
If the dataset is not found, it attempts to download the dataset from the internet and save it locally.
|
||||
|
||||
Args:
|
||||
dataset (str): Name of the dataset.
|
||||
@ -306,7 +305,8 @@ def check_cls_dataset(dataset: str):
|
||||
|
||||
|
||||
class HUBDatasetStats():
|
||||
""" Class for generating HUB dataset JSON and `-hub` dataset directory
|
||||
"""
|
||||
Class for generating HUB dataset JSON and `-hub` dataset directory
|
||||
|
||||
Arguments
|
||||
path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
|
||||
@ -427,9 +427,6 @@ def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
|
||||
max_dim (int, optional): The maximum dimension (width or height) of the output image. Default is 1920 pixels.
|
||||
quality (int, optional): The image compression quality as a percentage. Default is 50%.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Usage:
|
||||
from pathlib import Path
|
||||
from ultralytics.yolo.data.utils import compress_one_image
|
||||
@ -459,9 +456,6 @@ def delete_dsstore(path):
|
||||
Args:
|
||||
path (str, optional): The directory path where the ".DS_store" files should be deleted.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Usage:
|
||||
from ultralytics.yolo.data.utils import delete_dsstore
|
||||
delete_dsstore('/Users/glennjocher/Downloads/dataset')
|
||||
@ -478,15 +472,13 @@ def delete_dsstore(path):
|
||||
|
||||
|
||||
def zip_directory(dir, use_zipfile_library=True):
|
||||
"""Zips a directory and saves the archive to the specified output path.
|
||||
"""
|
||||
Zips a directory and saves the archive to the specified output path.
|
||||
|
||||
Args:
|
||||
dir (str): The path to the directory to be zipped.
|
||||
use_zipfile_library (bool): Whether to use zipfile library or shutil for zipping.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Usage:
|
||||
from ultralytics.yolo.data.utils import zip_directory
|
||||
zip_directory('/Users/glennjocher/Downloads/playground')
|
||||
|
@ -73,7 +73,7 @@ ARM64 = platform.machine() in ('arm64', 'aarch64')
|
||||
|
||||
|
||||
def export_formats():
|
||||
# YOLOv8 export formats
|
||||
"""YOLOv8 export formats"""
|
||||
import pandas
|
||||
x = [
|
||||
['PyTorch', '-', '.pt', True, True],
|
||||
@ -92,7 +92,7 @@ def export_formats():
|
||||
|
||||
|
||||
def gd_outputs(gd):
|
||||
# TensorFlow GraphDef model output node names
|
||||
"""TensorFlow GraphDef model output node names"""
|
||||
name_list, input_list = [], []
|
||||
for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
|
||||
name_list.append(node.name)
|
||||
@ -101,7 +101,7 @@ def gd_outputs(gd):
|
||||
|
||||
|
||||
def try_export(inner_func):
|
||||
# YOLOv8 export decorator, i..e @try_export
|
||||
"""YOLOv8 export decorator, i..e @try_export"""
|
||||
inner_args = get_default_args(inner_func)
|
||||
|
||||
def outer_func(*args, **kwargs):
|
||||
@ -118,10 +118,26 @@ def try_export(inner_func):
|
||||
return outer_func
|
||||
|
||||
|
||||
class iOSDetectModel(torch.nn.Module):
|
||||
"""Wrap an Ultralytics YOLO model for iOS export"""
|
||||
|
||||
def __init__(self, model, im):
|
||||
super().__init__()
|
||||
b, c, h, w = im.shape # batch, channel, height, width
|
||||
self.model = model
|
||||
self.nc = len(model.names) # number of classes
|
||||
if w == h:
|
||||
self.normalize = 1.0 / w # scalar
|
||||
else:
|
||||
self.normalize = torch.tensor([1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h]) # broadcast (slower, smaller)
|
||||
|
||||
def forward(self, x):
|
||||
xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
|
||||
return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)
|
||||
|
||||
|
||||
class Exporter:
|
||||
"""
|
||||
Exporter
|
||||
|
||||
A class for exporting a model.
|
||||
|
||||
Attributes:
|
||||
@ -136,6 +152,7 @@ class Exporter:
|
||||
Args:
|
||||
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
_callbacks (list, optional): List of callback functions. Defaults to None.
|
||||
"""
|
||||
self.args = get_cfg(cfg, overrides)
|
||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||
@ -385,22 +402,6 @@ class Exporter:
|
||||
check_requirements('coremltools>=6.0')
|
||||
import coremltools as ct # noqa
|
||||
|
||||
class iOSDetectModel(torch.nn.Module):
|
||||
# Wrap an Ultralytics YOLO model for iOS export
|
||||
def __init__(self, model, im):
|
||||
super().__init__()
|
||||
b, c, h, w = im.shape # batch, channel, height, width
|
||||
self.model = model
|
||||
self.nc = len(model.names) # number of classes
|
||||
if w == h:
|
||||
self.normalize = 1.0 / w # scalar
|
||||
else:
|
||||
self.normalize = torch.tensor([1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h]) # broadcast (slower, smaller)
|
||||
|
||||
def forward(self, x):
|
||||
xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
|
||||
return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
|
||||
f = self.file.with_suffix('.mlmodel')
|
||||
|
||||
|
@ -400,7 +400,7 @@ class YOLO:
|
||||
train_args (dict, optional): Additional arguments to pass to the `train()` method. Defaults to {}.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the results of the hyperparameter search.
|
||||
(dict): A dictionary containing the results of the hyperparameter search.
|
||||
|
||||
Raises:
|
||||
ModuleNotFoundError: If Ray Tune is not installed.
|
||||
|
@ -127,7 +127,10 @@ class BasePredictor:
|
||||
log_string += result.verbose()
|
||||
|
||||
if self.args.save or self.args.show: # Add bbox to image
|
||||
plot_args = dict(line_width=self.args.line_thickness, boxes=self.args.boxes)
|
||||
plot_args = dict(line_width=self.args.line_thickness,
|
||||
boxes=self.args.boxes,
|
||||
conf=self.args.show_conf,
|
||||
labels=self.args.show_labels)
|
||||
if not self.args.retina_masks:
|
||||
plot_args['im_gpu'] = im[idx]
|
||||
self.plotted_img = result.plot(**plot_args)
|
||||
|
@ -621,7 +621,7 @@ def check_amp(model):
|
||||
model (nn.Module): A YOLOv8 model instance.
|
||||
|
||||
Returns:
|
||||
bool: Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
|
||||
(bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the AMP checks fail, indicating anomalies with the AMP functionality on the system.
|
||||
|
@ -22,7 +22,7 @@ def check_train_batch_size(model, imgsz=640, amp=True):
|
||||
amp (bool): If True, use automatic mixed precision (AMP) for training.
|
||||
|
||||
Returns:
|
||||
int: Optimal batch size computed using the autobatch() function.
|
||||
(int): Optimal batch size computed using the autobatch() function.
|
||||
"""
|
||||
|
||||
with torch.cuda.amp.autocast(amp):
|
||||
@ -34,13 +34,13 @@ def autobatch(model, imgsz=640, fraction=0.67, batch_size=16):
|
||||
Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory.
|
||||
|
||||
Args:
|
||||
model: YOLO model to compute batch size for.
|
||||
model (torch.nn.module): YOLO model to compute batch size for.
|
||||
imgsz (int, optional): The image size used as input for the YOLO model. Defaults to 640.
|
||||
fraction (float, optional): The fraction of available CUDA memory to use. Defaults to 0.67.
|
||||
batch_size (int, optional): The default batch size to use if an error is detected. Defaults to 16.
|
||||
|
||||
Returns:
|
||||
int: The optimal batch size.
|
||||
(int): The optimal batch size.
|
||||
"""
|
||||
|
||||
# Check device
|
||||
|
@ -19,14 +19,14 @@ except (ImportError, AssertionError):
|
||||
clearml = None
|
||||
|
||||
|
||||
def _log_debug_samples(files, title='Debug Samples'):
|
||||
def _log_debug_samples(files, title='Debug Samples') -> None:
|
||||
"""
|
||||
Log files (images) as debug samples in the ClearML task.
|
||||
Log files (images) as debug samples in the ClearML task.
|
||||
|
||||
arguments:
|
||||
files (List(PosixPath)) a list of file paths in PosixPath format
|
||||
title (str) A title that groups together images with the same values
|
||||
"""
|
||||
Args:
|
||||
files (list): A list of file paths in PosixPath format.
|
||||
title (str): A title that groups together images with the same values.
|
||||
"""
|
||||
task = Task.current_task()
|
||||
if task:
|
||||
for f in files:
|
||||
@ -39,20 +39,23 @@ def _log_debug_samples(files, title='Debug Samples'):
|
||||
iteration=iteration)
|
||||
|
||||
|
||||
def _log_plot(title, plot_path):
|
||||
def _log_plot(title, plot_path) -> None:
|
||||
"""
|
||||
Log image as plot in the plot section of ClearML
|
||||
Log an image as a plot in the plot section of ClearML.
|
||||
|
||||
arguments:
|
||||
title (str) Title of the plot
|
||||
plot_path (PosixPath or str) Path to the saved image file
|
||||
"""
|
||||
Args:
|
||||
title (str): The title of the plot.
|
||||
plot_path (str): The path to the saved image file.
|
||||
"""
|
||||
img = mpimg.imread(plot_path)
|
||||
fig = plt.figure()
|
||||
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
|
||||
ax.imshow(img)
|
||||
|
||||
Task.current_task().get_logger().report_matplotlib_figure(title, '', figure=fig, report_interactive=False)
|
||||
Task.current_task().get_logger().report_matplotlib_figure(title=title,
|
||||
series='',
|
||||
figure=fig,
|
||||
report_interactive=False)
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
|
@ -47,13 +47,13 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
|
||||
stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value.
|
||||
|
||||
Args:
|
||||
imgsz (int or List[int]): Image size.
|
||||
imgsz (int) or (cList[int]): Image size.
|
||||
stride (int): Stride value.
|
||||
min_dim (int): Minimum number of dimensions.
|
||||
floor (int): Minimum allowed value for image size.
|
||||
|
||||
Returns:
|
||||
List[int]: Updated image size.
|
||||
(List[int]): Updated image size.
|
||||
"""
|
||||
# Convert stride to integer if it is a tensor
|
||||
stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride)
|
||||
@ -106,7 +106,7 @@ def check_version(current: str = '0.0.0',
|
||||
verbose (bool): If True, print warning message if minimum version is not met.
|
||||
|
||||
Returns:
|
||||
bool: True if minimum version is met, False otherwise.
|
||||
(bool): True if minimum version is met, False otherwise.
|
||||
"""
|
||||
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
|
||||
result = (current == minimum) if pinned else (current >= minimum) # bool
|
||||
@ -126,7 +126,7 @@ def check_latest_pypi_version(package_name='ultralytics'):
|
||||
package_name (str): The name of the package to find the latest version for.
|
||||
|
||||
Returns:
|
||||
str: The latest version of the package.
|
||||
(str): The latest version of the package.
|
||||
"""
|
||||
requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning
|
||||
response = requests.get(f'https://pypi.org/pypi/{package_name}/json', verify=False)
|
||||
@ -140,7 +140,7 @@ def check_pip_update_available():
|
||||
Checks if a new version of the ultralytics package is available on PyPI.
|
||||
|
||||
Returns:
|
||||
bool: True if an update is available, False otherwise.
|
||||
(bool): True if an update is available, False otherwise.
|
||||
"""
|
||||
if ONLINE and is_pip_package():
|
||||
with contextlib.suppress(Exception):
|
||||
@ -206,9 +206,6 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
|
||||
exclude (Tuple[str]): Tuple of package names to exclude from checking.
|
||||
install (bool): If True, attempt to auto-update packages that don't meet requirements.
|
||||
cmds (str): Additional commands to pass to the pip install command when auto-updating.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
prefix = colorstr('red', 'bold', 'requirements:')
|
||||
check_python() # check python version
|
||||
|
@ -67,21 +67,21 @@ def safe_download(url,
|
||||
min_bytes=1E0,
|
||||
progress=True):
|
||||
"""
|
||||
Function for downloading files from a URL, with options for retrying, unzipping, and deleting the downloaded file.
|
||||
Downloads files from a URL, with options for retrying, unzipping, and deleting the downloaded file.
|
||||
|
||||
Args:
|
||||
url: str: The URL of the file to be downloaded.
|
||||
file: str, optional: The filename of the downloaded file.
|
||||
url (str): The URL of the file to be downloaded.
|
||||
file (str, optional): The filename of the downloaded file.
|
||||
If not provided, the file will be saved with the same name as the URL.
|
||||
dir: str, optional: The directory to save the downloaded file.
|
||||
dir (str, optional): The directory to save the downloaded file.
|
||||
If not provided, the file will be saved in the current working directory.
|
||||
unzip: bool, optional: Whether to unzip the downloaded file. Default: True.
|
||||
delete: bool, optional: Whether to delete the downloaded file after unzipping. Default: False.
|
||||
curl: bool, optional: Whether to use curl command line tool for downloading. Default: False.
|
||||
retry: int, optional: The number of times to retry the download in case of failure. Default: 3.
|
||||
min_bytes: float, optional: The minimum number of bytes that the downloaded file should have, to be considered
|
||||
unzip (bool, optional): Whether to unzip the downloaded file. Default: True.
|
||||
delete (bool, optional): Whether to delete the downloaded file after unzipping. Default: False.
|
||||
curl (bool, optional): Whether to use curl command line tool for downloading. Default: False.
|
||||
retry (int, optional): The number of times to retry the download in case of failure. Default: 3.
|
||||
min_bytes (float, optional): The minimum number of bytes that the downloaded file should have, to be considered
|
||||
a successful download. Default: 1E0.
|
||||
progress: bool, optional: Whether to display a progress bar during the download. Default: True.
|
||||
progress (bool, optional): Whether to display a progress bar during the download. Default: True.
|
||||
"""
|
||||
if '://' not in str(url) and Path(url).is_file(): # exists ('://' check required in Windows Python<3.10)
|
||||
f = Path(url) # filename
|
||||
|
@ -30,13 +30,13 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False):
|
||||
directory if it does not already exist.
|
||||
|
||||
Args:
|
||||
path (str or pathlib.Path): Path to increment.
|
||||
exist_ok (bool, optional): If True, the path will not be incremented and will be returned as-is. Defaults to False.
|
||||
sep (str, optional): Separator to use between the path and the incrementation number. Defaults to an empty string.
|
||||
mkdir (bool, optional): If True, the path will be created as a directory if it does not exist. Defaults to False.
|
||||
path (str, pathlib.Path): Path to increment.
|
||||
exist_ok (bool, optional): If True, the path will not be incremented and returned as-is. Defaults to False.
|
||||
sep (str, optional): Separator to use between the path and the incrementation number. Defaults to ''.
|
||||
mkdir (bool, optional): Create a directory if it does not exist. Defaults to False.
|
||||
|
||||
Returns:
|
||||
pathlib.Path: Incremented path.
|
||||
(pathlib.Path): Incremented path.
|
||||
"""
|
||||
path = Path(path) # os-agnostic
|
||||
if path.exists() and not exist_ok:
|
||||
|
@ -98,7 +98,7 @@ class Bboxes:
|
||||
def mul(self, scale):
|
||||
"""
|
||||
Args:
|
||||
scale (tuple | List | int): the scale for four coords.
|
||||
scale (tuple) or (list) or (int): the scale for four coords.
|
||||
"""
|
||||
if isinstance(scale, Number):
|
||||
scale = to_4tuple(scale)
|
||||
@ -112,7 +112,7 @@ class Bboxes:
|
||||
def add(self, offset):
|
||||
"""
|
||||
Args:
|
||||
offset (tuple | List | int): the offset for four coords.
|
||||
offset (tuple) or (list) or (int): the offset for four coords.
|
||||
"""
|
||||
if isinstance(offset, Number):
|
||||
offset = to_4tuple(offset)
|
||||
@ -129,13 +129,18 @@ class Bboxes:
|
||||
@classmethod
|
||||
def concatenate(cls, boxes_list: List['Bboxes'], axis=0) -> 'Bboxes':
|
||||
"""
|
||||
Concatenates a list of Boxes into a single Bboxes
|
||||
Concatenate a list of Bboxes objects into a single Bboxes object.
|
||||
|
||||
Arguments:
|
||||
boxes_list (list[Bboxes])
|
||||
Args:
|
||||
boxes_list (List[Bboxes]): A list of Bboxes objects to concatenate.
|
||||
axis (int, optional): The axis along which to concatenate the bounding boxes.
|
||||
Defaults to 0.
|
||||
|
||||
Returns:
|
||||
Bboxes: the concatenated Boxes
|
||||
Bboxes: A new Bboxes object containing the concatenated bounding boxes.
|
||||
|
||||
Note:
|
||||
The input should be a list or tuple of Bboxes objects.
|
||||
"""
|
||||
assert isinstance(boxes_list, (list, tuple))
|
||||
if not boxes_list:
|
||||
@ -148,11 +153,21 @@ class Bboxes:
|
||||
|
||||
def __getitem__(self, index) -> 'Bboxes':
|
||||
"""
|
||||
Retrieve a specific bounding box or a set of bounding boxes using indexing.
|
||||
|
||||
Args:
|
||||
index: int, slice, or a BoolArray
|
||||
index (int, slice, or np.ndarray): The index, slice, or boolean array to select
|
||||
the desired bounding boxes.
|
||||
|
||||
Returns:
|
||||
Bboxes: Create a new :class:`Bboxes` by indexing.
|
||||
Bboxes: A new Bboxes object containing the selected bounding boxes.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the indexed bounding boxes do not form a 2-dimensional matrix.
|
||||
|
||||
Note:
|
||||
When using boolean indexing, make sure to provide a boolean array with the same
|
||||
length as the number of bounding boxes.
|
||||
"""
|
||||
if isinstance(index, int):
|
||||
return Bboxes(self.bboxes[index].view(1, -1))
|
||||
@ -236,11 +251,19 @@ class Instances:
|
||||
|
||||
def __getitem__(self, index) -> 'Instances':
|
||||
"""
|
||||
Retrieve a specific instance or a set of instances using indexing.
|
||||
|
||||
Args:
|
||||
index: int, slice, or a BoolArray
|
||||
index (int, slice, or np.ndarray): The index, slice, or boolean array to select
|
||||
the desired instances.
|
||||
|
||||
Returns:
|
||||
Instances: Create a new :class:`Instances` by indexing.
|
||||
Instances: A new Instances object containing the selected bounding boxes,
|
||||
segments, and keypoints if present.
|
||||
|
||||
Note:
|
||||
When using boolean indexing, make sure to provide a boolean array with the same
|
||||
length as the number of instances.
|
||||
"""
|
||||
segments = self.segments[index] if len(self.segments) else self.segments
|
||||
keypoints = self.keypoints[index] if self.keypoints is not None else None
|
||||
@ -305,14 +328,20 @@ class Instances:
|
||||
@classmethod
|
||||
def concatenate(cls, instances_list: List['Instances'], axis=0) -> 'Instances':
|
||||
"""
|
||||
Concatenates a list of Boxes into a single Bboxes
|
||||
Concatenates a list of Instances objects into a single Instances object.
|
||||
|
||||
Arguments:
|
||||
instances_list (list[Bboxes])
|
||||
axis
|
||||
Args:
|
||||
instances_list (List[Instances]): A list of Instances objects to concatenate.
|
||||
axis (int, optional): The axis along which the arrays will be concatenated. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
Boxes: the concatenated Boxes
|
||||
Instances: A new Instances object containing the concatenated bounding boxes,
|
||||
segments, and keypoints if present.
|
||||
|
||||
Note:
|
||||
The `Instances` objects in the list should have the same properties, such as
|
||||
the format of the bounding boxes, whether keypoints are present, and if the
|
||||
coordinates are normalized.
|
||||
"""
|
||||
assert isinstance(instances_list, (list, tuple))
|
||||
if not instances_list:
|
||||
|
@ -23,10 +23,16 @@ def box_area(box):
|
||||
|
||||
|
||||
def bbox_ioa(box1, box2, eps=1e-7):
|
||||
"""Returns the intersection over box2 area given box1, box2. Boxes are x1y1x2y2
|
||||
box1: np.array of shape(nx4)
|
||||
box2: np.array of shape(mx4)
|
||||
returns: np.array of shape(nxm)
|
||||
"""
|
||||
Calculate the intersection over box2 area given box1 and box2. Boxes are in x1y1x2y2 format.
|
||||
|
||||
Args:
|
||||
box1 (np.array): A numpy array of shape (n, 4) representing n bounding boxes.
|
||||
box2 (np.array): A numpy array of shape (m, 4) representing m bounding boxes.
|
||||
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
|
||||
|
||||
Returns:
|
||||
(np.array): A numpy array of shape (n, m) representing the intersection over box2 area.
|
||||
"""
|
||||
|
||||
# Get the coordinates of bounding boxes
|
||||
@ -46,17 +52,17 @@ def bbox_ioa(box1, box2, eps=1e-7):
|
||||
|
||||
def box_iou(box1, box2, eps=1e-7):
|
||||
"""
|
||||
Return intersection-over-union (Jaccard index) of boxes.
|
||||
Calculate intersection-over-union (IoU) of boxes.
|
||||
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
||||
Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
|
||||
|
||||
Arguments:
|
||||
box1 (Tensor[N, 4])
|
||||
box2 (Tensor[M, 4])
|
||||
eps
|
||||
Args:
|
||||
box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes.
|
||||
box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes.
|
||||
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
|
||||
|
||||
Returns:
|
||||
iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
|
||||
(torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.
|
||||
"""
|
||||
|
||||
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
|
||||
@ -68,7 +74,22 @@ def box_iou(box1, box2, eps=1e-7):
|
||||
|
||||
|
||||
def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
|
||||
# Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)
|
||||
"""
|
||||
Calculate Intersection over Union (IoU) of box1(1, 4) to box2(n, 4).
|
||||
|
||||
Args:
|
||||
box1 (torch.Tensor): A tensor representing a single bounding box with shape (1, 4).
|
||||
box2 (torch.Tensor): A tensor representing n bounding boxes with shape (n, 4).
|
||||
xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in
|
||||
(x1, y1, x2, y2) format. Defaults to True.
|
||||
GIoU (bool, optional): If True, calculate Generalized IoU. Defaults to False.
|
||||
DIoU (bool, optional): If True, calculate Distance IoU. Defaults to False.
|
||||
CIoU (bool, optional): If True, calculate Complete IoU. Defaults to False.
|
||||
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags.
|
||||
"""
|
||||
|
||||
# Get the coordinates of bounding boxes
|
||||
if xywh: # transform from xywh to xyxy
|
||||
@ -110,10 +131,17 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
|
||||
|
||||
def mask_iou(mask1, mask2, eps=1e-7):
|
||||
"""
|
||||
mask1: [N, n] m1 means number of gt objects
|
||||
mask2: [M, n] m2 means number of predicted objects
|
||||
Note: n means image_w x image_h
|
||||
Returns: masks iou, [N, M]
|
||||
Calculate masks IoU.
|
||||
|
||||
Args:
|
||||
mask1 (torch.Tensor): A tensor of shape (N, n) where N is the number of ground truth objects and n is the
|
||||
product of image width and height.
|
||||
mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the
|
||||
product of image width and height.
|
||||
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): A tensor of shape (N, M) representing masks IoU.
|
||||
"""
|
||||
intersection = torch.matmul(mask1, mask2.t()).clamp(0)
|
||||
union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection # (area1 + area2) - intersection
|
||||
@ -121,10 +149,18 @@ def mask_iou(mask1, mask2, eps=1e-7):
|
||||
|
||||
|
||||
def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7):
|
||||
"""OKS
|
||||
kpt1: [N, 17, 3], gt
|
||||
kpt2: [M, 17, 3], pred
|
||||
area: [N], areas from gt
|
||||
"""
|
||||
Calculate Object Keypoint Similarity (OKS).
|
||||
|
||||
Args:
|
||||
kpt1 (torch.Tensor): A tensor of shape (N, 17, 3) representing ground truth keypoints.
|
||||
kpt2 (torch.Tensor): A tensor of shape (M, 17, 3) representing predicted keypoints.
|
||||
area (torch.Tensor): A tensor of shape (N,) representing areas from ground truth.
|
||||
sigma (list): A list containing 17 values representing keypoint scales.
|
||||
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): A tensor of shape (N, M) representing keypoint similarities.
|
||||
"""
|
||||
d = (kpt1[:, None, :, 0] - kpt2[..., 0]) ** 2 + (kpt1[:, None, :, 1] - kpt2[..., 1]) ** 2 # (N, M, 17)
|
||||
sigma = torch.tensor(sigma, device=kpt1.device, dtype=kpt1.dtype) # (17, )
|
||||
@ -171,7 +207,17 @@ class FocalLoss(nn.Module):
|
||||
|
||||
|
||||
class ConfusionMatrix:
|
||||
# Updated version of https://github.com/kaanakan/object_detection_confusion_matrix
|
||||
"""
|
||||
A class for calculating and updating a confusion matrix for object detection and classification tasks.
|
||||
|
||||
Attributes:
|
||||
task (str): The type of task, either 'detect' or 'classify'.
|
||||
matrix (np.array): The confusion matrix, with dimensions depending on the task.
|
||||
nc (int): The number of classes.
|
||||
conf (float): The confidence threshold for detections.
|
||||
iou_thres (float): The Intersection over Union threshold.
|
||||
"""
|
||||
|
||||
def __init__(self, nc, conf=0.25, iou_thres=0.45, task='detect'):
|
||||
self.task = task
|
||||
self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == 'detect' else np.zeros((nc, nc))
|
||||
@ -183,12 +229,9 @@ class ConfusionMatrix:
|
||||
"""
|
||||
Update confusion matrix for classification task
|
||||
|
||||
Arguments:
|
||||
preds (Array[N, min(nc,5)])
|
||||
targets (Array[N, 1])
|
||||
|
||||
Returns:
|
||||
None, updates confusion matrix accordingly
|
||||
Args:
|
||||
preds (Array[N, min(nc,5)]): Predicted class labels.
|
||||
targets (Array[N, 1]): Ground truth class labels.
|
||||
"""
|
||||
preds, targets = torch.cat(preds)[:, 0], torch.cat(targets)
|
||||
for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):
|
||||
@ -196,15 +239,13 @@ class ConfusionMatrix:
|
||||
|
||||
def process_batch(self, detections, labels):
|
||||
"""
|
||||
Return intersection-over-union (Jaccard index) of boxes.
|
||||
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
||||
Update confusion matrix for object detection task.
|
||||
|
||||
Arguments:
|
||||
detections (Array[N, 6]), x1, y1, x2, y2, conf, class
|
||||
labels (Array[M, 5]), class, x1, y1, x2, y2
|
||||
|
||||
Returns:
|
||||
None, updates confusion matrix accordingly
|
||||
Args:
|
||||
detections (Array[N, 6]): Detected bounding boxes and their associated information.
|
||||
Each row should contain (x1, y1, x2, y2, conf, class).
|
||||
labels (Array[M, 5]): Ground truth bounding boxes and their associated class labels.
|
||||
Each row should contain (class, x1, y1, x2, y2).
|
||||
"""
|
||||
if detections is None:
|
||||
gt_classes = labels.int()
|
||||
@ -254,6 +295,14 @@ class ConfusionMatrix:
|
||||
@TryExcept('WARNING ⚠️ ConfusionMatrix plot failure')
|
||||
@plt_settings()
|
||||
def plot(self, normalize=True, save_dir='', names=()):
|
||||
"""
|
||||
Plot the confusion matrix using seaborn and save it to a file.
|
||||
|
||||
Args:
|
||||
normalize (bool): Whether to normalize the confusion matrix.
|
||||
save_dir (str): Directory where the plot will be saved.
|
||||
names (tuple): Names of classes, used as labels on the plot.
|
||||
"""
|
||||
import seaborn as sn
|
||||
|
||||
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
|
||||
@ -284,6 +333,9 @@ class ConfusionMatrix:
|
||||
plt.close(fig)
|
||||
|
||||
def print(self):
|
||||
"""
|
||||
Print the confusion matrix to the console.
|
||||
"""
|
||||
for i in range(self.nc + 1):
|
||||
LOGGER.info(' '.join(map(str, self.matrix[i])))
|
||||
|
||||
@ -343,12 +395,17 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi
|
||||
|
||||
|
||||
def compute_ap(recall, precision):
|
||||
""" Compute the average precision, given the recall and precision curves
|
||||
"""
|
||||
Compute the average precision (AP) given the recall and precision curves.
|
||||
|
||||
Arguments:
|
||||
recall: The recall curve (list)
|
||||
precision: The precision curve (list)
|
||||
recall (list): The recall curve.
|
||||
precision (list): The precision curve.
|
||||
|
||||
Returns:
|
||||
Average precision, precision curve, recall curve
|
||||
(float): Average precision.
|
||||
(np.ndarray): Precision envelope curve.
|
||||
(np.ndarray): Modified recall curve with sentinel values added at the beginning and end.
|
||||
"""
|
||||
|
||||
# Append sentinel values to beginning and end
|
||||
@ -488,57 +545,71 @@ class Metric(SimpleClass):
|
||||
|
||||
@property
|
||||
def ap50(self):
|
||||
"""AP@0.5 of all classes.
|
||||
"""
|
||||
Returns the Average Precision (AP) at an IoU threshold of 0.5 for all classes.
|
||||
|
||||
Returns:
|
||||
(nc, ) or [].
|
||||
(np.ndarray, list): Array of shape (nc,) with AP50 values per class, or an empty list if not available.
|
||||
"""
|
||||
return self.all_ap[:, 0] if len(self.all_ap) else []
|
||||
|
||||
@property
|
||||
def ap(self):
|
||||
"""AP@0.5:0.95
|
||||
"""
|
||||
Returns the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.
|
||||
|
||||
Returns:
|
||||
(nc, ) or [].
|
||||
(np.ndarray, list): Array of shape (nc,) with AP50-95 values per class, or an empty list if not available.
|
||||
"""
|
||||
return self.all_ap.mean(1) if len(self.all_ap) else []
|
||||
|
||||
@property
|
||||
def mp(self):
|
||||
"""mean precision of all classes.
|
||||
"""
|
||||
Returns the Mean Precision of all classes.
|
||||
|
||||
Returns:
|
||||
float.
|
||||
(float): The mean precision of all classes.
|
||||
"""
|
||||
return self.p.mean() if len(self.p) else 0.0
|
||||
|
||||
@property
|
||||
def mr(self):
|
||||
"""mean recall of all classes.
|
||||
"""
|
||||
Returns the Mean Recall of all classes.
|
||||
|
||||
Returns:
|
||||
float.
|
||||
(float): The mean recall of all classes.
|
||||
"""
|
||||
return self.r.mean() if len(self.r) else 0.0
|
||||
|
||||
@property
|
||||
def map50(self):
|
||||
"""Mean AP@0.5 of all classes.
|
||||
"""
|
||||
Returns the mean Average Precision (mAP) at an IoU threshold of 0.5.
|
||||
|
||||
Returns:
|
||||
float.
|
||||
(float): The mAP50 at an IoU threshold of 0.5.
|
||||
"""
|
||||
return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0
|
||||
|
||||
@property
|
||||
def map75(self):
|
||||
"""Mean AP@0.75 of all classes.
|
||||
"""
|
||||
Returns the mean Average Precision (mAP) at an IoU threshold of 0.75.
|
||||
|
||||
Returns:
|
||||
float.
|
||||
(float): The mAP50 at an IoU threshold of 0.75.
|
||||
"""
|
||||
return self.all_ap[:, 5].mean() if len(self.all_ap) else 0.0
|
||||
|
||||
@property
|
||||
def map(self):
|
||||
"""Mean AP@0.5:0.95 of all classes.
|
||||
"""
|
||||
Returns the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
|
||||
|
||||
Returns:
|
||||
float.
|
||||
(float): The mAP over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
|
||||
"""
|
||||
return self.all_ap.mean() if len(self.all_ap) else 0.0
|
||||
|
||||
@ -566,7 +637,7 @@ class Metric(SimpleClass):
|
||||
def update(self, results):
|
||||
"""
|
||||
Args:
|
||||
results: tuple(p, r, ap, f1, ap_class)
|
||||
results (tuple): A tuple of (p, r, ap, f1, ap_class)
|
||||
"""
|
||||
self.p, self.r, self.f1, self.all_ap, self.ap_class_index = results
|
||||
|
||||
|
@ -120,10 +120,10 @@ def make_divisible(x, divisor):
|
||||
|
||||
Args:
|
||||
x (int): The number to make divisible.
|
||||
divisor (int or torch.Tensor): The divisor.
|
||||
divisor (int) or (torch.Tensor): The divisor.
|
||||
|
||||
Returns:
|
||||
int: The nearest number divisible by the divisor.
|
||||
(int): The nearest number divisible by the divisor.
|
||||
"""
|
||||
if isinstance(divisor, torch.Tensor):
|
||||
divisor = int(divisor.max()) # to int
|
||||
|
@ -127,7 +127,7 @@ class Annotator:
|
||||
masks_color = masks * (colors * alpha) # shape(n,h,w,3)
|
||||
|
||||
inv_alph_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
|
||||
mcs = (masks_color * inv_alph_masks).sum(0) * 2 # mask color summand shape(n,h,w,3)
|
||||
mcs = masks_color.max(dim=0).values # shape(n,h,w,3)
|
||||
|
||||
im_gpu = im_gpu.flip(dims=[0]) # flip channel
|
||||
im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
|
||||
@ -140,12 +140,16 @@ class Annotator:
|
||||
self.fromarray(self.im)
|
||||
|
||||
def kpts(self, kpts, shape=(640, 640), radius=5, kpt_line=True):
|
||||
"""Plot keypoints.
|
||||
"""Plot keypoints on the image.
|
||||
|
||||
Args:
|
||||
kpts (tensor): predicted kpts, shape: [17, 3]
|
||||
shape (tuple): image shape, (h, w)
|
||||
steps (int): keypoints step
|
||||
radius (int): size of drawing points
|
||||
kpts (tensor): Predicted keypoints with shape [17, 3]. Each keypoint has (x, y, confidence).
|
||||
shape (tuple): Image shape as a tuple (h, w), where h is the height and w is the width.
|
||||
radius (int, optional): Radius of the drawn keypoints. Default is 5.
|
||||
kpt_line (bool, optional): If True, the function will draw lines connecting keypoints
|
||||
for human pose. Default is True.
|
||||
|
||||
Note: `kpt_line=True` currently only supports human pose plotting.
|
||||
"""
|
||||
if self.pil:
|
||||
# convert to numpy first
|
||||
|
@ -54,6 +54,19 @@ def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
|
||||
|
||||
|
||||
class TaskAlignedAssigner(nn.Module):
|
||||
"""
|
||||
A task-aligned assigner for object detection.
|
||||
|
||||
This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric,
|
||||
which combines both classification and localization information.
|
||||
|
||||
Attributes:
|
||||
topk (int): The number of top candidates to consider.
|
||||
num_classes (int): The number of object classes.
|
||||
alpha (float): The alpha parameter for the classification component of the task-aligned metric.
|
||||
beta (float): The beta parameter for the localization component of the task-aligned metric.
|
||||
eps (float): A small value to prevent division by zero.
|
||||
"""
|
||||
|
||||
def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
|
||||
super().__init__()
|
||||
@ -66,8 +79,9 @@ class TaskAlignedAssigner(nn.Module):
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
||||
"""This code referenced to
|
||||
https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
|
||||
"""
|
||||
Compute the task-aligned assignment.
|
||||
Reference https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
|
||||
|
||||
Args:
|
||||
pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
|
||||
@ -76,11 +90,13 @@ class TaskAlignedAssigner(nn.Module):
|
||||
gt_labels (Tensor): shape(bs, n_max_boxes, 1)
|
||||
gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
|
||||
mask_gt (Tensor): shape(bs, n_max_boxes, 1)
|
||||
|
||||
Returns:
|
||||
target_labels (Tensor): shape(bs, num_total_anchors)
|
||||
target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
|
||||
target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
|
||||
fg_mask (Tensor): shape(bs, num_total_anchors)
|
||||
target_gt_idx (Tensor): shape(bs, num_total_anchors)
|
||||
"""
|
||||
self.bs = pd_scores.size(0)
|
||||
self.n_max_boxes = gt_bboxes.size(1)
|
||||
@ -142,9 +158,19 @@ class TaskAlignedAssigner(nn.Module):
|
||||
|
||||
def select_topk_candidates(self, metrics, largest=True, topk_mask=None):
|
||||
"""
|
||||
Select the top-k candidates based on the given metrics.
|
||||
|
||||
Args:
|
||||
metrics: (b, max_num_obj, h*w).
|
||||
topk_mask: (b, max_num_obj, topk) or None
|
||||
metrics (Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,
|
||||
max_num_obj is the maximum number of objects, and h*w represents the
|
||||
total number of anchor points.
|
||||
largest (bool): If True, select the largest values; otherwise, select the smallest values.
|
||||
topk_mask (Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), where
|
||||
topk is the number of top candidates to consider. If not provided,
|
||||
the top-k values are automatically computed based on the given metrics.
|
||||
|
||||
Returns:
|
||||
(Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
|
||||
"""
|
||||
|
||||
num_anchors = metrics.shape[-1] # h*w
|
||||
@ -165,22 +191,38 @@ class TaskAlignedAssigner(nn.Module):
|
||||
|
||||
def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
|
||||
"""
|
||||
Compute target labels, target bounding boxes, and target scores for the positive anchor points.
|
||||
|
||||
Args:
|
||||
gt_labels: (b, max_num_obj, 1)
|
||||
gt_bboxes: (b, max_num_obj, 4)
|
||||
target_gt_idx: (b, h*w)
|
||||
fg_mask: (b, h*w)
|
||||
gt_labels (Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
|
||||
batch size and max_num_obj is the maximum number of objects.
|
||||
gt_bboxes (Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
|
||||
target_gt_idx (Tensor): Indices of the assigned ground truth objects for positive
|
||||
anchor points, with shape (b, h*w), where h*w is the total
|
||||
number of anchor points.
|
||||
fg_mask (Tensor): A boolean tensor of shape (b, h*w) indicating the positive
|
||||
(foreground) anchor points.
|
||||
|
||||
Returns:
|
||||
(Tuple[Tensor, Tensor, Tensor]): A tuple containing the following tensors:
|
||||
- target_labels (Tensor): Shape (b, h*w), containing the target labels for
|
||||
positive anchor points.
|
||||
- target_bboxes (Tensor): Shape (b, h*w, 4), containing the target bounding boxes
|
||||
for positive anchor points.
|
||||
- target_scores (Tensor): Shape (b, h*w, num_classes), containing the target scores
|
||||
for positive anchor points, where num_classes is the number
|
||||
of object classes.
|
||||
"""
|
||||
|
||||
# assigned target labels, (b, 1)
|
||||
# Assigned target labels, (b, 1)
|
||||
batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
|
||||
target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w)
|
||||
target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w)
|
||||
|
||||
# assigned target boxes, (b, max_num_obj, 4) -> (b, h*w)
|
||||
# Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w)
|
||||
target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx]
|
||||
|
||||
# assigned target scores
|
||||
# Assigned target scores
|
||||
target_labels.clamp(0)
|
||||
target_scores = F.one_hot(target_labels, self.num_classes) # (b, h*w, 80)
|
||||
fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80)
|
||||
|
@ -427,7 +427,7 @@ class EarlyStopping:
|
||||
fitness (float): Fitness value of current epoch
|
||||
|
||||
Returns:
|
||||
bool: True if training should stop, False otherwise
|
||||
(bool): True if training should stop, False otherwise
|
||||
"""
|
||||
if fitness is None: # check if fitness=None (happens when val=False)
|
||||
return False
|
||||
|
@ -101,18 +101,6 @@ class ClassificationTrainer(BaseTrainer):
|
||||
loss_items = loss.detach()
|
||||
return loss, loss_items
|
||||
|
||||
# def label_loss_items(self, loss_items=None, prefix="train"):
|
||||
# """
|
||||
# Returns a loss dict with labelled training loss items tensor
|
||||
# """
|
||||
# # Not needed for classification but necessary for segmentation & detection
|
||||
# keys = [f"{prefix}/{x}" for x in self.loss_names]
|
||||
# if loss_items is not None:
|
||||
# loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
|
||||
# return dict(zip(keys, loss_items))
|
||||
# else:
|
||||
# return keys
|
||||
|
||||
def label_loss_items(self, loss_items=None, prefix='train'):
|
||||
"""
|
||||
Returns a loss dict with labelled training loss items tensor
|
||||
|
Reference in New Issue
Block a user