ultralytics 8.0.79 expand Docs reference section (#2053)

Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Fri3dChicken <87434761+AmoghDhaliwal@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-04-16 12:28:12 +02:00
committed by GitHub
parent 47bd8b433b
commit 31db8ed163
106 changed files with 2570 additions and 529 deletions

View File

@ -262,13 +262,15 @@ class RandomPerspective:
return img, M, s
def apply_bboxes(self, bboxes, M):
"""apply affine to bboxes only.
"""
Apply affine to bboxes only.
Args:
bboxes(ndarray): list of bboxes, xyxy format, with shape (num_bboxes, 4).
M(ndarray): affine matrix.
bboxes (ndarray): list of bboxes, xyxy format, with shape (num_bboxes, 4).
M (ndarray): affine matrix.
Returns:
new_bboxes(ndarray): bboxes after affine, [num_bboxes, 4].
new_bboxes (ndarray): bboxes after affine, [num_bboxes, 4].
"""
n = len(bboxes)
if n == 0:
@ -285,14 +287,16 @@ class RandomPerspective:
return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1)), dtype=bboxes.dtype).reshape(4, n).T
def apply_segments(self, segments, M):
"""apply affine to segments and generate new bboxes from segments.
"""
Apply affine to segments and generate new bboxes from segments.
Args:
segments(ndarray): list of segments, [num_samples, 500, 2].
M(ndarray): affine matrix.
segments (ndarray): list of segments, [num_samples, 500, 2].
M (ndarray): affine matrix.
Returns:
new_segments(ndarray): list of segments after affine, [num_samples, 500, 2].
new_bboxes(ndarray): bboxes after affine, [N, 4].
new_segments (ndarray): list of segments after affine, [num_samples, 500, 2].
new_bboxes (ndarray): bboxes after affine, [N, 4].
"""
n, num = segments.shape[:2]
if n == 0:
@ -308,13 +312,15 @@ class RandomPerspective:
return bboxes, segments
def apply_keypoints(self, keypoints, M):
"""apply affine to keypoints.
"""
Apply affine to keypoints.
Args:
keypoints(ndarray): keypoints, [N, 17, 3].
M(ndarray): affine matrix.
keypoints (ndarray): keypoints, [N, 17, 3].
M (ndarray): affine matrix.
Return:
new_keypoints(ndarray): keypoints after affine, [N, 17, 3].
new_keypoints (ndarray): keypoints after affine, [N, 17, 3].
"""
n, nkpt = keypoints.shape[:2]
if n == 0:
@ -333,7 +339,7 @@ class RandomPerspective:
Affine images and targets.
Args:
labels(Dict): a dict of `bboxes`, `segments`, `keypoints`.
labels (dict): a dict of `bboxes`, `segments`, `keypoints`.
"""
if self.pre_transform and 'mosaic_border' not in labels:
labels = self.pre_transform(labels)

View File

@ -18,11 +18,30 @@ from .utils import HELP_URL, IMG_FORMATS
class BaseDataset(Dataset):
"""Base Dataset.
"""
Base dataset class for loading and processing image data.
Args:
img_path (str): image path.
pipeline (dict): a dict of image transforms.
label_path (str): label path, this can also be an ann_file or other custom label path.
img_path (str): Image path.
imgsz (int): Target image size for resizing. Default is 640.
cache (bool): Cache images in memory or on disk for faster loading. Default is False.
augment (bool): Apply data augmentation. Default is True.
hyp (dict): Dictionary of hyperparameters for data augmentation. Default is None.
prefix (str): Prefix for file paths. Default is an empty string.
rect (bool): Enable rectangular training. Default is False.
batch_size (int): Batch size for rectangular training. Default is None.
stride (int): Stride for rectangular training. Default is 32.
pad (float): Padding for rectangular training. Default is 0.5.
single_cls (bool): Use a single class for all labels. Default is False.
classes (list): List of included classes. Default is None.
Attributes:
im_files (list): List of image file paths.
labels (list): List of label data dictionaries.
ni (int): Number of images in the dataset.
ims (list): List of loaded images.
npy_files (list): List of numpy file paths.
transforms (callable): Image transformation function.
"""
def __init__(self,

View File

@ -21,10 +21,7 @@ from .utils import PIN_MEMORY
class InfiniteDataLoader(dataloader.DataLoader):
"""Dataloader that reuses workers
Uses same syntax as vanilla DataLoader
"""
"""Dataloader that reuses workers. Uses same syntax as vanilla DataLoader."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -40,10 +37,11 @@ class InfiniteDataLoader(dataloader.DataLoader):
class _RepeatSampler:
"""Sampler that repeats forever
"""
Sampler that repeats forever.
Args:
sampler (Sampler)
sampler (Dataset.sampler): The sampler to repeat.
"""
def __init__(self, sampler):
@ -173,7 +171,7 @@ def load_inference_source(source=None, transforms=None, imgsz=640, vid_stride=1,
auto (bool, optional): Automatically apply pre-processing. Default is True.
Returns:
dataset: A dataset object for the specified input source.
dataset (Dataset): A dataset object for the specified input source.
"""
source, webcam, screenshot, from_img, in_memory, tensor = check_source(source)
source_type = source.source_type if in_memory else SourceTypes(webcam, screenshot, from_img, tensor)

View File

@ -178,7 +178,7 @@ class _RepeatSampler:
""" Sampler that repeats forever
Args:
sampler (Sampler)
sampler (Dataset.sampler): The sampler to repeat.
"""
def __init__(self, sampler):

View File

@ -17,31 +17,31 @@ from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image_lab
class YOLODataset(BaseDataset):
cache_version = '1.0.2' # dataset labels *.cache version, >= 1.0.0 for YOLOv8
rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
"""
Dataset class for loading images object detection and/or segmentation labels in YOLO format.
Dataset class for loading object detection and/or segmentation labels in YOLO format.
Args:
img_path (str): path to the folder containing images.
imgsz (int): image size (default: 640).
cache (bool): if True, a cache file of the labels is created to speed up future creation of dataset instances
(default: False).
augment (bool): if True, data augmentation is applied (default: True).
hyp (dict): hyperparameters to apply data augmentation (default: None).
prefix (str): prefix to print in log messages (default: '').
rect (bool): if True, rectangular training is used (default: False).
batch_size (int): size of batches (default: None).
stride (int): stride (default: 32).
pad (float): padding (default: 0.0).
single_cls (bool): if True, single class training is used (default: False).
use_segments (bool): if True, segmentation masks are used as labels (default: False).
use_keypoints (bool): if True, keypoints are used as labels (default: False).
names (dict): A dictionary of class names. (default: None).
img_path (str): Path to the folder containing images.
imgsz (int, optional): Image size. Defaults to 640.
cache (bool, optional): Cache images to RAM or disk during training. Defaults to False.
augment (bool, optional): If True, data augmentation is applied. Defaults to True.
hyp (dict, optional): Hyperparameters to apply data augmentation. Defaults to None.
prefix (str, optional): Prefix to print in log messages. Defaults to ''.
rect (bool, optional): If True, rectangular training is used. Defaults to False.
batch_size (int, optional): Size of batches. Defaults to None.
stride (int, optional): Stride. Defaults to 32.
pad (float, optional): Padding. Defaults to 0.0.
single_cls (bool, optional): If True, single class training is used. Defaults to False.
use_segments (bool, optional): If True, segmentation masks are used as labels. Defaults to False.
use_keypoints (bool, optional): If True, keypoints are used as labels. Defaults to False.
data (dict, optional): A dataset YAML dictionary. Defaults to None.
classes (list): List of included classes. Default is None.
Returns:
A PyTorch dataset object that can be used for training an object detection or segmentation model.
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
"""
cache_version = '1.0.2' # dataset labels *.cache version, >= 1.0.0 for YOLOv8
rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
def __init__(self,
img_path,

View File

@ -7,21 +7,36 @@ from .augment import LetterBox
class MixAndRectDataset:
"""A wrapper of multiple images mixed dataset.
"""
A dataset class that applies mosaic and mixup transformations as well as rectangular training.
Args:
dataset (:obj:`BaseDataset`): The dataset to be mixed.
transforms (Sequence[dict]): config dict to be composed.
Attributes:
dataset: The base dataset.
imgsz: The size of the images in the dataset.
"""
def __init__(self, dataset):
"""
Args:
dataset (BaseDataset): The base dataset to apply transformations to.
"""
self.dataset = dataset
self.imgsz = dataset.imgsz
def __len__(self):
"""Returns the number of items in the dataset."""
return len(self.dataset)
def __getitem__(self, index):
"""
Applies mosaic, mixup and rectangular training transformations to an item in the dataset.
Args:
index (int): Index of the item in the dataset.
Returns:
(dict): A dictionary containing the transformed item data.
"""
labels = deepcopy(self.dataset[index])
for transform in self.dataset.transforms.tolist():
# mosaic and mixup

View File

@ -270,9 +270,8 @@ def check_cls_dataset(dataset: str):
"""
Check a classification dataset such as Imagenet.
Copy code
This function takes a `dataset` name as input and returns a dictionary containing information about the dataset.
If the dataset is not found, it attempts to download the dataset from the internet and save it to the local file system.
If the dataset is not found, it attempts to download the dataset from the internet and save it locally.
Args:
dataset (str): Name of the dataset.
@ -306,7 +305,8 @@ def check_cls_dataset(dataset: str):
class HUBDatasetStats():
""" Class for generating HUB dataset JSON and `-hub` dataset directory
"""
Class for generating HUB dataset JSON and `-hub` dataset directory
Arguments
path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
@ -427,9 +427,6 @@ def compress_one_image(f, f_new=None, max_dim=1920, quality=50):
max_dim (int, optional): The maximum dimension (width or height) of the output image. Default is 1920 pixels.
quality (int, optional): The image compression quality as a percentage. Default is 50%.
Returns:
None
Usage:
from pathlib import Path
from ultralytics.yolo.data.utils import compress_one_image
@ -459,9 +456,6 @@ def delete_dsstore(path):
Args:
path (str, optional): The directory path where the ".DS_store" files should be deleted.
Returns:
None
Usage:
from ultralytics.yolo.data.utils import delete_dsstore
delete_dsstore('/Users/glennjocher/Downloads/dataset')
@ -478,15 +472,13 @@ def delete_dsstore(path):
def zip_directory(dir, use_zipfile_library=True):
"""Zips a directory and saves the archive to the specified output path.
"""
Zips a directory and saves the archive to the specified output path.
Args:
dir (str): The path to the directory to be zipped.
use_zipfile_library (bool): Whether to use zipfile library or shutil for zipping.
Returns:
None
Usage:
from ultralytics.yolo.data.utils import zip_directory
zip_directory('/Users/glennjocher/Downloads/playground')