ultralytics 8.0.79 expand Docs reference section (#2053)

Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Fri3dChicken <87434761+AmoghDhaliwal@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-04-16 12:28:12 +02:00
committed by GitHub
parent 47bd8b433b
commit 31db8ed163
106 changed files with 2570 additions and 529 deletions

View File

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.78'
__version__ = '8.0.79'
from ultralytics.hub import start
from ultralytics.yolo.engine.model import YOLO

View File

@ -111,7 +111,7 @@ class Auth:
Get the authentication header for making API requests.
Returns:
dict: The authentication header if id_token or API key is set, None otherwise.
(dict): The authentication header if id_token or API key is set, None otherwise.
"""
if self.id_token:
return {'authorization': f'Bearer {self.id_token}'}

View File

@ -30,7 +30,7 @@ def check_dataset_disk_space(url='https://ultralytics.com/assets/coco128.zip', s
sf (float, optional): Safety factor, the multiplier for the required free space. Defaults to 2.0.
Returns:
bool: True if there is sufficient disk space, False otherwise.
(bool): True if there is sufficient disk space, False otherwise.
"""
gib = 1 << 30 # bytes per GiB
data = int(requests.head(url).headers['Content-Length']) / gib # dataset size (GB)
@ -51,7 +51,7 @@ def request_with_credentials(url: str) -> any:
url (str): The URL to make the request to.
Returns:
any: The response data from the AJAX request.
(any): The response data from the AJAX request.
Raises:
OSError: If the function is not run in a Google Colab environment.
@ -87,11 +87,14 @@ def requests_with_progress(method, url, **kwargs):
Args:
method (str): The HTTP method to use (e.g. 'GET', 'POST').
url (str): The URL to send the request to.
progress (bool, optional): Whether to display a progress bar. Defaults to False.
**kwargs: Additional keyword arguments to pass to the underlying `requests.request` function.
**kwargs (dict): Additional keyword arguments to pass to the underlying `requests.request` function.
Returns:
requests.Response: The response from the HTTP request.
(requests.Response): The response object from the HTTP request.
Note:
If 'progress' is set to True, the progress bar will display the download progress
for responses with a known content length.
"""
progress = kwargs.pop('progress', False)
if not progress:
@ -118,10 +121,10 @@ def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbos
code (int, optional): An identifier for the request, used for logging purposes. Default is -1.
verbose (bool, optional): A flag to determine whether to print out to console or not. Default is True.
progress (bool, optional): Whether to show a progress bar during the request. Default is False.
**kwargs: Keyword arguments to be passed to the requests function specified in method.
**kwargs (dict): Keyword arguments to be passed to the requests function specified in method.
Returns:
requests.Response: The HTTP response object. If the request is executed in a separate thread, returns None.
(requests.Response): The HTTP response object. If the request is executed in a separate thread, returns None.
"""
retry_codes = (408, 500) # retry only these codes

View File

@ -337,7 +337,7 @@ def torch_safe_load(weight):
weight (str): The file path of the PyTorch model.
Returns:
The loaded PyTorch model.
(dict): The loaded PyTorch model.
"""
from ultralytics.yolo.utils.downloads import attempt_download_asset
@ -398,7 +398,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
for k in 'names', 'nc', 'yaml':
setattr(ensemble, k, getattr(ensemble[0], k))
ensemble.stride = ensemble[torch.argmax(torch.tensor([m.stride.max() for m in ensemble])).int()].stride
assert all(ensemble[0].nc == m.nc for m in ensemble), f'Models differ in class counts: {[m.nc for m in ensemble]}'
assert all(ensemble[0].nc == m.nc for m in ensemble), f'Models differ in class counts {[m.nc for m in ensemble]}'
return ensemble
@ -520,7 +520,7 @@ def guess_model_scale(model_path):
which is denoted by n, s, m, l, or x. The function returns the size character of the model scale as a string.
Args:
model_path (str or Path): The path to the YOLO model's YAML file.
model_path (str) or (Path): The path to the YOLO model's YAML file.
Returns:
(str): The size character of the model's scale, which can be n, s, m, l, or x.
@ -539,7 +539,7 @@ def guess_model_task(model):
model (nn.Module) or (dict): PyTorch model or model configuration in YAML format.
Returns:
str: Task of the model ('detect', 'segment', 'classify', 'pose').
(str): Task of the model ('detect', 'segment', 'classify', 'pose').
Raises:
SyntaxError: If the task of the model could not be determined.

View File

@ -190,11 +190,24 @@ def fuse_score(cost_matrix, detections):
def bbox_ious(box1, box2, eps=1e-7):
"""Boxes are x1y1x2y2
box1: np.array of shape(nx4)
box2: np.array of shape(mx4)
returns: np.array of shape(nxm)
"""
Calculate the Intersection over Union (IoU) between pairs of bounding boxes.
Args:
box1 (np.array): A numpy array of shape (n, 4) representing 'n' bounding boxes.
Each row is in the format (x1, y1, x2, y2).
box2 (np.array): A numpy array of shape (m, 4) representing 'm' bounding boxes.
Each row is in the format (x1, y1, x2, y2).
eps (float, optional): A small constant to prevent division by zero. Defaults to 1e-7.
Returns:
(np.array): A numpy array of shape (n, m) representing the IoU scores for each pair
of bounding boxes from box1 and box2.
Note:
The bounding box coordinates are expected to be in the format (x1, y1, x2, y2).
"""
# Get the coordinates of bounding boxes
b1_x1, b1_y1, b1_x2, b1_y2 = box1.T
b2_x1, b2_y1, b2_x2, b2_y2 = box2.T

View File

@ -262,13 +262,15 @@ class RandomPerspective:
return img, M, s
def apply_bboxes(self, bboxes, M):
"""apply affine to bboxes only.
"""
Apply affine to bboxes only.
Args:
bboxes(ndarray): list of bboxes, xyxy format, with shape (num_bboxes, 4).
M(ndarray): affine matrix.
bboxes (ndarray): list of bboxes, xyxy format, with shape (num_bboxes, 4).
M (ndarray): affine matrix.
Returns:
new_bboxes(ndarray): bboxes after affine, [num_bboxes, 4].
new_bboxes (ndarray): bboxes after affine, [num_bboxes, 4].
"""
n = len(bboxes)
if n == 0:
@ -285,14 +287,16 @@ class RandomPerspective:
return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1)), dtype=bboxes.dtype).reshape(4, n).T
def apply_segments(self, segments, M):
"""apply affine to segments and generate new bboxes from segments.
"""
Apply affine to segments and generate new bboxes from segments.
Args:
segments(ndarray): list of segments, [num_samples, 500, 2].
M(ndarray): affine matrix.
segments (ndarray): list of segments, [num_samples, 500, 2].
M (ndarray): affine matrix.
Returns:
new_segments(ndarray): list of segments after affine, [num_samples, 500, 2].
new_bboxes(ndarray): bboxes after affine, [N, 4].
new_segments (ndarray): list of segments after affine, [num_samples, 500, 2].
new_bboxes (ndarray): bboxes after affine, [N, 4].
"""
n, num = segments.shape[:2]
if n == 0:
@ -308,13 +312,15 @@ class RandomPerspective:
return bboxes, segments
def apply_keypoints(self, keypoints, M):
"""apply affine to keypoints.
"""
Apply affine to keypoints.
Args:
keypoints(ndarray): keypoints, [N, 17, 3].
M(ndarray): affine matrix.
keypoints (ndarray): keypoints, [N, 17, 3].
M (ndarray): affine matrix.
Return:
new_keypoints(ndarray): keypoints after affine, [N, 17, 3].
new_keypoints (ndarray): keypoints after affine, [N, 17, 3].
"""
n, nkpt = keypoints.shape[:2]
if n == 0:
@ -333,7 +339,7 @@ class RandomPerspective:
Affine images and targets.
Args:
labels(Dict): a dict of `bboxes`, `segments`, `keypoints`.
labels (dict): a dict of `bboxes`, `segments`, `keypoints`.
"""
if self.pre_transform and 'mosaic_border' not in labels:
labels = self.pre_transform(labels)

View File

@ -18,11 +18,30 @@ from .utils import HELP_URL, IMG_FORMATS
class BaseDataset(Dataset):
"""Base Dataset.
"""
Base dataset class for loading and processing image data.
Args:
img_path (str): image path.
pipeline (dict): a dict of image transforms.
label_path (str): label path, this can also be an ann_file or other custom label path.
img_path (str): Image path.
imgsz (int): Target image size for resizing. Default is 640.
cache (bool): Cache images in memory or on disk for faster loading. Default is False.
augment (bool): Apply data augmentation. Default is True.
hyp (dict): Dictionary of hyperparameters for data augmentation. Default is None.
prefix (str): Prefix for file paths. Default is an empty string.
rect (bool): Enable rectangular training. Default is False.
batch_size (int): Batch size for rectangular training. Default is None.
stride (int): Stride for rectangular training. Default is 32.
pad (float): Padding for rectangular training. Default is 0.5.
single_cls (bool): Use a single class for all labels. Default is False.
classes (list): List of included classes. Default is None.
Attributes:
im_files (list): List of image file paths.
labels (list): List of label data dictionaries.
ni (int): Number of images in the dataset.
ims (list): List of loaded images.
npy_files (list): List of numpy file paths.
transforms (callable): Image transformation function.
"""
def __init__(self,

View File

@ -21,10 +21,7 @@ from .utils import PIN_MEMORY
class InfiniteDataLoader(dataloader.DataLoader):
"""Dataloader that reuses workers
Uses same syntax as vanilla DataLoader
"""
"""Dataloader that reuses workers. Uses same syntax as vanilla DataLoader."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -40,10 +37,11 @@ class InfiniteDataLoader(dataloader.DataLoader):
class _RepeatSampler:
"""Sampler that repeats forever
"""
Sampler that repeats forever.
Args:
sampler (Sampler)
sampler (Dataset.sampler): The sampler to repeat.
"""
def __init__(self, sampler):
@ -173,7 +171,7 @@ def load_inference_source(source=None, transforms=None, imgsz=640, vid_stride=1,
auto (bool, optional): Automatically apply pre-processing. Default is True.
Returns:
dataset: A dataset object for the specified input source.
dataset (Dataset): A dataset object for the specified input source.
"""
source, webcam, screenshot, from_img, in_memory, tensor = check_source(source)
source_type = source.source_type if in_memory else SourceTypes(webcam, screenshot, from_img, tensor)

View File

@ -178,7 +178,7 @@ class _RepeatSampler:
""" Sampler that repeats forever
Args:
sampler (Sampler)
sampler (Dataset.sampler): The sampler to repeat.
"""
def __init__(self, sampler):

View File

@ -17,31 +17,31 @@ from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image_lab
class YOLODataset(BaseDataset):
cache_version = '1.0.2' # dataset labels *.cache version, >= 1.0.0 for YOLOv8
rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
"""
Dataset class for loading images object detection and/or segmentation labels in YOLO format.
Dataset class for loading object detection and/or segmentation labels in YOLO format.
Args:
img_path (str): path to the folder containing images.
imgsz (int): image size (default: 640).
cache (bool): if True, a cache file of the labels is created to speed up future creation of dataset instances
(default: False).
augment (bool): if True, data augmentation is applied (default: True).
hyp (dict): hyperparameters to apply data augmentation (default: None).
prefix (str): prefix to print in log messages (default: '').
rect (bool): if True, rectangular training is used (default: False).
batch_size (int): size of batches (default: None).
stride (int): stride (default: 32).
pad (float): padding (default: 0.0).
single_cls (bool): if True, single class training is used (default: False).
use_segments (bool): if True, segmentation masks are used as labels (default: False).
use_keypoints (bool): if True, keypoints are used as labels (default: False).
names (dict): A dictionary of class names. (default: None).
img_path (str): Path to the folder containing images.
imgsz (int, optional): Image size. Defaults to 640.
cache (bool, optional): Cache images to RAM or disk during training. Defaults to False.
augment (bool, optional): If True, data augmentation is applied. Defaults to True.
hyp (dict, optional): Hyperparameters to apply data augmentation. Defaults to None.
prefix (str, optional): Prefix to print in log messages. Defaults to ''.
rect (bool, optional): If True, rectangular training is used. Defaults to False.
batch_size (int, optional): Size of batches. Defaults to None.
stride (int, optional): Stride. Defaults to 32.
pad (float, optional): Padding. Defaults to 0.0.
single_cls (bool, optional): If True, single class training is used. Defaults to False.
use_segments (bool, optional): If True, segmentation masks are used as labels. Defaults to False.
use_keypoints (bool, optional): If True, keypoints are used as labels. Defaults to False.
data (dict, optional): A dataset YAML dictionary. Defaults to None.
classes (list): List of included classes. Default is None.
Returns:
A PyTorch dataset object that can be used for training an object detection or segmentation model.
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
"""
cache_version = '1.0.2' # dataset labels *.cache version, >= 1.0.0 for YOLOv8
rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
def __init__(self,
img_path,

View File

@ -7,21 +7,36 @@ from .augment import LetterBox
class MixAndRectDataset:
"""A wrapper of multiple images mixed dataset.
"""
A dataset class that applies mosaic and mixup transformations as well as rectangular training.
Args:
dataset (:obj:`BaseDataset`): The dataset to be mixed.
transforms (Sequence[dict]): config dict to be composed.
Attributes:
dataset: The base dataset.
imgsz: The size of the images in the dataset.
"""
def __init__(self, dataset):
"""
Args:
dataset (BaseDataset): The base dataset to apply transformations to.
"""
self.dataset = dataset
self.imgsz = dataset.imgsz
def __len__(self):
"""Returns the number of items in the dataset."""
return len(self.dataset)
def __getitem__(self, index):
"""
Applies mosaic, mixup and rectangular training transformations to an item in the dataset.
Args:
index (int): Index of the item in the dataset.
Returns:
(dict): A dictionary containing the transformed item data.
"""
labels = deepcopy(self.dataset[index])
for transform in self.dataset.transforms.tolist():
# mosaic and mixup

View File

@ -270,9 +270,8 @@ def check_cls_dataset(dataset: str):
"""
Check a classification dataset such as Imagenet.
Copy code
This function takes a `dataset` name as input and returns a dictionary containing information about the dataset.
If the dataset is not found, it attempts to download the dataset from the internet and save it to the local file system.
If the dataset is not found, it attempts to download the dataset from the internet and save it locally.
Args:
dataset (str): Name of the dataset.
@ -306,7 +305,8 @@ def check_cls_dataset(dataset: str):
class HUBDatasetStats():
""" Class for generating HUB dataset JSON and `-hub` dataset directory
"""
Class for generating HUB dataset JSON and `-hub` dataset directory
Arguments
path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
@ -427,9 +427,6 @@ def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
max_dim (int, optional): The maximum dimension (width or height) of the output image. Default is 1920 pixels.
quality (int, optional): The image compression quality as a percentage. Default is 50%.
Returns:
None
Usage:
from pathlib import Path
from ultralytics.yolo.data.utils import compress_one_image
@ -459,9 +456,6 @@ def delete_dsstore(path):
Args:
path (str, optional): The directory path where the ".DS_store" files should be deleted.
Returns:
None
Usage:
from ultralytics.yolo.data.utils import delete_dsstore
delete_dsstore('/Users/glennjocher/Downloads/dataset')
@ -478,15 +472,13 @@ def delete_dsstore(path):
def zip_directory(dir, use_zipfile_library=True):
"""Zips a directory and saves the archive to the specified output path.
"""
Zips a directory and saves the archive to the specified output path.
Args:
dir (str): The path to the directory to be zipped.
use_zipfile_library (bool): Whether to use zipfile library or shutil for zipping.
Returns:
None
Usage:
from ultralytics.yolo.data.utils import zip_directory
zip_directory('/Users/glennjocher/Downloads/playground')

View File

@ -73,7 +73,7 @@ ARM64 = platform.machine() in ('arm64', 'aarch64')
def export_formats():
# YOLOv8 export formats
"""YOLOv8 export formats"""
import pandas
x = [
['PyTorch', '-', '.pt', True, True],
@ -92,7 +92,7 @@ def export_formats():
def gd_outputs(gd):
# TensorFlow GraphDef model output node names
"""TensorFlow GraphDef model output node names"""
name_list, input_list = [], []
for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
name_list.append(node.name)
@ -101,7 +101,7 @@ def gd_outputs(gd):
def try_export(inner_func):
# YOLOv8 export decorator, i..e @try_export
"""YOLOv8 export decorator, i..e @try_export"""
inner_args = get_default_args(inner_func)
def outer_func(*args, **kwargs):
@ -118,10 +118,26 @@ def try_export(inner_func):
return outer_func
class iOSDetectModel(torch.nn.Module):
"""Wrap an Ultralytics YOLO model for iOS export"""
def __init__(self, model, im):
super().__init__()
b, c, h, w = im.shape # batch, channel, height, width
self.model = model
self.nc = len(model.names) # number of classes
if w == h:
self.normalize = 1.0 / w # scalar
else:
self.normalize = torch.tensor([1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h]) # broadcast (slower, smaller)
def forward(self, x):
xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)
class Exporter:
"""
Exporter
A class for exporting a model.
Attributes:
@ -136,6 +152,7 @@ class Exporter:
Args:
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
overrides (dict, optional): Configuration overrides. Defaults to None.
_callbacks (list, optional): List of callback functions. Defaults to None.
"""
self.args = get_cfg(cfg, overrides)
self.callbacks = _callbacks or callbacks.get_default_callbacks()
@ -385,22 +402,6 @@ class Exporter:
check_requirements('coremltools>=6.0')
import coremltools as ct # noqa
class iOSDetectModel(torch.nn.Module):
# Wrap an Ultralytics YOLO model for iOS export
def __init__(self, model, im):
super().__init__()
b, c, h, w = im.shape # batch, channel, height, width
self.model = model
self.nc = len(model.names) # number of classes
if w == h:
self.normalize = 1.0 / w # scalar
else:
self.normalize = torch.tensor([1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h]) # broadcast (slower, smaller)
def forward(self, x):
xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)
LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
f = self.file.with_suffix('.mlmodel')

View File

@ -400,7 +400,7 @@ class YOLO:
train_args (dict, optional): Additional arguments to pass to the `train()` method. Defaults to {}.
Returns:
A dictionary containing the results of the hyperparameter search.
(dict): A dictionary containing the results of the hyperparameter search.
Raises:
ModuleNotFoundError: If Ray Tune is not installed.

View File

@ -127,7 +127,10 @@ class BasePredictor:
log_string += result.verbose()
if self.args.save or self.args.show: # Add bbox to image
plot_args = dict(line_width=self.args.line_thickness, boxes=self.args.boxes)
plot_args = dict(line_width=self.args.line_thickness,
boxes=self.args.boxes,
conf=self.args.show_conf,
labels=self.args.show_labels)
if not self.args.retina_masks:
plot_args['im_gpu'] = im[idx]
self.plotted_img = result.plot(**plot_args)

View File

@ -621,7 +621,7 @@ def check_amp(model):
model (nn.Module): A YOLOv8 model instance.
Returns:
bool: Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
(bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
Raises:
AssertionError: If the AMP checks fail, indicating anomalies with the AMP functionality on the system.

View File

@ -22,7 +22,7 @@ def check_train_batch_size(model, imgsz=640, amp=True):
amp (bool): If True, use automatic mixed precision (AMP) for training.
Returns:
int: Optimal batch size computed using the autobatch() function.
(int): Optimal batch size computed using the autobatch() function.
"""
with torch.cuda.amp.autocast(amp):
@ -34,13 +34,13 @@ def autobatch(model, imgsz=640, fraction=0.67, batch_size=16):
Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory.
Args:
model: YOLO model to compute batch size for.
model (torch.nn.module): YOLO model to compute batch size for.
imgsz (int, optional): The image size used as input for the YOLO model. Defaults to 640.
fraction (float, optional): The fraction of available CUDA memory to use. Defaults to 0.67.
batch_size (int, optional): The default batch size to use if an error is detected. Defaults to 16.
Returns:
int: The optimal batch size.
(int): The optimal batch size.
"""
# Check device

View File

@ -19,14 +19,14 @@ except (ImportError, AssertionError):
clearml = None
def _log_debug_samples(files, title='Debug Samples'):
def _log_debug_samples(files, title='Debug Samples') -> None:
"""
Log files (images) as debug samples in the ClearML task.
Log files (images) as debug samples in the ClearML task.
arguments:
files (List(PosixPath)) a list of file paths in PosixPath format
title (str) A title that groups together images with the same values
"""
Args:
files (list): A list of file paths in PosixPath format.
title (str): A title that groups together images with the same values.
"""
task = Task.current_task()
if task:
for f in files:
@ -39,20 +39,23 @@ def _log_debug_samples(files, title='Debug Samples'):
iteration=iteration)
def _log_plot(title, plot_path):
def _log_plot(title, plot_path) -> None:
"""
Log image as plot in the plot section of ClearML
Log an image as a plot in the plot section of ClearML.
arguments:
title (str) Title of the plot
plot_path (PosixPath or str) Path to the saved image file
"""
Args:
title (str): The title of the plot.
plot_path (str): The path to the saved image file.
"""
img = mpimg.imread(plot_path)
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
ax.imshow(img)
Task.current_task().get_logger().report_matplotlib_figure(title, '', figure=fig, report_interactive=False)
Task.current_task().get_logger().report_matplotlib_figure(title=title,
series='',
figure=fig,
report_interactive=False)
def on_pretrain_routine_start(trainer):

View File

@ -47,13 +47,13 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value.
Args:
imgsz (int or List[int]): Image size.
imgsz (int) or (cList[int]): Image size.
stride (int): Stride value.
min_dim (int): Minimum number of dimensions.
floor (int): Minimum allowed value for image size.
Returns:
List[int]: Updated image size.
(List[int]): Updated image size.
"""
# Convert stride to integer if it is a tensor
stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride)
@ -106,7 +106,7 @@ def check_version(current: str = '0.0.0',
verbose (bool): If True, print warning message if minimum version is not met.
Returns:
bool: True if minimum version is met, False otherwise.
(bool): True if minimum version is met, False otherwise.
"""
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
result = (current == minimum) if pinned else (current >= minimum) # bool
@ -126,7 +126,7 @@ def check_latest_pypi_version(package_name='ultralytics'):
package_name (str): The name of the package to find the latest version for.
Returns:
str: The latest version of the package.
(str): The latest version of the package.
"""
requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning
response = requests.get(f'https://pypi.org/pypi/{package_name}/json', verify=False)
@ -140,7 +140,7 @@ def check_pip_update_available():
Checks if a new version of the ultralytics package is available on PyPI.
Returns:
bool: True if an update is available, False otherwise.
(bool): True if an update is available, False otherwise.
"""
if ONLINE and is_pip_package():
with contextlib.suppress(Exception):
@ -206,9 +206,6 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
exclude (Tuple[str]): Tuple of package names to exclude from checking.
install (bool): If True, attempt to auto-update packages that don't meet requirements.
cmds (str): Additional commands to pass to the pip install command when auto-updating.
Returns:
None
"""
prefix = colorstr('red', 'bold', 'requirements:')
check_python() # check python version

View File

@ -67,21 +67,21 @@ def safe_download(url,
min_bytes=1E0,
progress=True):
"""
Function for downloading files from a URL, with options for retrying, unzipping, and deleting the downloaded file.
Downloads files from a URL, with options for retrying, unzipping, and deleting the downloaded file.
Args:
url: str: The URL of the file to be downloaded.
file: str, optional: The filename of the downloaded file.
url (str): The URL of the file to be downloaded.
file (str, optional): The filename of the downloaded file.
If not provided, the file will be saved with the same name as the URL.
dir: str, optional: The directory to save the downloaded file.
dir (str, optional): The directory to save the downloaded file.
If not provided, the file will be saved in the current working directory.
unzip: bool, optional: Whether to unzip the downloaded file. Default: True.
delete: bool, optional: Whether to delete the downloaded file after unzipping. Default: False.
curl: bool, optional: Whether to use curl command line tool for downloading. Default: False.
retry: int, optional: The number of times to retry the download in case of failure. Default: 3.
min_bytes: float, optional: The minimum number of bytes that the downloaded file should have, to be considered
unzip (bool, optional): Whether to unzip the downloaded file. Default: True.
delete (bool, optional): Whether to delete the downloaded file after unzipping. Default: False.
curl (bool, optional): Whether to use curl command line tool for downloading. Default: False.
retry (int, optional): The number of times to retry the download in case of failure. Default: 3.
min_bytes (float, optional): The minimum number of bytes that the downloaded file should have, to be considered
a successful download. Default: 1E0.
progress: bool, optional: Whether to display a progress bar during the download. Default: True.
progress (bool, optional): Whether to display a progress bar during the download. Default: True.
"""
if '://' not in str(url) and Path(url).is_file(): # exists ('://' check required in Windows Python<3.10)
f = Path(url) # filename

View File

@ -30,13 +30,13 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False):
directory if it does not already exist.
Args:
path (str or pathlib.Path): Path to increment.
exist_ok (bool, optional): If True, the path will not be incremented and will be returned as-is. Defaults to False.
sep (str, optional): Separator to use between the path and the incrementation number. Defaults to an empty string.
mkdir (bool, optional): If True, the path will be created as a directory if it does not exist. Defaults to False.
path (str, pathlib.Path): Path to increment.
exist_ok (bool, optional): If True, the path will not be incremented and returned as-is. Defaults to False.
sep (str, optional): Separator to use between the path and the incrementation number. Defaults to ''.
mkdir (bool, optional): Create a directory if it does not exist. Defaults to False.
Returns:
pathlib.Path: Incremented path.
(pathlib.Path): Incremented path.
"""
path = Path(path) # os-agnostic
if path.exists() and not exist_ok:

View File

@ -98,7 +98,7 @@ class Bboxes:
def mul(self, scale):
"""
Args:
scale (tuple | List | int): the scale for four coords.
scale (tuple) or (list) or (int): the scale for four coords.
"""
if isinstance(scale, Number):
scale = to_4tuple(scale)
@ -112,7 +112,7 @@ class Bboxes:
def add(self, offset):
"""
Args:
offset (tuple | List | int): the offset for four coords.
offset (tuple) or (list) or (int): the offset for four coords.
"""
if isinstance(offset, Number):
offset = to_4tuple(offset)
@ -129,13 +129,18 @@ class Bboxes:
@classmethod
def concatenate(cls, boxes_list: List['Bboxes'], axis=0) -> 'Bboxes':
"""
Concatenates a list of Boxes into a single Bboxes
Concatenate a list of Bboxes objects into a single Bboxes object.
Arguments:
boxes_list (list[Bboxes])
Args:
boxes_list (List[Bboxes]): A list of Bboxes objects to concatenate.
axis (int, optional): The axis along which to concatenate the bounding boxes.
Defaults to 0.
Returns:
Bboxes: the concatenated Boxes
Bboxes: A new Bboxes object containing the concatenated bounding boxes.
Note:
The input should be a list or tuple of Bboxes objects.
"""
assert isinstance(boxes_list, (list, tuple))
if not boxes_list:
@ -148,11 +153,21 @@ class Bboxes:
def __getitem__(self, index) -> 'Bboxes':
"""
Retrieve a specific bounding box or a set of bounding boxes using indexing.
Args:
index: int, slice, or a BoolArray
index (int, slice, or np.ndarray): The index, slice, or boolean array to select
the desired bounding boxes.
Returns:
Bboxes: Create a new :class:`Bboxes` by indexing.
Bboxes: A new Bboxes object containing the selected bounding boxes.
Raises:
AssertionError: If the indexed bounding boxes do not form a 2-dimensional matrix.
Note:
When using boolean indexing, make sure to provide a boolean array with the same
length as the number of bounding boxes.
"""
if isinstance(index, int):
return Bboxes(self.bboxes[index].view(1, -1))
@ -236,11 +251,19 @@ class Instances:
def __getitem__(self, index) -> 'Instances':
"""
Retrieve a specific instance or a set of instances using indexing.
Args:
index: int, slice, or a BoolArray
index (int, slice, or np.ndarray): The index, slice, or boolean array to select
the desired instances.
Returns:
Instances: Create a new :class:`Instances` by indexing.
Instances: A new Instances object containing the selected bounding boxes,
segments, and keypoints if present.
Note:
When using boolean indexing, make sure to provide a boolean array with the same
length as the number of instances.
"""
segments = self.segments[index] if len(self.segments) else self.segments
keypoints = self.keypoints[index] if self.keypoints is not None else None
@ -305,14 +328,20 @@ class Instances:
@classmethod
def concatenate(cls, instances_list: List['Instances'], axis=0) -> 'Instances':
"""
Concatenates a list of Boxes into a single Bboxes
Concatenates a list of Instances objects into a single Instances object.
Arguments:
instances_list (list[Bboxes])
axis
Args:
instances_list (List[Instances]): A list of Instances objects to concatenate.
axis (int, optional): The axis along which the arrays will be concatenated. Defaults to 0.
Returns:
Boxes: the concatenated Boxes
Instances: A new Instances object containing the concatenated bounding boxes,
segments, and keypoints if present.
Note:
The `Instances` objects in the list should have the same properties, such as
the format of the bounding boxes, whether keypoints are present, and if the
coordinates are normalized.
"""
assert isinstance(instances_list, (list, tuple))
if not instances_list:

View File

@ -23,10 +23,16 @@ def box_area(box):
def bbox_ioa(box1, box2, eps=1e-7):
"""Returns the intersection over box2 area given box1, box2. Boxes are x1y1x2y2
box1: np.array of shape(nx4)
box2: np.array of shape(mx4)
returns: np.array of shape(nxm)
"""
Calculate the intersection over box2 area given box1 and box2. Boxes are in x1y1x2y2 format.
Args:
box1 (np.array): A numpy array of shape (n, 4) representing n bounding boxes.
box2 (np.array): A numpy array of shape (m, 4) representing m bounding boxes.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
Returns:
(np.array): A numpy array of shape (n, m) representing the intersection over box2 area.
"""
# Get the coordinates of bounding boxes
@ -46,17 +52,17 @@ def bbox_ioa(box1, box2, eps=1e-7):
def box_iou(box1, box2, eps=1e-7):
"""
Return intersection-over-union (Jaccard index) of boxes.
Calculate intersection-over-union (IoU) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
Arguments:
box1 (Tensor[N, 4])
box2 (Tensor[M, 4])
eps
Args:
box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes.
box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
Returns:
iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
(torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.
"""
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
@ -68,7 +74,22 @@ def box_iou(box1, box2, eps=1e-7):
def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
# Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)
"""
Calculate Intersection over Union (IoU) of box1(1, 4) to box2(n, 4).
Args:
box1 (torch.Tensor): A tensor representing a single bounding box with shape (1, 4).
box2 (torch.Tensor): A tensor representing n bounding boxes with shape (n, 4).
xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in
(x1, y1, x2, y2) format. Defaults to True.
GIoU (bool, optional): If True, calculate Generalized IoU. Defaults to False.
DIoU (bool, optional): If True, calculate Distance IoU. Defaults to False.
CIoU (bool, optional): If True, calculate Complete IoU. Defaults to False.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
Returns:
(torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags.
"""
# Get the coordinates of bounding boxes
if xywh: # transform from xywh to xyxy
@ -110,10 +131,17 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
def mask_iou(mask1, mask2, eps=1e-7):
"""
mask1: [N, n] m1 means number of gt objects
mask2: [M, n] m2 means number of predicted objects
Note: n means image_w x image_h
Returns: masks iou, [N, M]
Calculate masks IoU.
Args:
mask1 (torch.Tensor): A tensor of shape (N, n) where N is the number of ground truth objects and n is the
product of image width and height.
mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the
product of image width and height.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
Returns:
(torch.Tensor): A tensor of shape (N, M) representing masks IoU.
"""
intersection = torch.matmul(mask1, mask2.t()).clamp(0)
union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection # (area1 + area2) - intersection
@ -121,10 +149,18 @@ def mask_iou(mask1, mask2, eps=1e-7):
def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7):
"""OKS
kpt1: [N, 17, 3], gt
kpt2: [M, 17, 3], pred
area: [N], areas from gt
"""
Calculate Object Keypoint Similarity (OKS).
Args:
kpt1 (torch.Tensor): A tensor of shape (N, 17, 3) representing ground truth keypoints.
kpt2 (torch.Tensor): A tensor of shape (M, 17, 3) representing predicted keypoints.
area (torch.Tensor): A tensor of shape (N,) representing areas from ground truth.
sigma (list): A list containing 17 values representing keypoint scales.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
Returns:
(torch.Tensor): A tensor of shape (N, M) representing keypoint similarities.
"""
d = (kpt1[:, None, :, 0] - kpt2[..., 0]) ** 2 + (kpt1[:, None, :, 1] - kpt2[..., 1]) ** 2 # (N, M, 17)
sigma = torch.tensor(sigma, device=kpt1.device, dtype=kpt1.dtype) # (17, )
@ -171,7 +207,17 @@ class FocalLoss(nn.Module):
class ConfusionMatrix:
# Updated version of https://github.com/kaanakan/object_detection_confusion_matrix
"""
A class for calculating and updating a confusion matrix for object detection and classification tasks.
Attributes:
task (str): The type of task, either 'detect' or 'classify'.
matrix (np.array): The confusion matrix, with dimensions depending on the task.
nc (int): The number of classes.
conf (float): The confidence threshold for detections.
iou_thres (float): The Intersection over Union threshold.
"""
def __init__(self, nc, conf=0.25, iou_thres=0.45, task='detect'):
self.task = task
self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == 'detect' else np.zeros((nc, nc))
@ -183,12 +229,9 @@ class ConfusionMatrix:
"""
Update confusion matrix for classification task
Arguments:
preds (Array[N, min(nc,5)])
targets (Array[N, 1])
Returns:
None, updates confusion matrix accordingly
Args:
preds (Array[N, min(nc,5)]): Predicted class labels.
targets (Array[N, 1]): Ground truth class labels.
"""
preds, targets = torch.cat(preds)[:, 0], torch.cat(targets)
for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):
@ -196,15 +239,13 @@ class ConfusionMatrix:
def process_batch(self, detections, labels):
"""
Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Update confusion matrix for object detection task.
Arguments:
detections (Array[N, 6]), x1, y1, x2, y2, conf, class
labels (Array[M, 5]), class, x1, y1, x2, y2
Returns:
None, updates confusion matrix accordingly
Args:
detections (Array[N, 6]): Detected bounding boxes and their associated information.
Each row should contain (x1, y1, x2, y2, conf, class).
labels (Array[M, 5]): Ground truth bounding boxes and their associated class labels.
Each row should contain (class, x1, y1, x2, y2).
"""
if detections is None:
gt_classes = labels.int()
@ -254,6 +295,14 @@ class ConfusionMatrix:
@TryExcept('WARNING ⚠️ ConfusionMatrix plot failure')
@plt_settings()
def plot(self, normalize=True, save_dir='', names=()):
"""
Plot the confusion matrix using seaborn and save it to a file.
Args:
normalize (bool): Whether to normalize the confusion matrix.
save_dir (str): Directory where the plot will be saved.
names (tuple): Names of classes, used as labels on the plot.
"""
import seaborn as sn
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
@ -284,6 +333,9 @@ class ConfusionMatrix:
plt.close(fig)
def print(self):
"""
Print the confusion matrix to the console.
"""
for i in range(self.nc + 1):
LOGGER.info(' '.join(map(str, self.matrix[i])))
@ -343,12 +395,17 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi
def compute_ap(recall, precision):
""" Compute the average precision, given the recall and precision curves
"""
Compute the average precision (AP) given the recall and precision curves.
Arguments:
recall: The recall curve (list)
precision: The precision curve (list)
recall (list): The recall curve.
precision (list): The precision curve.
Returns:
Average precision, precision curve, recall curve
(float): Average precision.
(np.ndarray): Precision envelope curve.
(np.ndarray): Modified recall curve with sentinel values added at the beginning and end.
"""
# Append sentinel values to beginning and end
@ -488,57 +545,71 @@ class Metric(SimpleClass):
@property
def ap50(self):
"""AP@0.5 of all classes.
"""
Returns the Average Precision (AP) at an IoU threshold of 0.5 for all classes.
Returns:
(nc, ) or [].
(np.ndarray, list): Array of shape (nc,) with AP50 values per class, or an empty list if not available.
"""
return self.all_ap[:, 0] if len(self.all_ap) else []
@property
def ap(self):
"""AP@0.5:0.95
"""
Returns the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.
Returns:
(nc, ) or [].
(np.ndarray, list): Array of shape (nc,) with AP50-95 values per class, or an empty list if not available.
"""
return self.all_ap.mean(1) if len(self.all_ap) else []
@property
def mp(self):
"""mean precision of all classes.
"""
Returns the Mean Precision of all classes.
Returns:
float.
(float): The mean precision of all classes.
"""
return self.p.mean() if len(self.p) else 0.0
@property
def mr(self):
"""mean recall of all classes.
"""
Returns the Mean Recall of all classes.
Returns:
float.
(float): The mean recall of all classes.
"""
return self.r.mean() if len(self.r) else 0.0
@property
def map50(self):
"""Mean AP@0.5 of all classes.
"""
Returns the mean Average Precision (mAP) at an IoU threshold of 0.5.
Returns:
float.
(float): The mAP50 at an IoU threshold of 0.5.
"""
return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0
@property
def map75(self):
"""Mean AP@0.75 of all classes.
"""
Returns the mean Average Precision (mAP) at an IoU threshold of 0.75.
Returns:
float.
(float): The mAP50 at an IoU threshold of 0.75.
"""
return self.all_ap[:, 5].mean() if len(self.all_ap) else 0.0
@property
def map(self):
"""Mean AP@0.5:0.95 of all classes.
"""
Returns the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
Returns:
float.
(float): The mAP over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
"""
return self.all_ap.mean() if len(self.all_ap) else 0.0
@ -566,7 +637,7 @@ class Metric(SimpleClass):
def update(self, results):
"""
Args:
results: tuple(p, r, ap, f1, ap_class)
results (tuple): A tuple of (p, r, ap, f1, ap_class)
"""
self.p, self.r, self.f1, self.all_ap, self.ap_class_index = results

View File

@ -120,10 +120,10 @@ def make_divisible(x, divisor):
Args:
x (int): The number to make divisible.
divisor (int or torch.Tensor): The divisor.
divisor (int) or (torch.Tensor): The divisor.
Returns:
int: The nearest number divisible by the divisor.
(int): The nearest number divisible by the divisor.
"""
if isinstance(divisor, torch.Tensor):
divisor = int(divisor.max()) # to int

View File

@ -127,7 +127,7 @@ class Annotator:
masks_color = masks * (colors * alpha) # shape(n,h,w,3)
inv_alph_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
mcs = (masks_color * inv_alph_masks).sum(0) * 2 # mask color summand shape(n,h,w,3)
mcs = masks_color.max(dim=0).values # shape(n,h,w,3)
im_gpu = im_gpu.flip(dims=[0]) # flip channel
im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
@ -140,12 +140,16 @@ class Annotator:
self.fromarray(self.im)
def kpts(self, kpts, shape=(640, 640), radius=5, kpt_line=True):
"""Plot keypoints.
"""Plot keypoints on the image.
Args:
kpts (tensor): predicted kpts, shape: [17, 3]
shape (tuple): image shape, (h, w)
steps (int): keypoints step
radius (int): size of drawing points
kpts (tensor): Predicted keypoints with shape [17, 3]. Each keypoint has (x, y, confidence).
shape (tuple): Image shape as a tuple (h, w), where h is the height and w is the width.
radius (int, optional): Radius of the drawn keypoints. Default is 5.
kpt_line (bool, optional): If True, the function will draw lines connecting keypoints
for human pose. Default is True.
Note: `kpt_line=True` currently only supports human pose plotting.
"""
if self.pil:
# convert to numpy first

View File

@ -54,6 +54,19 @@ def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
class TaskAlignedAssigner(nn.Module):
"""
A task-aligned assigner for object detection.
This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric,
which combines both classification and localization information.
Attributes:
topk (int): The number of top candidates to consider.
num_classes (int): The number of object classes.
alpha (float): The alpha parameter for the classification component of the task-aligned metric.
beta (float): The beta parameter for the localization component of the task-aligned metric.
eps (float): A small value to prevent division by zero.
"""
def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
super().__init__()
@ -66,8 +79,9 @@ class TaskAlignedAssigner(nn.Module):
@torch.no_grad()
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
"""This code referenced to
https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
"""
Compute the task-aligned assignment.
Reference https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
Args:
pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
@ -76,11 +90,13 @@ class TaskAlignedAssigner(nn.Module):
gt_labels (Tensor): shape(bs, n_max_boxes, 1)
gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
mask_gt (Tensor): shape(bs, n_max_boxes, 1)
Returns:
target_labels (Tensor): shape(bs, num_total_anchors)
target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
fg_mask (Tensor): shape(bs, num_total_anchors)
target_gt_idx (Tensor): shape(bs, num_total_anchors)
"""
self.bs = pd_scores.size(0)
self.n_max_boxes = gt_bboxes.size(1)
@ -142,9 +158,19 @@ class TaskAlignedAssigner(nn.Module):
def select_topk_candidates(self, metrics, largest=True, topk_mask=None):
"""
Select the top-k candidates based on the given metrics.
Args:
metrics: (b, max_num_obj, h*w).
topk_mask: (b, max_num_obj, topk) or None
metrics (Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size,
max_num_obj is the maximum number of objects, and h*w represents the
total number of anchor points.
largest (bool): If True, select the largest values; otherwise, select the smallest values.
topk_mask (Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), where
topk is the number of top candidates to consider. If not provided,
the top-k values are automatically computed based on the given metrics.
Returns:
(Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
"""
num_anchors = metrics.shape[-1] # h*w
@ -165,22 +191,38 @@ class TaskAlignedAssigner(nn.Module):
def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
"""
Compute target labels, target bounding boxes, and target scores for the positive anchor points.
Args:
gt_labels: (b, max_num_obj, 1)
gt_bboxes: (b, max_num_obj, 4)
target_gt_idx: (b, h*w)
fg_mask: (b, h*w)
gt_labels (Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
batch size and max_num_obj is the maximum number of objects.
gt_bboxes (Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
target_gt_idx (Tensor): Indices of the assigned ground truth objects for positive
anchor points, with shape (b, h*w), where h*w is the total
number of anchor points.
fg_mask (Tensor): A boolean tensor of shape (b, h*w) indicating the positive
(foreground) anchor points.
Returns:
(Tuple[Tensor, Tensor, Tensor]): A tuple containing the following tensors:
- target_labels (Tensor): Shape (b, h*w), containing the target labels for
positive anchor points.
- target_bboxes (Tensor): Shape (b, h*w, 4), containing the target bounding boxes
for positive anchor points.
- target_scores (Tensor): Shape (b, h*w, num_classes), containing the target scores
for positive anchor points, where num_classes is the number
of object classes.
"""
# assigned target labels, (b, 1)
# Assigned target labels, (b, 1)
batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w)
target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w)
# assigned target boxes, (b, max_num_obj, 4) -> (b, h*w)
# Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w)
target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx]
# assigned target scores
# Assigned target scores
target_labels.clamp(0)
target_scores = F.one_hot(target_labels, self.num_classes) # (b, h*w, 80)
fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80)

View File

@ -427,7 +427,7 @@ class EarlyStopping:
fitness (float): Fitness value of current epoch
Returns:
bool: True if training should stop, False otherwise
(bool): True if training should stop, False otherwise
"""
if fitness is None: # check if fitness=None (happens when val=False)
return False

View File

@ -101,18 +101,6 @@ class ClassificationTrainer(BaseTrainer):
loss_items = loss.detach()
return loss, loss_items
# def label_loss_items(self, loss_items=None, prefix="train"):
# """
# Returns a loss dict with labelled training loss items tensor
# """
# # Not needed for classification but necessary for segmentation & detection
# keys = [f"{prefix}/{x}" for x in self.loss_names]
# if loss_items is not None:
# loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
# return dict(zip(keys, loss_items))
# else:
# return keys
def label_loss_items(self, loss_items=None, prefix='train'):
"""
Returns a loss dict with labelled training loss items tensor