ultralytics 8.0.79 expand Docs reference section (#2053)

Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Fri3dChicken <87434761+AmoghDhaliwal@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-04-16 12:28:12 +02:00
committed by GitHub
parent 47bd8b433b
commit 31db8ed163
106 changed files with 2570 additions and 529 deletions

View File

@ -22,7 +22,7 @@ def check_train_batch_size(model, imgsz=640, amp=True):
amp (bool): If True, use automatic mixed precision (AMP) for training.
Returns:
int: Optimal batch size computed using the autobatch() function.
(int): Optimal batch size computed using the autobatch() function.
"""
with torch.cuda.amp.autocast(amp):
@ -34,13 +34,13 @@ def autobatch(model, imgsz=640, fraction=0.67, batch_size=16):
Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory.
Args:
model: YOLO model to compute batch size for.
model (torch.nn.module): YOLO model to compute batch size for.
imgsz (int, optional): The image size used as input for the YOLO model. Defaults to 640.
fraction (float, optional): The fraction of available CUDA memory to use. Defaults to 0.67.
batch_size (int, optional): The default batch size to use if an error is detected. Defaults to 16.
Returns:
int: The optimal batch size.
(int): The optimal batch size.
"""
# Check device

View File

@ -19,14 +19,14 @@ except (ImportError, AssertionError):
clearml = None
def _log_debug_samples(files, title='Debug Samples'):
def _log_debug_samples(files, title='Debug Samples') -> None:
"""
Log files (images) as debug samples in the ClearML task.
Log files (images) as debug samples in the ClearML task.
arguments:
files (List(PosixPath)) a list of file paths in PosixPath format
title (str) A title that groups together images with the same values
"""
Args:
files (list): A list of file paths in PosixPath format.
title (str): A title that groups together images with the same values.
"""
task = Task.current_task()
if task:
for f in files:
@ -39,20 +39,23 @@ def _log_debug_samples(files, title='Debug Samples'):
iteration=iteration)
def _log_plot(title, plot_path):
def _log_plot(title, plot_path) -> None:
"""
Log image as plot in the plot section of ClearML
Log an image as a plot in the plot section of ClearML.
arguments:
title (str) Title of the plot
plot_path (PosixPath or str) Path to the saved image file
"""
Args:
title (str): The title of the plot.
plot_path (str): The path to the saved image file.
"""
img = mpimg.imread(plot_path)
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
ax.imshow(img)
Task.current_task().get_logger().report_matplotlib_figure(title, '', figure=fig, report_interactive=False)
Task.current_task().get_logger().report_matplotlib_figure(title=title,
series='',
figure=fig,
report_interactive=False)
def on_pretrain_routine_start(trainer):

View File

@ -47,13 +47,13 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value.
Args:
imgsz (int or List[int]): Image size.
imgsz (int) or (cList[int]): Image size.
stride (int): Stride value.
min_dim (int): Minimum number of dimensions.
floor (int): Minimum allowed value for image size.
Returns:
List[int]: Updated image size.
(List[int]): Updated image size.
"""
# Convert stride to integer if it is a tensor
stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride)
@ -106,7 +106,7 @@ def check_version(current: str = '0.0.0',
verbose (bool): If True, print warning message if minimum version is not met.
Returns:
bool: True if minimum version is met, False otherwise.
(bool): True if minimum version is met, False otherwise.
"""
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
result = (current == minimum) if pinned else (current >= minimum) # bool
@ -126,7 +126,7 @@ def check_latest_pypi_version(package_name='ultralytics'):
package_name (str): The name of the package to find the latest version for.
Returns:
str: The latest version of the package.
(str): The latest version of the package.
"""
requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning
response = requests.get(f'https://pypi.org/pypi/{package_name}/json', verify=False)
@ -140,7 +140,7 @@ def check_pip_update_available():
Checks if a new version of the ultralytics package is available on PyPI.
Returns:
bool: True if an update is available, False otherwise.
(bool): True if an update is available, False otherwise.
"""
if ONLINE and is_pip_package():
with contextlib.suppress(Exception):
@ -206,9 +206,6 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
exclude (Tuple[str]): Tuple of package names to exclude from checking.
install (bool): If True, attempt to auto-update packages that don't meet requirements.
cmds (str): Additional commands to pass to the pip install command when auto-updating.
Returns:
None
"""
prefix = colorstr('red', 'bold', 'requirements:')
check_python() # check python version

View File

@ -67,21 +67,21 @@ def safe_download(url,
min_bytes=1E0,
progress=True):
"""
Function for downloading files from a URL, with options for retrying, unzipping, and deleting the downloaded file.
Downloads files from a URL, with options for retrying, unzipping, and deleting the downloaded file.
Args:
url: str: The URL of the file to be downloaded.
file: str, optional: The filename of the downloaded file.
url (str): The URL of the file to be downloaded.
file (str, optional): The filename of the downloaded file.
If not provided, the file will be saved with the same name as the URL.
dir: str, optional: The directory to save the downloaded file.
dir (str, optional): The directory to save the downloaded file.
If not provided, the file will be saved in the current working directory.
unzip: bool, optional: Whether to unzip the downloaded file. Default: True.
delete: bool, optional: Whether to delete the downloaded file after unzipping. Default: False.
curl: bool, optional: Whether to use curl command line tool for downloading. Default: False.
retry: int, optional: The number of times to retry the download in case of failure. Default: 3.
min_bytes: float, optional: The minimum number of bytes that the downloaded file should have, to be considered
unzip (bool, optional): Whether to unzip the downloaded file. Default: True.
delete (bool, optional): Whether to delete the downloaded file after unzipping. Default: False.
curl (bool, optional): Whether to use curl command line tool for downloading. Default: False.
retry (int, optional): The number of times to retry the download in case of failure. Default: 3.
min_bytes (float, optional): The minimum number of bytes that the downloaded file should have, to be considered
a successful download. Default: 1E0.
progress: bool, optional: Whether to display a progress bar during the download. Default: True.
progress (bool, optional): Whether to display a progress bar during the download. Default: True.
"""
if '://' not in str(url) and Path(url).is_file(): # exists ('://' check required in Windows Python<3.10)
f = Path(url) # filename

View File

@ -30,13 +30,13 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False):
directory if it does not already exist.
Args:
path (str or pathlib.Path): Path to increment.
exist_ok (bool, optional): If True, the path will not be incremented and will be returned as-is. Defaults to False.
sep (str, optional): Separator to use between the path and the incrementation number. Defaults to an empty string.
mkdir (bool, optional): If True, the path will be created as a directory if it does not exist. Defaults to False.
path (str, pathlib.Path): Path to increment.
exist_ok (bool, optional): If True, the path will not be incremented and returned as-is. Defaults to False.
sep (str, optional): Separator to use between the path and the incrementation number. Defaults to ''.
mkdir (bool, optional): Create a directory if it does not exist. Defaults to False.
Returns:
pathlib.Path: Incremented path.
(pathlib.Path): Incremented path.
"""
path = Path(path) # os-agnostic
if path.exists() and not exist_ok:

View File

@ -98,7 +98,7 @@ class Bboxes:
def mul(self, scale):
"""
Args:
scale (tuple | List | int): the scale for four coords.
scale (tuple) or (list) or (int): the scale for four coords.
"""
if isinstance(scale, Number):
scale = to_4tuple(scale)
@ -112,7 +112,7 @@ class Bboxes:
def add(self, offset):
"""
Args:
offset (tuple | List | int): the offset for four coords.
offset (tuple) or (list) or (int): the offset for four coords.
"""
if isinstance(offset, Number):
offset = to_4tuple(offset)
@ -129,13 +129,18 @@ class Bboxes:
@classmethod
def concatenate(cls, boxes_list: List['Bboxes'], axis=0) -> 'Bboxes':
"""
Concatenates a list of Boxes into a single Bboxes
Concatenate a list of Bboxes objects into a single Bboxes object.
Arguments:
boxes_list (list[Bboxes])
Args:
boxes_list (List[Bboxes]): A list of Bboxes objects to concatenate.
axis (int, optional): The axis along which to concatenate the bounding boxes.
Defaults to 0.
Returns:
Bboxes: the concatenated Boxes
Bboxes: A new Bboxes object containing the concatenated bounding boxes.
Note:
The input should be a list or tuple of Bboxes objects.
"""
assert isinstance(boxes_list, (list, tuple))
if not boxes_list:
@ -148,11 +153,21 @@ class Bboxes:
def __getitem__(self, index) -> 'Bboxes':
"""
Retrieve a specific bounding box or a set of bounding boxes using indexing.
Args:
index: int, slice, or a BoolArray
index (int, slice, or np.ndarray): The index, slice, or boolean array to select
the desired bounding boxes.
Returns:
Bboxes: Create a new :class:`Bboxes` by indexing.
Bboxes: A new Bboxes object containing the selected bounding boxes.
Raises:
AssertionError: If the indexed bounding boxes do not form a 2-dimensional matrix.
Note:
When using boolean indexing, make sure to provide a boolean array with the same
length as the number of bounding boxes.
"""
if isinstance(index, int):
return Bboxes(self.bboxes[index].view(1, -1))
@ -236,11 +251,19 @@ class Instances:
def __getitem__(self, index) -> 'Instances':
"""
Retrieve a specific instance or a set of instances using indexing.
Args:
index: int, slice, or a BoolArray
index (int, slice, or np.ndarray): The index, slice, or boolean array to select
the desired instances.
Returns:
Instances: Create a new :class:`Instances` by indexing.
Instances: A new Instances object containing the selected bounding boxes,
segments, and keypoints if present.
Note:
When using boolean indexing, make sure to provide a boolean array with the same
length as the number of instances.
"""
segments = self.segments[index] if len(self.segments) else self.segments
keypoints = self.keypoints[index] if self.keypoints is not None else None
@ -305,14 +328,20 @@ class Instances:
@classmethod
def concatenate(cls, instances_list: List['Instances'], axis=0) -> 'Instances':
"""
Concatenates a list of Boxes into a single Bboxes
Concatenates a list of Instances objects into a single Instances object.
Arguments:
instances_list (list[Bboxes])
axis
Args:
instances_list (List[Instances]): A list of Instances objects to concatenate.
axis (int, optional): The axis along which the arrays will be concatenated. Defaults to 0.
Returns:
Boxes: the concatenated Boxes
Instances: A new Instances object containing the concatenated bounding boxes,
segments, and keypoints if present.
Note:
The `Instances` objects in the list should have the same properties, such as
the format of the bounding boxes, whether keypoints are present, and if the
coordinates are normalized.
"""
assert isinstance(instances_list, (list, tuple))
if not instances_list:

View File

@ -23,10 +23,16 @@ def box_area(box):
def bbox_ioa(box1, box2, eps=1e-7):
"""Returns the intersection over box2 area given box1, box2. Boxes are x1y1x2y2
box1: np.array of shape(nx4)
box2: np.array of shape(mx4)
returns: np.array of shape(nxm)
"""
Calculate the intersection over box2 area given box1 and box2. Boxes are in x1y1x2y2 format.
Args:
box1 (np.array): A numpy array of shape (n, 4) representing n bounding boxes.
box2 (np.array): A numpy array of shape (m, 4) representing m bounding boxes.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
Returns:
(np.array): A numpy array of shape (n, m) representing the intersection over box2 area.
"""
# Get the coordinates of bounding boxes
@ -46,17 +52,17 @@ def bbox_ioa(box1, box2, eps=1e-7):
def box_iou(box1, box2, eps=1e-7):
"""
Return intersection-over-union (Jaccard index) of boxes.
Calculate intersection-over-union (IoU) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
Arguments:
box1 (Tensor[N, 4])
box2 (Tensor[M, 4])
eps
Args:
box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes.
box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
Returns:
iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
(torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.
"""
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
@ -68,7 +74,22 @@ def box_iou(box1, box2, eps=1e-7):
def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
# Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)
"""
Calculate Intersection over Union (IoU) of box1(1, 4) to box2(n, 4).
Args:
box1 (torch.Tensor): A tensor representing a single bounding box with shape (1, 4).
box2 (torch.Tensor): A tensor representing n bounding boxes with shape (n, 4).
xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in
(x1, y1, x2, y2) format. Defaults to True.
GIoU (bool, optional): If True, calculate Generalized IoU. Defaults to False.
DIoU (bool, optional): If True, calculate Distance IoU. Defaults to False.
CIoU (bool, optional): If True, calculate Complete IoU. Defaults to False.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
Returns:
(torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags.
"""
# Get the coordinates of bounding boxes
if xywh: # transform from xywh to xyxy
@ -110,10 +131,17 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
def mask_iou(mask1, mask2, eps=1e-7):
"""
mask1: [N, n] m1 means number of gt objects
mask2: [M, n] m2 means number of predicted objects
Note: n means image_w x image_h
Returns: masks iou, [N, M]
Calculate masks IoU.
Args:
mask1 (torch.Tensor): A tensor of shape (N, n) where N is the number of ground truth objects and n is the
product of image width and height.
mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the
product of image width and height.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
Returns:
(torch.Tensor): A tensor of shape (N, M) representing masks IoU.
"""
intersection = torch.matmul(mask1, mask2.t()).clamp(0)
union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection # (area1 + area2) - intersection
@ -121,10 +149,18 @@ def mask_iou(mask1, mask2, eps=1e-7):
def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7):
"""OKS
kpt1: [N, 17, 3], gt
kpt2: [M, 17, 3], pred
area: [N], areas from gt
"""
Calculate Object Keypoint Similarity (OKS).
Args:
kpt1 (torch.Tensor): A tensor of shape (N, 17, 3) representing ground truth keypoints.
kpt2 (torch.Tensor): A tensor of shape (M, 17, 3) representing predicted keypoints.
area (torch.Tensor): A tensor of shape (N,) representing areas from ground truth.
sigma (list): A list containing 17 values representing keypoint scales.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
Returns:
(torch.Tensor): A tensor of shape (N, M) representing keypoint similarities.
"""
d = (kpt1[:, None, :, 0] - kpt2[..., 0]) ** 2 + (kpt1[:, None, :, 1] - kpt2[..., 1]) ** 2 # (N, M, 17)
sigma = torch.tensor(sigma, device=kpt1.device, dtype=kpt1.dtype) # (17, )
@ -171,7 +207,17 @@ class FocalLoss(nn.Module):
class ConfusionMatrix:
# Updated version of https://github.com/kaanakan/object_detection_confusion_matrix
"""
A class for calculating and updating a confusion matrix for object detection and classification tasks.
Attributes:
task (str): The type of task, either 'detect' or 'classify'.
matrix (np.array): The confusion matrix, with dimensions depending on the task.
nc (int): The number of classes.
conf (float): The confidence threshold for detections.
iou_thres (float): The Intersection over Union threshold.
"""
def __init__(self, nc, conf=0.25, iou_thres=0.45, task='detect'):
self.task = task
self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == 'detect' else np.zeros((nc, nc))
@ -183,12 +229,9 @@ class ConfusionMatrix:
"""
Update confusion matrix for classification task
Arguments:
preds (Array[N, min(nc,5)])
targets (Array[N, 1])
Returns:
None, updates confusion matrix accordingly
Args:
preds (Array[N, min(nc,5)]): Predicted class labels.
targets (Array[N, 1]): Ground truth class labels.
"""
preds, targets = torch.cat(preds)[:, 0], torch.cat(targets)
for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):
@ -196,15 +239,13 @@ class ConfusionMatrix:
def process_batch(self, detections, labels):
"""
Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Update confusion matrix for object detection task.
Arguments:
detections (Array[N, 6]), x1, y1, x2, y2, conf, class
labels (Array[M, 5]), class, x1, y1, x2, y2
Returns:
None, updates confusion matrix accordingly
Args:
detections (Array[N, 6]): Detected bounding boxes and their associated information.
Each row should contain (x1, y1, x2, y2, conf, class).
labels (Array[M, 5]): Ground truth bounding boxes and their associated class labels.
Each row should contain (class, x1, y1, x2, y2).
"""
if detections is None:
gt_classes = labels.int()
@ -254,6 +295,14 @@ class ConfusionMatrix:
@TryExcept('WARNING ⚠️ ConfusionMatrix plot failure')
@plt_settings()
def plot(self, normalize=True, save_dir='', names=()):
"""
Plot the confusion matrix using seaborn and save it to a file.
Args:
normalize (bool): Whether to normalize the confusion matrix.
save_dir (str): Directory where the plot will be saved.
names (tuple): Names of classes, used as labels on the plot.
"""
import seaborn as sn
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
@ -284,6 +333,9 @@ class ConfusionMatrix:
plt.close(fig)
def print(self):
"""
Print the confusion matrix to the console.
"""
for i in range(self.nc + 1):
LOGGER.info(' '.join(map(str, self.matrix[i])))
@ -343,12 +395,17 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi
def compute_ap(recall, precision):
""" Compute the average precision, given the recall and precision curves
"""
Compute the average precision (AP) given the recall and precision curves.
Arguments:
recall: The recall curve (list)
precision: The precision curve (list)
recall (list): The recall curve.
precision (list): The precision curve.
Returns:
Average precision, precision curve, recall curve
(float): Average precision.
(np.ndarray): Precision envelope curve.
(np.ndarray): Modified recall curve with sentinel values added at the beginning and end.
"""
# Append sentinel values to beginning and end
@ -488,57 +545,71 @@ class Metric(SimpleClass):
@property
def ap50(self):
"""AP@0.5 of all classes.
"""
Returns the Average Precision (AP) at an IoU threshold of 0.5 for all classes.
Returns:
(nc, ) or [].
(np.ndarray, list): Array of shape (nc,) with AP50 values per class, or an empty list if not available.
"""
return self.all_ap[:, 0] if len(self.all_ap) else []
@property
def ap(self):
"""AP@0.5:0.95
"""
Returns the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.
Returns:
(nc, ) or [].
(np.ndarray, list): Array of shape (nc,) with AP50-95 values per class, or an empty list if not available.
"""
return self.all_ap.mean(1) if len(self.all_ap) else []
@property
def mp(self):
"""mean precision of all classes.
"""
Returns the Mean Precision of all classes.
Returns:
float.
(float): The mean precision of all classes.
"""
return self.p.mean() if len(self.p) else 0.0
@property
def mr(self):
"""mean recall of all classes.
"""
Returns the Mean Recall of all classes.
Returns:
float.
(float): The mean recall of all classes.
"""
return self.r.mean() if len(self.r) else 0.0
@property
def map50(self):
"""Mean AP@0.5 of all classes.
"""
Returns the mean Average Precision (mAP) at an IoU threshold of 0.5.
Returns:
float.
(float): The mAP50 at an IoU threshold of 0.5.
"""
return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0
@property
def map75(self):
"""Mean AP@0.75 of all classes.
"""
Returns the mean Average Precision (mAP) at an IoU threshold of 0.75.
Returns:
float.
(float): The mAP50 at an IoU threshold of 0.75.
"""
return self.all_ap[:, 5].mean() if len(self.all_ap) else 0.0
@property
def map(self):
"""Mean AP@0.5:0.95 of all classes.
"""
Returns the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
Returns:
float.
(float): The mAP over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
"""
return self.all_ap.mean() if len(self.all_ap) else 0.0
@ -566,7 +637,7 @@ class Metric(SimpleClass):
def update(self, results):
"""
Args:
results: tuple(p, r, ap, f1, ap_class)
results (tuple): A tuple of (p, r, ap, f1, ap_class)
"""
self.p, self.r, self.f1, self.all_ap, self.ap_class_index = results

View File

@ -120,10 +120,10 @@ def make_divisible(x, divisor):
Args:
x (int): The number to make divisible.
divisor (int or torch.Tensor): The divisor.
divisor (int) or (torch.Tensor): The divisor.
Returns:
int: The nearest number divisible by the divisor.
(int): The nearest number divisible by the divisor.
"""
if isinstance(divisor, torch.Tensor):
divisor = int(divisor.max()) # to int

View File

@ -127,7 +127,7 @@ class Annotator:
masks_color = masks * (colors * alpha) # shape(n,h,w,3)
inv_alph_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
mcs = (masks_color * inv_alph_masks).sum(0) * 2 # mask color summand shape(n,h,w,3)
mcs = masks_color.max(dim=0).values # shape(n,h,w,3)
im_gpu = im_gpu.flip(dims=[0]) # flip channel
im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
@ -140,12 +140,16 @@ class Annotator:
self.fromarray(self.im)
def kpts(self, kpts, shape=(640, 640), radius=5, kpt_line=True):
"""Plot keypoints.
"""Plot keypoints on the image.
Args:
kpts (tensor): predicted kpts, shape: [17, 3]
shape (tuple): image shape, (h, w)
steps (int): keypoints step
radius (int): size of drawing points
kpts (tensor): Predicted keypoints with shape [17, 3]. Each keypoint has (x, y, confidence).
shape (tuple): Image shape as a tuple (h, w), where h is the height and w is the width.
radius (int, optional): Radius of the drawn keypoints. Default is 5.
kpt_line (bool, optional): If True, the function will draw lines connecting keypoints
for human pose. Default is True.
Note: `kpt_line=True` currently only supports human pose plotting.
"""
if self.pil:
# convert to numpy first

View File

@ -54,6 +54,19 @@ def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
class TaskAlignedAssigner(nn.Module):
"""
A task-aligned assigner for object detection.
This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric,
which combines both classification and localization information.
Attributes:
topk (int): The number of top candidates to consider.
num_classes (int): The number of object classes.
alpha (float): The alpha parameter for the classification component of the task-aligned metric.
beta (float): The beta parameter for the localization component of the task-aligned metric.
eps (float): A small value to prevent division by zero.
"""
def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
super().__init__()
@ -66,8 +79,9 @@ class TaskAlignedAssigner(nn.Module):
@torch.no_grad()
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
"""This code referenced to
https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
"""
Compute the task-aligned assignment.
Reference https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
Args:
pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
@ -76,11 +90,13 @@ class TaskAlignedAssigner(nn.Module):
gt_labels (Tensor): shape(bs, n_max_boxes, 1)
gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
mask_gt (Tensor): shape(bs, n_max_boxes, 1)
Returns:
target_labels (Tensor): shape(bs, num_total_anchors)
target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
fg_mask (Tensor): shape(bs, num_total_anchors)
target_gt_idx (Tensor): shape(bs, num_total_anchors)
"""
self.bs = pd_scores.size(0)
self.n_max_boxes = gt_bboxes.size(1)
@ -142,9 +158,19 @@ class TaskAlignedAssigner(nn.Module):
def select_topk_candidates(self, metrics, largest=True, topk_mask=None):
"""
Select the top-k candidates based on the given metrics.
Args:
metrics: (b, max_num_obj, h*w).
topk_mask: (b, max_num_obj, topk) or None
metrics (Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,
max_num_obj is the maximum number of objects, and h*w represents the
total number of anchor points.
largest (bool): If True, select the largest values; otherwise, select the smallest values.
topk_mask (Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), where
topk is the number of top candidates to consider. If not provided,
the top-k values are automatically computed based on the given metrics.
Returns:
(Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
"""
num_anchors = metrics.shape[-1] # h*w
@ -165,22 +191,38 @@ class TaskAlignedAssigner(nn.Module):
def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
"""
Compute target labels, target bounding boxes, and target scores for the positive anchor points.
Args:
gt_labels: (b, max_num_obj, 1)
gt_bboxes: (b, max_num_obj, 4)
target_gt_idx: (b, h*w)
fg_mask: (b, h*w)
gt_labels (Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
batch size and max_num_obj is the maximum number of objects.
gt_bboxes (Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
target_gt_idx (Tensor): Indices of the assigned ground truth objects for positive
anchor points, with shape (b, h*w), where h*w is the total
number of anchor points.
fg_mask (Tensor): A boolean tensor of shape (b, h*w) indicating the positive
(foreground) anchor points.
Returns:
(Tuple[Tensor, Tensor, Tensor]): A tuple containing the following tensors:
- target_labels (Tensor): Shape (b, h*w), containing the target labels for
positive anchor points.
- target_bboxes (Tensor): Shape (b, h*w, 4), containing the target bounding boxes
for positive anchor points.
- target_scores (Tensor): Shape (b, h*w, num_classes), containing the target scores
for positive anchor points, where num_classes is the number
of object classes.
"""
# assigned target labels, (b, 1)
# Assigned target labels, (b, 1)
batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w)
target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w)
# assigned target boxes, (b, max_num_obj, 4) -> (b, h*w)
# Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w)
target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx]
# assigned target scores
# Assigned target scores
target_labels.clamp(0)
target_scores = F.one_hot(target_labels, self.num_classes) # (b, h*w, 80)
fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80)

View File

@ -427,7 +427,7 @@ class EarlyStopping:
fitness (float): Fitness value of current epoch
Returns:
bool: True if training should stop, False otherwise
(bool): True if training should stop, False otherwise
"""
if fitness is None: # check if fitness=None (happens when val=False)
return False