ultralytics 8.0.81
single-line docstring updates (#2061)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -393,6 +393,7 @@ def entrypoint(debug=''):
|
||||
|
||||
# Special modes --------------------------------------------------------------------------------------------------------
|
||||
def copy_default_cfg():
|
||||
"""Copy and create a new default configuration file with '_copy' appended to its name."""
|
||||
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
|
||||
shutil.copy2(DEFAULT_CFG_PATH, new_file)
|
||||
LOGGER.info(f'{DEFAULT_CFG_PATH} copied to {new_file}\n'
|
||||
|
@ -26,15 +26,19 @@ class BaseTransform:
|
||||
pass
|
||||
|
||||
def apply_image(self, labels):
|
||||
"""Applies image transformation to labels."""
|
||||
pass
|
||||
|
||||
def apply_instances(self, labels):
|
||||
"""Applies transformations to input 'labels' and returns object instances."""
|
||||
pass
|
||||
|
||||
def apply_semantic(self, labels):
|
||||
"""Applies semantic segmentation to an image."""
|
||||
pass
|
||||
|
||||
def __call__(self, labels):
|
||||
"""Applies label transformations to an image, instances and semantic masks."""
|
||||
self.apply_image(labels)
|
||||
self.apply_instances(labels)
|
||||
self.apply_semantic(labels)
|
||||
@ -43,20 +47,25 @@ class BaseTransform:
|
||||
class Compose:
|
||||
|
||||
def __init__(self, transforms):
|
||||
"""Initializes the Compose object with a list of transforms."""
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, data):
|
||||
"""Applies a series of transformations to input data."""
|
||||
for t in self.transforms:
|
||||
data = t(data)
|
||||
return data
|
||||
|
||||
def append(self, transform):
|
||||
"""Appends a new transform to the existing list of transforms."""
|
||||
self.transforms.append(transform)
|
||||
|
||||
def tolist(self):
|
||||
"""Converts list of transforms to a standard Python list."""
|
||||
return self.transforms
|
||||
|
||||
def __repr__(self):
|
||||
"""Return string representation of object."""
|
||||
format_string = f'{self.__class__.__name__}('
|
||||
for t in self.transforms:
|
||||
format_string += '\n'
|
||||
@ -74,6 +83,7 @@ class BaseMixTransform:
|
||||
self.p = p
|
||||
|
||||
def __call__(self, labels):
|
||||
"""Applies pre-processing transforms and mixup/mosaic transforms to labels data."""
|
||||
if random.uniform(0, 1) > self.p:
|
||||
return labels
|
||||
|
||||
@ -96,9 +106,11 @@ class BaseMixTransform:
|
||||
return labels
|
||||
|
||||
def _mix_transform(self, labels):
|
||||
"""Applies MixUp or Mosaic augmentation to the label dictionary."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_indexes(self):
|
||||
"""Gets a list of shuffled indexes for mosaic augmentation."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -111,6 +123,7 @@ class Mosaic(BaseMixTransform):
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, imgsz=640, p=1.0, border=(0, 0)):
|
||||
"""Initializes the object with a dataset, image size, probability, and border."""
|
||||
assert 0 <= p <= 1.0, 'The probability should be in range [0, 1]. ' f'got {p}.'
|
||||
super().__init__(dataset=dataset, p=p)
|
||||
self.dataset = dataset
|
||||
@ -118,9 +131,11 @@ class Mosaic(BaseMixTransform):
|
||||
self.border = border
|
||||
|
||||
def get_indexes(self):
|
||||
"""Return a list of 3 random indexes from the dataset."""
|
||||
return [random.randint(0, len(self.dataset) - 1) for _ in range(3)]
|
||||
|
||||
def _mix_transform(self, labels):
|
||||
"""Apply mixup transformation to the input image and labels."""
|
||||
mosaic_labels = []
|
||||
assert labels.get('rect_shape', None) is None, 'rect and mosaic is exclusive.'
|
||||
assert len(labels.get('mix_labels', [])) > 0, 'There are no other images for mosaic augment.'
|
||||
@ -166,6 +181,7 @@ class Mosaic(BaseMixTransform):
|
||||
return labels
|
||||
|
||||
def _cat_labels(self, mosaic_labels):
|
||||
"""Return labels with mosaic border instances clipped."""
|
||||
if len(mosaic_labels) == 0:
|
||||
return {}
|
||||
cls = []
|
||||
@ -190,6 +206,7 @@ class MixUp(BaseMixTransform):
|
||||
super().__init__(dataset=dataset, pre_transform=pre_transform, p=p)
|
||||
|
||||
def get_indexes(self):
|
||||
"""Get a random index from the dataset."""
|
||||
return random.randint(0, len(self.dataset) - 1)
|
||||
|
||||
def _mix_transform(self, labels):
|
||||
@ -400,6 +417,7 @@ class RandomHSV:
|
||||
self.vgain = vgain
|
||||
|
||||
def __call__(self, labels):
|
||||
"""Applies random horizontal or vertical flip to an image with a given probability."""
|
||||
img = labels['img']
|
||||
if self.hgain or self.sgain or self.vgain:
|
||||
r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1 # random gains
|
||||
@ -427,6 +445,7 @@ class RandomFlip:
|
||||
self.flip_idx = flip_idx
|
||||
|
||||
def __call__(self, labels):
|
||||
"""Resize image and padding for detection, instance segmentation, pose."""
|
||||
img = labels['img']
|
||||
instances = labels.pop('instances')
|
||||
instances.convert_bbox(format='xywh')
|
||||
@ -453,6 +472,7 @@ class LetterBox:
|
||||
"""Resize image and padding for detection, instance segmentation, pose."""
|
||||
|
||||
def __init__(self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32):
|
||||
"""Initialize LetterBox object with specific parameters."""
|
||||
self.new_shape = new_shape
|
||||
self.auto = auto
|
||||
self.scaleFill = scaleFill
|
||||
@ -460,6 +480,7 @@ class LetterBox:
|
||||
self.stride = stride
|
||||
|
||||
def __call__(self, labels=None, image=None):
|
||||
"""Return updated labels and image with added border."""
|
||||
if labels is None:
|
||||
labels = {}
|
||||
img = labels.get('img') if image is None else image
|
||||
@ -556,6 +577,7 @@ class CopyPaste:
|
||||
class Albumentations:
|
||||
# YOLOv8 Albumentations class (optional, only used if package is installed)
|
||||
def __init__(self, p=1.0):
|
||||
"""Initialize the transform object for YOLO bbox formatted params."""
|
||||
self.p = p
|
||||
self.transform = None
|
||||
prefix = colorstr('albumentations: ')
|
||||
@ -581,6 +603,7 @@ class Albumentations:
|
||||
LOGGER.info(f'{prefix}{e}')
|
||||
|
||||
def __call__(self, labels):
|
||||
"""Generates object detections and returns a dictionary with detection results."""
|
||||
im = labels['img']
|
||||
cls = labels['cls']
|
||||
if len(cls):
|
||||
@ -618,6 +641,7 @@ class Format:
|
||||
self.batch_idx = batch_idx # keep the batch indexes
|
||||
|
||||
def __call__(self, labels):
|
||||
"""Return formatted image, classes, bounding boxes & keypoints to be used by 'collate_fn'."""
|
||||
img = labels.pop('img')
|
||||
h, w = img.shape[:2]
|
||||
cls = labels.pop('cls')
|
||||
@ -647,6 +671,7 @@ class Format:
|
||||
return labels
|
||||
|
||||
def _format_img(self, img):
|
||||
"""Format the image for YOLOv5 from Numpy array to PyTorch tensor."""
|
||||
if len(img.shape) < 3:
|
||||
img = np.expand_dims(img, -1)
|
||||
img = np.ascontiguousarray(img.transpose(2, 0, 1)[::-1])
|
||||
@ -668,6 +693,7 @@ class Format:
|
||||
|
||||
|
||||
def v8_transforms(dataset, imgsz, hyp):
|
||||
"""Convert images to a size suitable for YOLOv8 training."""
|
||||
pre_transform = Compose([
|
||||
Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic, border=[-imgsz // 2, -imgsz // 2]),
|
||||
CopyPaste(p=hyp.copy_paste),
|
||||
@ -749,6 +775,7 @@ def classify_albumentations(
|
||||
class ClassifyLetterBox:
|
||||
# YOLOv8 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
|
||||
def __init__(self, size=(640, 640), auto=False, stride=32):
|
||||
"""Resizes image and crops it to center with max dimensions 'h' and 'w'."""
|
||||
super().__init__()
|
||||
self.h, self.w = (size, size) if isinstance(size, int) else size
|
||||
self.auto = auto # pass max size integer, automatically solve for short side using stride
|
||||
@ -768,6 +795,7 @@ class ClassifyLetterBox:
|
||||
class CenterCrop:
|
||||
# YOLOv8 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])
|
||||
def __init__(self, size=640):
|
||||
"""Converts an image from numpy array to PyTorch tensor."""
|
||||
super().__init__()
|
||||
self.h, self.w = (size, size) if isinstance(size, int) else size
|
||||
|
||||
@ -781,6 +809,7 @@ class CenterCrop:
|
||||
class ToTensor:
|
||||
# YOLOv8 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
|
||||
def __init__(self, half=False):
|
||||
"""Initialize YOLOv8 ToTensor object with optional half-precision support."""
|
||||
super().__init__()
|
||||
self.half = half
|
||||
|
||||
|
@ -170,6 +170,7 @@ class BaseDataset(Dataset):
|
||||
np.save(f.as_posix(), cv2.imread(self.im_files[i]))
|
||||
|
||||
def set_rectangle(self):
|
||||
"""Sets the shape of bounding boxes for YOLO detections as rectangles."""
|
||||
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
|
||||
nb = bi[-1] + 1 # number of batches
|
||||
|
||||
@ -194,9 +195,11 @@ class BaseDataset(Dataset):
|
||||
self.batch = bi # batch index of image
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Returns transformed label information for given index."""
|
||||
return self.transforms(self.get_label_info(index))
|
||||
|
||||
def get_label_info(self, index):
|
||||
"""Get and return label information from the dataset."""
|
||||
label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948
|
||||
label.pop('shape', None) # shape is for rect, remove it
|
||||
label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index)
|
||||
@ -208,6 +211,7 @@ class BaseDataset(Dataset):
|
||||
return label
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the length of the labels list for the dataset."""
|
||||
return len(self.labels)
|
||||
|
||||
def update_labels_info(self, label):
|
||||
|
@ -24,14 +24,17 @@ class InfiniteDataLoader(dataloader.DataLoader):
|
||||
"""Dataloader that reuses workers. Uses same syntax as vanilla DataLoader."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Dataloader that infinitely recycles workers, inherits from DataLoader."""
|
||||
super().__init__(*args, **kwargs)
|
||||
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
|
||||
self.iterator = super().__iter__()
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the length of the batch sampler's sampler."""
|
||||
return len(self.batch_sampler.sampler)
|
||||
|
||||
def __iter__(self):
|
||||
"""Creates a sampler that repeats indefinitely."""
|
||||
for _ in range(len(self)):
|
||||
yield next(self.iterator)
|
||||
|
||||
@ -45,9 +48,11 @@ class _RepeatSampler:
|
||||
"""
|
||||
|
||||
def __init__(self, sampler):
|
||||
"""Initializes an object that repeats a given sampler indefinitely."""
|
||||
self.sampler = sampler
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterates over the 'sampler' and yields its contents."""
|
||||
while True:
|
||||
yield from iter(self.sampler)
|
||||
|
||||
@ -60,6 +65,7 @@ def seed_worker(worker_id): # noqa
|
||||
|
||||
|
||||
def build_dataloader(cfg, batch, img_path, data_info, stride=32, rect=False, rank=-1, mode='train'):
|
||||
"""Return an InfiniteDataLoader or DataLoader for training or validation set."""
|
||||
assert mode in ['train', 'val']
|
||||
shuffle = mode == 'train'
|
||||
if cfg.rect and shuffle:
|
||||
@ -134,6 +140,7 @@ def build_classification_dataloader(path,
|
||||
|
||||
|
||||
def check_source(source):
|
||||
"""Check source type and return corresponding flag values."""
|
||||
webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False
|
||||
if isinstance(source, (str, int, Path)): # int for local usb camera
|
||||
source = str(source)
|
||||
|
@ -32,6 +32,7 @@ class SourceTypes:
|
||||
class LoadStreams:
|
||||
# YOLOv8 streamloader, i.e. `yolo predict source='rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
|
||||
def __init__(self, sources='file.streams', imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1):
|
||||
"""Initialize instance variables and check for consistent input stream shapes."""
|
||||
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
|
||||
self.mode = 'stream'
|
||||
self.imgsz = imgsz
|
||||
@ -97,10 +98,12 @@ class LoadStreams:
|
||||
time.sleep(0.0) # wait time
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterates through YOLO image feed and re-opens unresponsive streams."""
|
||||
self.count = -1
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
"""Returns source paths, transformed and original images for processing YOLOv5."""
|
||||
self.count += 1
|
||||
if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
|
||||
cv2.destroyAllWindows()
|
||||
@ -117,6 +120,7 @@ class LoadStreams:
|
||||
return self.sources, im, im0, None, ''
|
||||
|
||||
def __len__(self):
|
||||
"""Return the length of the sources object."""
|
||||
return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
|
||||
|
||||
|
||||
@ -153,6 +157,7 @@ class LoadScreenshots:
|
||||
self.monitor = {'left': self.left, 'top': self.top, 'width': self.width, 'height': self.height}
|
||||
|
||||
def __iter__(self):
|
||||
"""Returns an iterator of the object."""
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
@ -173,6 +178,7 @@ class LoadScreenshots:
|
||||
class LoadImages:
|
||||
# YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4`
|
||||
def __init__(self, path, imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1):
|
||||
"""Initialize the Dataloader and raise FileNotFoundError if file not found."""
|
||||
if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line
|
||||
path = Path(path).read_text().rsplit()
|
||||
files = []
|
||||
@ -211,10 +217,12 @@ class LoadImages:
|
||||
f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}')
|
||||
|
||||
def __iter__(self):
|
||||
"""Returns an iterator object for VideoStream or ImageFolder."""
|
||||
self.count = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
"""Return next image, path and metadata from dataset."""
|
||||
if self.count == self.nf:
|
||||
raise StopIteration
|
||||
path = self.files[self.count]
|
||||
@ -276,12 +284,14 @@ class LoadImages:
|
||||
return im
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the number of files in the object."""
|
||||
return self.nf # number of files
|
||||
|
||||
|
||||
class LoadPilAndNumpy:
|
||||
|
||||
def __init__(self, im0, imgsz=640, stride=32, auto=True, transforms=None):
|
||||
"""Initialize PIL and Numpy Dataloader."""
|
||||
if not isinstance(im0, list):
|
||||
im0 = [im0]
|
||||
self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)]
|
||||
@ -296,6 +306,7 @@ class LoadPilAndNumpy:
|
||||
|
||||
@staticmethod
|
||||
def _single_check(im):
|
||||
"""Validate and format an image to numpy array."""
|
||||
assert isinstance(im, (Image.Image, np.ndarray)), f'Expected PIL/np.ndarray image type, but got {type(im)}'
|
||||
if isinstance(im, Image.Image):
|
||||
if im.mode != 'RGB':
|
||||
@ -305,6 +316,7 @@ class LoadPilAndNumpy:
|
||||
return im
|
||||
|
||||
def _single_preprocess(self, im, auto):
|
||||
"""Preprocesses a single image for inference."""
|
||||
if self.transforms:
|
||||
im = self.transforms(im) # transforms
|
||||
else:
|
||||
@ -314,9 +326,11 @@ class LoadPilAndNumpy:
|
||||
return im
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the length of the 'im0' attribute."""
|
||||
return len(self.im0)
|
||||
|
||||
def __next__(self):
|
||||
"""Returns batch paths, images, processed images, None, ''."""
|
||||
if self.count == 1: # loop only once as it's batch inference
|
||||
raise StopIteration
|
||||
auto = all(x.shape == self.im0[0].shape for x in self.im0) and self.auto
|
||||
@ -326,6 +340,7 @@ class LoadPilAndNumpy:
|
||||
return self.paths, im, self.im0, None, ''
|
||||
|
||||
def __iter__(self):
|
||||
"""Enables iteration for class LoadPilAndNumpy."""
|
||||
self.count = 0
|
||||
return self
|
||||
|
||||
@ -338,16 +353,19 @@ class LoadTensor:
|
||||
self.mode = 'image'
|
||||
|
||||
def __iter__(self):
|
||||
"""Returns an iterator object."""
|
||||
self.count = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
"""Return next item in the iterator."""
|
||||
if self.count == 1:
|
||||
raise StopIteration
|
||||
self.count += 1
|
||||
return None, self.im0, self.im0, None, '' # self.paths, im, self.im0, None, ''
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the batch size."""
|
||||
return self.bs
|
||||
|
||||
|
||||
|
@ -24,6 +24,7 @@ IMAGENET_STD = 0.229, 0.224, 0.225 # RGB standard deviation
|
||||
class Albumentations:
|
||||
# YOLOv5 Albumentations class (optional, only used if package is installed)
|
||||
def __init__(self, size=640):
|
||||
"""Instantiate object with image augmentations for YOLOv5."""
|
||||
self.transform = None
|
||||
prefix = colorstr('albumentations: ')
|
||||
try:
|
||||
@ -48,6 +49,7 @@ class Albumentations:
|
||||
LOGGER.info(f'{prefix}{e}')
|
||||
|
||||
def __call__(self, im, labels, p=1.0):
|
||||
"""Transforms input image and labels with probability 'p'."""
|
||||
if self.transform and random.random() < p:
|
||||
new = self.transform(image=im, bboxes=labels[:, 1:], class_labels=labels[:, 0]) # transformed
|
||||
im, labels = new['image'], np.array([[c, *b] for c, b in zip(new['class_labels'], new['bboxes'])])
|
||||
@ -111,7 +113,7 @@ def replicate(im, labels):
|
||||
|
||||
|
||||
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
|
||||
# Resize and pad image while meeting stride-multiple constraints
|
||||
"""Resize and pad image while meeting stride-multiple constraints."""
|
||||
shape = im.shape[:2] # current shape [height, width]
|
||||
if isinstance(new_shape, int):
|
||||
new_shape = (new_shape, new_shape)
|
||||
@ -359,6 +361,7 @@ def classify_transforms(size=224):
|
||||
class LetterBox:
|
||||
# YOLOv5 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
|
||||
def __init__(self, size=(640, 640), auto=False, stride=32):
|
||||
"""Resizes and crops an image to a specified size for YOLOv5 preprocessing."""
|
||||
super().__init__()
|
||||
self.h, self.w = (size, size) if isinstance(size, int) else size
|
||||
self.auto = auto # pass max size integer, automatically solve for short side using stride
|
||||
@ -378,6 +381,7 @@ class LetterBox:
|
||||
class CenterCrop:
|
||||
# YOLOv5 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])
|
||||
def __init__(self, size=640):
|
||||
"""Converts input image into tensor for YOLOv5 processing."""
|
||||
super().__init__()
|
||||
self.h, self.w = (size, size) if isinstance(size, int) else size
|
||||
|
||||
@ -391,6 +395,7 @@ class CenterCrop:
|
||||
class ToTensor:
|
||||
# YOLOv5 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
|
||||
def __init__(self, half=False):
|
||||
"""Initialize ToTensor class for YOLOv5 image preprocessing."""
|
||||
super().__init__()
|
||||
self.half = half
|
||||
|
||||
|
@ -162,14 +162,17 @@ class InfiniteDataLoader(dataloader.DataLoader):
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Dataloader that reuses workers for same syntax as vanilla DataLoader."""
|
||||
super().__init__(*args, **kwargs)
|
||||
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
|
||||
self.iterator = super().__iter__()
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the length of batch_sampler's sampler."""
|
||||
return len(self.batch_sampler.sampler)
|
||||
|
||||
def __iter__(self):
|
||||
"""Creates a sampler that infinitely repeats."""
|
||||
for _ in range(len(self)):
|
||||
yield next(self.iterator)
|
||||
|
||||
@ -182,9 +185,11 @@ class _RepeatSampler:
|
||||
"""
|
||||
|
||||
def __init__(self, sampler):
|
||||
"""Sampler that repeats dataset samples infinitely."""
|
||||
self.sampler = sampler
|
||||
|
||||
def __iter__(self):
|
||||
"""Infinite loop iterating over a given sampler."""
|
||||
while True:
|
||||
yield from iter(self.sampler)
|
||||
|
||||
@ -221,6 +226,7 @@ class LoadScreenshots:
|
||||
self.monitor = {'left': self.left, 'top': self.top, 'width': self.width, 'height': self.height}
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterates over objects with the same structure as the monitor attribute."""
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
@ -241,6 +247,7 @@ class LoadScreenshots:
|
||||
class LoadImages:
|
||||
# YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
|
||||
def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
|
||||
"""Initialize instance variables and check for valid input."""
|
||||
if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line
|
||||
path = Path(path).read_text().rsplit()
|
||||
files = []
|
||||
@ -276,10 +283,12 @@ class LoadImages:
|
||||
f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
|
||||
|
||||
def __iter__(self):
|
||||
"""Returns an iterator object for iterating over images or videos found in a directory."""
|
||||
self.count = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
"""Iterator's next item, performs transformation on image and returns path, transformed image, original image, capture and size."""
|
||||
if self.count == self.nf:
|
||||
raise StopIteration
|
||||
path = self.files[self.count]
|
||||
@ -338,12 +347,14 @@ class LoadImages:
|
||||
return im
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the number of files in the class instance."""
|
||||
return self.nf # number of files
|
||||
|
||||
|
||||
class LoadStreams:
|
||||
# YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
|
||||
def __init__(self, sources='file.streams', img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
|
||||
"""Initialize YOLO detector with optional transforms and check input shapes."""
|
||||
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
|
||||
self.mode = 'stream'
|
||||
self.img_size = img_size
|
||||
@ -404,10 +415,12 @@ class LoadStreams:
|
||||
time.sleep(0.0) # wait time
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterator that returns the class instance."""
|
||||
self.count = -1
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
"""Return a tuple containing transformed and resized image data."""
|
||||
self.count += 1
|
||||
if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
|
||||
cv2.destroyAllWindows()
|
||||
@ -424,6 +437,7 @@ class LoadStreams:
|
||||
return self.sources, im, im0, None, ''
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the number of sources as the length of the object."""
|
||||
return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
|
||||
|
||||
|
||||
@ -607,6 +621,7 @@ class LoadImagesAndLabels(Dataset):
|
||||
return cache
|
||||
|
||||
def cache_labels(self, path=Path('./labels.cache'), prefix=''):
|
||||
"""Cache labels and save as numpy file for next time."""
|
||||
# Cache dataset labels, check images and read shapes
|
||||
if path.exists():
|
||||
path.unlink() # remove *.cache file if exists
|
||||
@ -646,9 +661,11 @@ class LoadImagesAndLabels(Dataset):
|
||||
return x
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the length of 'im_files' attribute."""
|
||||
return len(self.im_files)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Get a sample and its corresponding label, filename and shape from the dataset."""
|
||||
index = self.indices[index] # linear, shuffled, or image_weights
|
||||
|
||||
hyp = self.hyp
|
||||
@ -1039,6 +1056,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
||||
"""
|
||||
|
||||
def __init__(self, root, augment, imgsz, cache=False):
|
||||
"""Initialize YOLO dataset with root, augmentation, image size, and cache parameters."""
|
||||
super().__init__(root=root)
|
||||
self.torch_transforms = classify_transforms(imgsz)
|
||||
self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
|
||||
@ -1047,6 +1065,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
||||
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
|
||||
|
||||
def __getitem__(self, i):
|
||||
"""Retrieves data items of 'dataset' via indices & creates InfiniteDataLoader."""
|
||||
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
|
||||
if self.cache_ram and im is None:
|
||||
im = self.samples[i][3] = cv2.imread(f)
|
||||
|
@ -127,6 +127,7 @@ class YOLODataset(BaseDataset):
|
||||
return x
|
||||
|
||||
def get_labels(self):
|
||||
"""Returns dictionary of labels for YOLO training."""
|
||||
self.label_files = img2label_paths(self.im_files)
|
||||
cache_path = Path(self.label_files[0]).parent.with_suffix('.cache')
|
||||
try:
|
||||
@ -170,6 +171,7 @@ class YOLODataset(BaseDataset):
|
||||
|
||||
# TODO: use hyp config to set all these augmentations
|
||||
def build_transforms(self, hyp=None):
|
||||
"""Builds and appends transforms to the list."""
|
||||
if self.augment:
|
||||
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
|
||||
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
|
||||
@ -187,6 +189,7 @@ class YOLODataset(BaseDataset):
|
||||
return transforms
|
||||
|
||||
def close_mosaic(self, hyp):
|
||||
"""Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations."""
|
||||
hyp.mosaic = 0.0 # set mosaic ratio=0.0
|
||||
hyp.copy_paste = 0.0 # keep the same behavior as previous v8 close-mosaic
|
||||
hyp.mixup = 0.0 # keep the same behavior as previous v8 close-mosaic
|
||||
@ -206,6 +209,7 @@ class YOLODataset(BaseDataset):
|
||||
|
||||
@staticmethod
|
||||
def collate_fn(batch):
|
||||
"""Collates data samples into batches."""
|
||||
new_batch = {}
|
||||
keys = batch[0].keys()
|
||||
values = list(zip(*[list(b.values()) for b in batch]))
|
||||
@ -234,6 +238,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
||||
"""
|
||||
|
||||
def __init__(self, root, augment, imgsz, cache=False):
|
||||
"""Initialize YOLO object with root, image size, augmentations, and cache settings"""
|
||||
super().__init__(root=root)
|
||||
self.torch_transforms = classify_transforms(imgsz)
|
||||
self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
|
||||
@ -242,6 +247,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
||||
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
|
||||
|
||||
def __getitem__(self, i):
|
||||
"""Returns subset of data and targets corresponding to given indices."""
|
||||
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
|
||||
if self.cache_ram and im is None:
|
||||
im = self.samples[i][3] = cv2.imread(f)
|
||||
@ -265,4 +271,5 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
||||
class SemanticDataset(BaseDataset):
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize a SemanticDataset object."""
|
||||
pass
|
||||
|
@ -359,6 +359,7 @@ class HUBDatasetStats():
|
||||
return True, str(dir), self._find_yaml(dir) # zipped, data_dir, yaml_path
|
||||
|
||||
def _hub_ops(self, f):
|
||||
"""Saves a compressed image for HUB previews."""
|
||||
compress_one_image(f, self.im_dir / Path(f).name) # save to dataset-hub
|
||||
|
||||
def get_json(self, save=False, verbose=False):
|
||||
|
@ -105,6 +105,7 @@ def try_export(inner_func):
|
||||
inner_args = get_default_args(inner_func)
|
||||
|
||||
def outer_func(*args, **kwargs):
|
||||
"""Export a model."""
|
||||
prefix = inner_args['prefix']
|
||||
try:
|
||||
with Profile() as dt:
|
||||
@ -118,24 +119,6 @@ def try_export(inner_func):
|
||||
return outer_func
|
||||
|
||||
|
||||
class iOSDetectModel(torch.nn.Module):
|
||||
"""Wrap an Ultralytics YOLO model for iOS export."""
|
||||
|
||||
def __init__(self, model, im):
|
||||
super().__init__()
|
||||
b, c, h, w = im.shape # batch, channel, height, width
|
||||
self.model = model
|
||||
self.nc = len(model.names) # number of classes
|
||||
if w == h:
|
||||
self.normalize = 1.0 / w # scalar
|
||||
else:
|
||||
self.normalize = torch.tensor([1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h]) # broadcast (slower, smaller)
|
||||
|
||||
def forward(self, x):
|
||||
xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
|
||||
return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)
|
||||
|
||||
|
||||
class Exporter:
|
||||
"""
|
||||
A class for exporting a model.
|
||||
@ -160,6 +143,7 @@ class Exporter:
|
||||
|
||||
@smart_inference_mode()
|
||||
def __call__(self, model=None):
|
||||
"""Returns list of exported files/dirs after running callbacks."""
|
||||
self.run_callbacks('on_export_start')
|
||||
t = time.time()
|
||||
format = self.args.format.lower() # to lowercase
|
||||
@ -703,7 +687,7 @@ class Exporter:
|
||||
tmp_file.unlink()
|
||||
|
||||
def _pipeline_coreml(self, model, prefix=colorstr('CoreML Pipeline:')):
|
||||
# YOLOv8 CoreML pipeline
|
||||
"""YOLOv8 CoreML pipeline."""
|
||||
import coremltools as ct # noqa
|
||||
|
||||
LOGGER.info(f'{prefix} starting pipeline with coremltools {ct.__version__}...')
|
||||
@ -826,11 +810,33 @@ class Exporter:
|
||||
self.callbacks[event].append(callback)
|
||||
|
||||
def run_callbacks(self, event: str):
|
||||
"""Execute all callbacks for a given event."""
|
||||
for callback in self.callbacks.get(event, []):
|
||||
callback(self)
|
||||
|
||||
|
||||
class iOSDetectModel(torch.nn.Module):
|
||||
"""Wrap an Ultralytics YOLO model for iOS export."""
|
||||
|
||||
def __init__(self, model, im):
|
||||
"""Initialize the iOSDetectModel class with a YOLO model and example image."""
|
||||
super().__init__()
|
||||
b, c, h, w = im.shape # batch, channel, height, width
|
||||
self.model = model
|
||||
self.nc = len(model.names) # number of classes
|
||||
if w == h:
|
||||
self.normalize = 1.0 / w # scalar
|
||||
else:
|
||||
self.normalize = torch.tensor([1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h]) # broadcast (slower, smaller)
|
||||
|
||||
def forward(self, x):
|
||||
"""Normalize predictions of object detection model with input size-dependent factors."""
|
||||
xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
|
||||
return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)
|
||||
|
||||
|
||||
def export(cfg=DEFAULT_CFG):
|
||||
"""Export a YOLOv model to a specific format."""
|
||||
cfg.model = cfg.model or 'yolov8n.yaml'
|
||||
cfg.format = cfg.format or 'torchscript'
|
||||
|
||||
|
@ -107,14 +107,17 @@ class YOLO:
|
||||
self._load(model, task)
|
||||
|
||||
def __call__(self, source=None, stream=False, **kwargs):
|
||||
"""Calls the 'predict' function with given arguments to perform object detection."""
|
||||
return self.predict(source, stream, **kwargs)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
"""Raises error if object has no requested attribute."""
|
||||
name = self.__class__.__name__
|
||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||
|
||||
@staticmethod
|
||||
def is_hub_model(model):
|
||||
"""Check if the provided model is a HUB model."""
|
||||
return any((
|
||||
model.startswith('https://hub.ultra'), # i.e. https://hub.ultralytics.com/models/MODEL_ID
|
||||
[len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID
|
||||
@ -209,6 +212,7 @@ class YOLO:
|
||||
self.model.info(verbose=verbose)
|
||||
|
||||
def fuse(self):
|
||||
"""Fuse PyTorch Conv2d and BatchNorm2d layers."""
|
||||
self._check_is_pytorch_model()
|
||||
self.model.fuse()
|
||||
|
||||
@ -493,9 +497,11 @@ class YOLO:
|
||||
|
||||
@staticmethod
|
||||
def _reset_ckpt_args(args):
|
||||
"""Reset arguments when loading a PyTorch model."""
|
||||
include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
|
||||
return {k: v for k, v in args.items() if k in include}
|
||||
|
||||
def _reset_callbacks(self):
|
||||
"""Reset all registered callbacks."""
|
||||
for event in callbacks.default_callbacks.keys():
|
||||
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
|
||||
|
@ -107,9 +107,11 @@ class BasePredictor:
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
def preprocess(self, img):
|
||||
"""Prepares input image before inference."""
|
||||
pass
|
||||
|
||||
def write_results(self, idx, results, batch):
|
||||
"""Write inference results to a file or directory."""
|
||||
p, im, _ = batch
|
||||
log_string = ''
|
||||
if len(im.shape) == 3:
|
||||
@ -143,9 +145,11 @@ class BasePredictor:
|
||||
return log_string
|
||||
|
||||
def postprocess(self, preds, img, orig_img):
|
||||
"""Post-processes predictions for an image and returns them."""
|
||||
return preds
|
||||
|
||||
def __call__(self, source=None, model=None, stream=False):
|
||||
"""Performs inference on an image or stream."""
|
||||
self.stream = stream
|
||||
if stream:
|
||||
return self.stream_inference(source, model)
|
||||
@ -159,6 +163,7 @@ class BasePredictor:
|
||||
pass
|
||||
|
||||
def setup_source(self, source):
|
||||
"""Sets up source and inference mode."""
|
||||
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
|
||||
if self.args.task == 'classify':
|
||||
transforms = getattr(self.model.model, 'transforms', classify_transforms(self.imgsz[0]))
|
||||
@ -179,6 +184,7 @@ class BasePredictor:
|
||||
|
||||
@smart_inference_mode()
|
||||
def stream_inference(self, source=None, model=None):
|
||||
"""Streams real-time inference on camera feed and saves results to file."""
|
||||
if self.args.verbose:
|
||||
LOGGER.info('')
|
||||
|
||||
@ -264,6 +270,7 @@ class BasePredictor:
|
||||
self.run_callbacks('on_predict_end')
|
||||
|
||||
def setup_model(self, model, verbose=True):
|
||||
"""Initialize YOLO model with given parameters and set it to evaluation mode."""
|
||||
device = select_device(self.args.device, verbose=verbose)
|
||||
model = model or self.args.model
|
||||
self.args.half &= device.type != 'cpu' # half precision only supported on CUDA
|
||||
@ -278,6 +285,7 @@ class BasePredictor:
|
||||
self.model.eval()
|
||||
|
||||
def show(self, p):
|
||||
"""Display an image in a window using OpenCV imshow()."""
|
||||
im0 = self.plotted_img
|
||||
if platform.system() == 'Linux' and p not in self.windows:
|
||||
self.windows.append(p)
|
||||
@ -287,6 +295,7 @@ class BasePredictor:
|
||||
cv2.waitKey(500 if self.batch[4].startswith('image') else 1) # 1 millisecond
|
||||
|
||||
def save_preds(self, vid_cap, idx, save_path):
|
||||
"""Save video predictions as mp4 at specified path."""
|
||||
im0 = self.plotted_img
|
||||
# Save imgs
|
||||
if self.dataset.mode == 'image':
|
||||
@ -307,6 +316,7 @@ class BasePredictor:
|
||||
self.vid_writer[idx].write(im0)
|
||||
|
||||
def run_callbacks(self, event: str):
|
||||
"""Runs all registered callbacks for a specific event."""
|
||||
for callback in self.callbacks.get(event, []):
|
||||
callback(self)
|
||||
|
||||
|
@ -19,42 +19,41 @@ from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
|
||||
|
||||
class BaseTensor(SimpleClass):
|
||||
"""
|
||||
|
||||
Attributes:
|
||||
data (torch.Tensor): Base tensor.
|
||||
orig_shape (tuple): Original image size, in the format (height, width).
|
||||
|
||||
Methods:
|
||||
cpu(): Returns a copy of the tensor on CPU memory.
|
||||
numpy(): Returns a copy of the tensor as a numpy array.
|
||||
cuda(): Returns a copy of the tensor on GPU memory.
|
||||
to(): Returns a copy of the tensor with the specified device and dtype.
|
||||
Base tensor class with additional methods for easy manipulation and device handling.
|
||||
"""
|
||||
|
||||
def __init__(self, data, orig_shape) -> None:
|
||||
"""Initialize BaseTensor with data and original shape."""
|
||||
self.data = data
|
||||
self.orig_shape = orig_shape
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
"""Return the shape of the data tensor."""
|
||||
return self.data.shape
|
||||
|
||||
def cpu(self):
|
||||
"""Return a copy of the tensor on CPU memory."""
|
||||
return self.__class__(self.data.cpu(), self.orig_shape)
|
||||
|
||||
def numpy(self):
|
||||
"""Return a copy of the tensor as a numpy array."""
|
||||
return self.__class__(self.data.numpy(), self.orig_shape)
|
||||
|
||||
def cuda(self):
|
||||
"""Return a copy of the tensor on GPU memory."""
|
||||
return self.__class__(self.data.cuda(), self.orig_shape)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
"""Return a copy of the tensor with the specified device and dtype."""
|
||||
return self.__class__(self.data.to(*args, **kwargs), self.orig_shape)
|
||||
|
||||
def __len__(self): # override len(results)
|
||||
"""Return the length of the data tensor."""
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Return a BaseTensor with the specified index of the data tensor."""
|
||||
return self.__class__(self.data[idx], self.orig_shape)
|
||||
|
||||
|
||||
@ -83,10 +82,10 @@ class Results(SimpleClass):
|
||||
keypoints (List[List[float]], optional): A list of detected keypoints for each object.
|
||||
speed (dict): A dictionary of preprocess, inference and postprocess speeds in milliseconds per image.
|
||||
_keys (tuple): A tuple of attribute names for non-empty attributes.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None, keypoints=None) -> None:
|
||||
"""Initialize the Results class."""
|
||||
self.orig_img = orig_img
|
||||
self.orig_shape = orig_img.shape[:2]
|
||||
self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes
|
||||
@ -99,16 +98,19 @@ class Results(SimpleClass):
|
||||
self._keys = ('boxes', 'masks', 'probs', 'keypoints')
|
||||
|
||||
def pandas(self):
|
||||
"""Convert the results to a pandas DataFrame."""
|
||||
pass
|
||||
# TODO masks.pandas + boxes.pandas + cls.pandas
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Return a Results object for the specified index."""
|
||||
r = self.new()
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k)[idx])
|
||||
return r
|
||||
|
||||
def update(self, boxes=None, masks=None, probs=None):
|
||||
"""Update the boxes, masks, and probs attributes of the Results object."""
|
||||
if boxes is not None:
|
||||
self.boxes = Boxes(boxes, self.orig_shape)
|
||||
if masks is not None:
|
||||
@ -117,38 +119,45 @@ class Results(SimpleClass):
|
||||
self.probs = probs
|
||||
|
||||
def cpu(self):
|
||||
"""Return a copy of the Results object with all tensors on CPU memory."""
|
||||
r = self.new()
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k).cpu())
|
||||
return r
|
||||
|
||||
def numpy(self):
|
||||
"""Return a copy of the Results object with all tensors as numpy arrays."""
|
||||
r = self.new()
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k).numpy())
|
||||
return r
|
||||
|
||||
def cuda(self):
|
||||
"""Return a copy of the Results object with all tensors on GPU memory."""
|
||||
r = self.new()
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k).cuda())
|
||||
return r
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
"""Return a copy of the Results object with tensors on the specified device and dtype."""
|
||||
r = self.new()
|
||||
for k in self.keys:
|
||||
setattr(r, k, getattr(self, k).to(*args, **kwargs))
|
||||
return r
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of detections in the Results object."""
|
||||
for k in self.keys:
|
||||
return len(getattr(self, k))
|
||||
|
||||
def new(self):
|
||||
"""Return a new Results object with the same image, path, and names."""
|
||||
return Results(orig_img=self.orig_img, path=self.path, names=self.names)
|
||||
|
||||
@property
|
||||
def keys(self):
|
||||
"""Return a list of non-empty attribute names."""
|
||||
return [k for k in self._keys if getattr(self, k) is not None]
|
||||
|
||||
def plot(
|
||||
@ -250,7 +259,8 @@ class Results(SimpleClass):
|
||||
return log_string
|
||||
|
||||
def save_txt(self, txt_file, save_conf=False):
|
||||
"""Save predictions into txt file.
|
||||
"""
|
||||
Save predictions into txt file.
|
||||
|
||||
Args:
|
||||
txt_file (str): txt file path.
|
||||
@ -285,7 +295,8 @@ class Results(SimpleClass):
|
||||
f.write(text + '\n')
|
||||
|
||||
def save_crop(self, save_dir, file_name=Path('im.jpg')):
|
||||
"""Save cropped predictions to `save_dir/cls/file_name.jpg`.
|
||||
"""
|
||||
Save cropped predictions to `save_dir/cls/file_name.jpg`.
|
||||
|
||||
Args:
|
||||
save_dir (str | pathlib.Path): Save path.
|
||||
@ -338,6 +349,7 @@ class Boxes(BaseTensor):
|
||||
"""
|
||||
|
||||
def __init__(self, boxes, orig_shape) -> None:
|
||||
"""Initialize the Boxes class."""
|
||||
if boxes.ndim == 1:
|
||||
boxes = boxes[None, :]
|
||||
n = boxes.shape[-1]
|
||||
@ -349,40 +361,49 @@ class Boxes(BaseTensor):
|
||||
|
||||
@property
|
||||
def xyxy(self):
|
||||
"""Return the boxes in xyxy format."""
|
||||
return self.data[:, :4]
|
||||
|
||||
@property
|
||||
def conf(self):
|
||||
"""Return the confidence values of the boxes."""
|
||||
return self.data[:, -2]
|
||||
|
||||
@property
|
||||
def cls(self):
|
||||
"""Return the class values of the boxes."""
|
||||
return self.data[:, -1]
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
"""Return the track IDs of the boxes (if available)."""
|
||||
return self.data[:, -3] if self.is_track else None
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=2) # maxsize 1 should suffice
|
||||
def xywh(self):
|
||||
"""Return the boxes in xywh format."""
|
||||
return ops.xyxy2xywh(self.xyxy)
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=2)
|
||||
def xyxyn(self):
|
||||
"""Return the boxes in xyxy format normalized by original image size."""
|
||||
return self.xyxy / self.orig_shape[[1, 0, 1, 0]]
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=2)
|
||||
def xywhn(self):
|
||||
"""Return the boxes in xywh format normalized by original image size."""
|
||||
return self.xywh / self.orig_shape[[1, 0, 1, 0]]
|
||||
|
||||
def pandas(self):
|
||||
"""Convert the object to a pandas DataFrame (not yet implemented)."""
|
||||
LOGGER.info('results.pandas() method not yet implemented')
|
||||
|
||||
@property
|
||||
def boxes(self):
|
||||
"""Return the raw bboxes tensor (deprecated)."""
|
||||
LOGGER.warning("WARNING ⚠️ 'Boxes.boxes' is deprecated. Use 'Boxes.data' instead.")
|
||||
return self.data
|
||||
|
||||
@ -411,6 +432,7 @@ class Masks(BaseTensor):
|
||||
"""
|
||||
|
||||
def __init__(self, masks, orig_shape) -> None:
|
||||
"""Initialize the Masks class."""
|
||||
if masks.ndim == 2:
|
||||
masks = masks[None, :]
|
||||
super().__init__(masks, orig_shape)
|
||||
@ -418,7 +440,7 @@ class Masks(BaseTensor):
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def segments(self):
|
||||
"""Segments-deprecated (normalized)."""
|
||||
"""Return segments (deprecated; normalized)."""
|
||||
LOGGER.warning("WARNING ⚠️ 'Masks.segments' is deprecated. Use 'Masks.xyn' for segments (normalized) and "
|
||||
"'Masks.xy' for segments (pixels) instead.")
|
||||
return self.xyn
|
||||
@ -426,7 +448,7 @@ class Masks(BaseTensor):
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def xyn(self):
|
||||
"""Segments (normalized)."""
|
||||
"""Return segments (normalized)."""
|
||||
return [
|
||||
ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=True)
|
||||
for x in ops.masks2segments(self.data)]
|
||||
@ -434,12 +456,13 @@ class Masks(BaseTensor):
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def xy(self):
|
||||
"""Segments (pixels)."""
|
||||
"""Return segments (pixels)."""
|
||||
return [
|
||||
ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=False)
|
||||
for x in ops.masks2segments(self.data)]
|
||||
|
||||
@property
|
||||
def masks(self):
|
||||
"""Return the raw masks tensor (deprecated)."""
|
||||
LOGGER.warning("WARNING ⚠️ 'Masks.masks' is deprecated. Use 'Masks.data' instead.")
|
||||
return self.data
|
||||
|
@ -159,6 +159,7 @@ class BaseTrainer:
|
||||
self.callbacks[event] = [callback]
|
||||
|
||||
def run_callbacks(self, event: str):
|
||||
"""Run all existing callbacks associated with a particular event."""
|
||||
for callback in self.callbacks.get(event, []):
|
||||
callback(self)
|
||||
|
||||
@ -190,6 +191,7 @@ class BaseTrainer:
|
||||
self._do_train(world_size)
|
||||
|
||||
def _setup_ddp(self, world_size):
|
||||
"""Initializes and sets the DistributedDataParallel parameters for training."""
|
||||
torch.cuda.set_device(RANK)
|
||||
self.device = torch.device('cuda', RANK)
|
||||
LOGGER.info(f'DDP settings: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
||||
@ -259,6 +261,7 @@ class BaseTrainer:
|
||||
self.run_callbacks('on_pretrain_routine_end')
|
||||
|
||||
def _do_train(self, world_size=1):
|
||||
"""Train completed, evaluate and plot if specified by arguments."""
|
||||
if world_size > 1:
|
||||
self._setup_ddp(world_size)
|
||||
|
||||
@ -392,6 +395,7 @@ class BaseTrainer:
|
||||
self.run_callbacks('teardown')
|
||||
|
||||
def save_model(self):
|
||||
"""Save model checkpoints based on various conditions."""
|
||||
ckpt = {
|
||||
'epoch': self.epoch,
|
||||
'best_fitness': self.best_fitness,
|
||||
@ -436,6 +440,7 @@ class BaseTrainer:
|
||||
return ckpt
|
||||
|
||||
def optimizer_step(self):
|
||||
"""Perform a single step of the training optimizer with gradient clipping and EMA update."""
|
||||
self.scaler.unscale_(self.optimizer) # unscale gradients
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients
|
||||
self.scaler.step(self.optimizer)
|
||||
@ -461,9 +466,11 @@ class BaseTrainer:
|
||||
return metrics, fitness
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""Get model and raise NotImplementedError for loading cfg files."""
|
||||
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
||||
|
||||
def get_validator(self):
|
||||
"""Returns a NotImplementedError when the get_validator function is called."""
|
||||
raise NotImplementedError('get_validator function not implemented in trainer')
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
|
||||
@ -492,19 +499,24 @@ class BaseTrainer:
|
||||
self.model.names = self.data['names']
|
||||
|
||||
def build_targets(self, preds, targets):
|
||||
"""Builds target tensors for training YOLO model."""
|
||||
pass
|
||||
|
||||
def progress_string(self):
|
||||
"""Returns a string describing training progress."""
|
||||
return ''
|
||||
|
||||
# TODO: may need to put these following functions into callback
|
||||
def plot_training_samples(self, batch, ni):
|
||||
"""Plots training samples during YOLOv5 training."""
|
||||
pass
|
||||
|
||||
def plot_training_labels(self):
|
||||
"""Plots training labels for YOLO model."""
|
||||
pass
|
||||
|
||||
def save_metrics(self, metrics):
|
||||
"""Saves training metrics to a CSV file."""
|
||||
keys, vals = list(metrics.keys()), list(metrics.values())
|
||||
n = len(metrics) + 1 # number of cols
|
||||
s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header
|
||||
@ -512,9 +524,11 @@ class BaseTrainer:
|
||||
f.write(s + ('%23.5g,' * n % tuple([self.epoch] + vals)).rstrip(',') + '\n')
|
||||
|
||||
def plot_metrics(self):
|
||||
"""Plot and display metrics visually."""
|
||||
pass
|
||||
|
||||
def final_eval(self):
|
||||
"""Performs final evaluation and validation for object detection YOLO model."""
|
||||
for f in self.last, self.best:
|
||||
if f.exists():
|
||||
strip_optimizer(f) # strip optimizers
|
||||
@ -525,6 +539,7 @@ class BaseTrainer:
|
||||
self.run_callbacks('on_fit_epoch_end')
|
||||
|
||||
def check_resume(self):
|
||||
"""Check if resume checkpoint exists and update arguments accordingly."""
|
||||
resume = self.args.resume
|
||||
if resume:
|
||||
try:
|
||||
@ -539,6 +554,7 @@ class BaseTrainer:
|
||||
self.resume = resume
|
||||
|
||||
def resume_training(self, ckpt):
|
||||
"""Resume YOLO training from given epoch and best fitness."""
|
||||
if ckpt is None:
|
||||
return
|
||||
best_fitness = 0.0
|
||||
|
@ -195,58 +195,72 @@ class BaseValidator:
|
||||
return stats
|
||||
|
||||
def add_callback(self, event: str, callback):
|
||||
"""
|
||||
Appends the given callback.
|
||||
"""
|
||||
"""Appends the given callback."""
|
||||
self.callbacks[event].append(callback)
|
||||
|
||||
def run_callbacks(self, event: str):
|
||||
"""Runs all callbacks associated with a specified event."""
|
||||
for callback in self.callbacks.get(event, []):
|
||||
callback(self)
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size):
|
||||
"""Get data loader from dataset path and batch size."""
|
||||
raise NotImplementedError('get_dataloader function not implemented for this validator')
|
||||
|
||||
def preprocess(self, batch):
|
||||
"""Preprocesses an input batch."""
|
||||
return batch
|
||||
|
||||
def postprocess(self, preds):
|
||||
"""Describes and summarizes the purpose of 'postprocess()' but no details mentioned."""
|
||||
return preds
|
||||
|
||||
def init_metrics(self, model):
|
||||
"""Initialize performance metrics for the YOLO model."""
|
||||
pass
|
||||
|
||||
def update_metrics(self, preds, batch):
|
||||
"""Updates metrics based on predictions and batch."""
|
||||
pass
|
||||
|
||||
def finalize_metrics(self, *args, **kwargs):
|
||||
"""Finalizes and returns all metrics."""
|
||||
pass
|
||||
|
||||
def get_stats(self):
|
||||
"""Returns statistics about the model's performance."""
|
||||
return {}
|
||||
|
||||
def check_stats(self, stats):
|
||||
"""Checks statistics."""
|
||||
pass
|
||||
|
||||
def print_results(self):
|
||||
"""Prints the results of the model's predictions."""
|
||||
pass
|
||||
|
||||
def get_desc(self):
|
||||
"""Get description of the YOLO model."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def metric_keys(self):
|
||||
"""Returns the metric keys used in YOLO training/validation."""
|
||||
return []
|
||||
|
||||
# TODO: may need to put these following functions into callback
|
||||
def plot_val_samples(self, batch, ni):
|
||||
"""Plots validation samples during training."""
|
||||
pass
|
||||
|
||||
def plot_predictions(self, batch, preds, ni):
|
||||
"""Plots YOLO model predictions on batch images."""
|
||||
pass
|
||||
|
||||
def pred_to_json(self, preds, batch):
|
||||
"""Convert predictions to JSON format."""
|
||||
pass
|
||||
|
||||
def eval_json(self, stats):
|
||||
"""Evaluate and return JSON format of prediction statistics."""
|
||||
pass
|
||||
|
@ -182,8 +182,10 @@ def plt_settings(rcparams={'font.size': 11}, backend='Agg'):
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
"""Decorator to apply temporary rc parameters and backend to a function."""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
"""Sets rc parameters and backend, calls the original function, and restores the settings."""
|
||||
original_backend = plt.get_backend()
|
||||
plt.switch_backend(backend)
|
||||
|
||||
@ -229,6 +231,7 @@ class EmojiFilter(logging.Filter):
|
||||
"""
|
||||
|
||||
def filter(self, record):
|
||||
"""Filter logs by emoji unicode characters on windows."""
|
||||
record.msg = emojis(record.msg)
|
||||
return super().filter(record)
|
||||
|
||||
@ -573,13 +576,16 @@ class TryExcept(contextlib.ContextDecorator):
|
||||
"""YOLOv8 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager."""
|
||||
|
||||
def __init__(self, msg='', verbose=True):
|
||||
"""Initialize TryExcept class with optional message and verbosity settings."""
|
||||
self.msg = msg
|
||||
self.verbose = verbose
|
||||
|
||||
def __enter__(self):
|
||||
"""Executes when entering TryExcept context, initializes instance."""
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type, value, traceback):
|
||||
"""Defines behavior when exiting a 'with' block, prints error message if necessary."""
|
||||
if self.verbose and value:
|
||||
print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}"))
|
||||
return True
|
||||
@ -589,6 +595,7 @@ def threaded(func):
|
||||
"""Multi-threads a target function and returns thread. Usage: @threaded decorator."""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
"""Multi-threads a given function and returns the thread."""
|
||||
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
|
||||
thread.start()
|
||||
return thread
|
||||
@ -602,6 +609,7 @@ def set_sentry():
|
||||
"""
|
||||
|
||||
def before_send(event, hint):
|
||||
"""A function executed before sending the event to Sentry."""
|
||||
if 'exc_info' in hint:
|
||||
exc_type, exc_value, tb = hint['exc_info']
|
||||
if exc_type in (KeyboardInterrupt, FileNotFoundError) \
|
||||
@ -698,6 +706,7 @@ def set_settings(kwargs, file=SETTINGS_YAML):
|
||||
|
||||
|
||||
def deprecation_warn(arg, new_arg, version=None):
|
||||
"""Issue a deprecation warning when a deprecated argument is used, suggesting an updated argument."""
|
||||
if not version:
|
||||
version = float(__version__[:3]) + 0.2 # deprecate after 2nd major release
|
||||
LOGGER.warning(f"WARNING ⚠️ '{arg}' is deprecated and will be removed in 'ultralytics {version}' in the future. "
|
||||
|
@ -35,7 +35,30 @@ from ultralytics.yolo.utils.files import file_size
|
||||
from ultralytics.yolo.utils.torch_utils import select_device
|
||||
|
||||
|
||||
def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt', imgsz=160, half=False, device='cpu', hard_fail=False):
|
||||
def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
|
||||
imgsz=160,
|
||||
half=False,
|
||||
int8=False,
|
||||
device='cpu',
|
||||
hard_fail=False):
|
||||
"""
|
||||
Benchmark a YOLO model across different formats for speed and accuracy.
|
||||
|
||||
Args:
|
||||
model (Union[str, Path], optional): Path to the model file or directory. Default is
|
||||
Path(SETTINGS['weights_dir']) / 'yolov8n.pt'.
|
||||
imgsz (int, optional): Image size for the benchmark. Default is 160.
|
||||
half (bool, optional): Use half-precision for the model if True. Default is False.
|
||||
int8 (bool, optional): Use int8-precision for the model if True. Default is False.
|
||||
device (str, optional): Device to run the benchmark on, either 'cpu' or 'cuda'. Default is 'cpu'.
|
||||
hard_fail (Union[bool, float], optional): If True or a float, assert benchmarks pass with given metric.
|
||||
Default is False.
|
||||
|
||||
Returns:
|
||||
df (pandas.DataFrame): A pandas DataFrame with benchmark results for each format, including file size,
|
||||
metric, and inference time.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
pd.options.display.max_columns = 10
|
||||
pd.options.display.width = 120
|
||||
@ -61,7 +84,7 @@ def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt', imgsz=160, hal
|
||||
filename = model.ckpt_path or model.cfg
|
||||
export = model # PyTorch format
|
||||
else:
|
||||
filename = model.export(imgsz=imgsz, format=format, half=half, device=device) # all others
|
||||
filename = model.export(imgsz=imgsz, format=format, half=half, int8=int8, device=device) # all others
|
||||
export = YOLO(filename, task=model.task)
|
||||
assert suffix in str(filename), 'export failed'
|
||||
emoji = '❎' # indicates export succeeded
|
||||
@ -83,7 +106,14 @@ def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt', imgsz=160, hal
|
||||
elif model.task == 'pose':
|
||||
data, key = 'coco8-pose.yaml', 'metrics/mAP50-95(P)'
|
||||
|
||||
results = export.val(data=data, batch=1, imgsz=imgsz, plots=False, device=device, half=half, verbose=False)
|
||||
results = export.val(data=data,
|
||||
batch=1,
|
||||
imgsz=imgsz,
|
||||
plots=False,
|
||||
device=device,
|
||||
half=half,
|
||||
int8=int8,
|
||||
verbose=False)
|
||||
metric, speed = results.results_dict[key], results.speed['inference']
|
||||
y.append([name, '✅', round(file_size(filename), 1), round(metric, 4), round(speed, 2)])
|
||||
except Exception as e:
|
||||
|
@ -2,111 +2,144 @@
|
||||
"""
|
||||
Base callbacks
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
# Trainer callbacks ----------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Called before the pretraining routine starts."""
|
||||
pass
|
||||
|
||||
|
||||
def on_pretrain_routine_end(trainer):
|
||||
"""Called after the pretraining routine ends."""
|
||||
pass
|
||||
|
||||
|
||||
def on_train_start(trainer):
|
||||
"""Called when the training starts."""
|
||||
pass
|
||||
|
||||
|
||||
def on_train_epoch_start(trainer):
|
||||
"""Called at the start of each training epoch."""
|
||||
pass
|
||||
|
||||
|
||||
def on_train_batch_start(trainer):
|
||||
"""Called at the start of each training batch."""
|
||||
pass
|
||||
|
||||
|
||||
def optimizer_step(trainer):
|
||||
"""Called when the optimizer takes a step."""
|
||||
pass
|
||||
|
||||
|
||||
def on_before_zero_grad(trainer):
|
||||
"""Called before the gradients are set to zero."""
|
||||
pass
|
||||
|
||||
|
||||
def on_train_batch_end(trainer):
|
||||
"""Called at the end of each training batch."""
|
||||
pass
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Called at the end of each training epoch."""
|
||||
pass
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Called at the end of each fit epoch (train + val)."""
|
||||
pass
|
||||
|
||||
|
||||
def on_model_save(trainer):
|
||||
"""Called when the model is saved."""
|
||||
pass
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Called when the training ends."""
|
||||
pass
|
||||
|
||||
|
||||
def on_params_update(trainer):
|
||||
"""Called when the model parameters are updated."""
|
||||
pass
|
||||
|
||||
|
||||
def teardown(trainer):
|
||||
"""Called during the teardown of the training process."""
|
||||
pass
|
||||
|
||||
|
||||
# Validator callbacks --------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def on_val_start(validator):
|
||||
"""Called when the validation starts."""
|
||||
pass
|
||||
|
||||
|
||||
def on_val_batch_start(validator):
|
||||
"""Called at the start of each validation batch."""
|
||||
pass
|
||||
|
||||
|
||||
def on_val_batch_end(validator):
|
||||
"""Called at the end of each validation batch."""
|
||||
pass
|
||||
|
||||
|
||||
def on_val_end(validator):
|
||||
"""Called when the validation ends."""
|
||||
pass
|
||||
|
||||
|
||||
# Predictor callbacks --------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def on_predict_start(predictor):
|
||||
"""Called when the prediction starts."""
|
||||
pass
|
||||
|
||||
|
||||
def on_predict_batch_start(predictor):
|
||||
"""Called at the start of each prediction batch."""
|
||||
pass
|
||||
|
||||
|
||||
def on_predict_batch_end(predictor):
|
||||
"""Called at the end of each prediction batch."""
|
||||
pass
|
||||
|
||||
|
||||
def on_predict_postprocess_end(predictor):
|
||||
"""Called after the post-processing of the prediction ends."""
|
||||
pass
|
||||
|
||||
|
||||
def on_predict_end(predictor):
|
||||
"""Called when the prediction ends."""
|
||||
pass
|
||||
|
||||
|
||||
# Exporter callbacks ---------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def on_export_start(exporter):
|
||||
"""Called when the model export starts."""
|
||||
pass
|
||||
|
||||
|
||||
def on_export_end(exporter):
|
||||
"""Called when the model export ends."""
|
||||
pass
|
||||
|
||||
|
||||
@ -146,10 +179,23 @@ default_callbacks = {
|
||||
|
||||
|
||||
def get_default_callbacks():
|
||||
"""
|
||||
Return a copy of the default_callbacks dictionary with lists as default values.
|
||||
|
||||
Returns:
|
||||
(defaultdict): A defaultdict with keys from default_callbacks and empty lists as default values.
|
||||
"""
|
||||
return defaultdict(list, deepcopy(default_callbacks))
|
||||
|
||||
|
||||
def add_integration_callbacks(instance):
|
||||
"""
|
||||
Add integration callbacks from various sources to the instance's callbacks.
|
||||
|
||||
Args:
|
||||
instance (Trainer, Predictor, Validator, Exporter): An object with a 'callbacks' attribute that is a dictionary
|
||||
of callback lists.
|
||||
"""
|
||||
from .clearml import callbacks as clearml_callbacks
|
||||
from .comet import callbacks as comet_callbacks
|
||||
from .hub import callbacks as hub_callbacks
|
||||
|
@ -59,6 +59,7 @@ def _log_plot(title, plot_path) -> None:
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Runs at start of pretraining routine; initializes and connects/ logs task to ClearML."""
|
||||
try:
|
||||
task = Task.current_task()
|
||||
if task:
|
||||
@ -83,11 +84,13 @@ def on_pretrain_routine_start(trainer):
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Logs debug samples for the first epoch of YOLO training."""
|
||||
if trainer.epoch == 1 and Task.current_task():
|
||||
_log_debug_samples(sorted(trainer.save_dir.glob('train_batch*.jpg')), 'Mosaic')
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Reports model information to logger at the end of an epoch."""
|
||||
task = Task.current_task()
|
||||
if task:
|
||||
# You should have access to the validation bboxes under jdict
|
||||
@ -105,12 +108,14 @@ def on_fit_epoch_end(trainer):
|
||||
|
||||
|
||||
def on_val_end(validator):
|
||||
"""Logs validation results including labels and predictions."""
|
||||
if Task.current_task():
|
||||
# Log val_labels and val_pred
|
||||
_log_debug_samples(sorted(validator.save_dir.glob('val*.jpg')), 'Validation')
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Logs final model and its name on training completion."""
|
||||
task = Task.current_task()
|
||||
if task:
|
||||
# Log final results, CM matrix + PR plots
|
||||
|
@ -36,6 +36,7 @@ _comet_image_prediction_count = 0
|
||||
|
||||
|
||||
def _get_experiment_type(mode, project_name):
|
||||
"""Return an experiment based on mode and project name."""
|
||||
if mode == 'offline':
|
||||
return comet_ml.OfflineExperiment(project_name=project_name)
|
||||
|
||||
@ -61,6 +62,7 @@ def _create_experiment(args):
|
||||
|
||||
|
||||
def _fetch_trainer_metadata(trainer):
|
||||
"""Returns metadata for YOLO training including epoch and asset saving status."""
|
||||
curr_epoch = trainer.epoch + 1
|
||||
|
||||
train_num_steps_per_epoch = len(trainer.train_loader.dataset) // trainer.batch_size
|
||||
@ -97,6 +99,7 @@ def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, origin
|
||||
|
||||
|
||||
def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None):
|
||||
"""Format ground truth annotations for detection."""
|
||||
indices = batch['batch_idx'] == img_idx
|
||||
bboxes = batch['bboxes'][indices]
|
||||
if len(bboxes) == 0:
|
||||
@ -120,6 +123,7 @@ def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, c
|
||||
|
||||
|
||||
def _format_prediction_annotations_for_detection(image_path, metadata, class_label_map=None):
|
||||
"""Format YOLO predictions for object detection visualization."""
|
||||
stem = image_path.stem
|
||||
image_id = int(stem) if stem.isnumeric() else stem
|
||||
|
||||
@ -142,6 +146,7 @@ def _format_prediction_annotations_for_detection(image_path, metadata, class_lab
|
||||
|
||||
|
||||
def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, class_label_map):
|
||||
"""Join the ground truth and prediction annotations if they exist."""
|
||||
ground_truth_annotations = _format_ground_truth_annotations_for_detection(img_idx, image_path, batch,
|
||||
class_label_map)
|
||||
prediction_annotations = _format_prediction_annotations_for_detection(image_path, prediction_metadata_map,
|
||||
@ -153,6 +158,7 @@ def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, clas
|
||||
|
||||
|
||||
def _create_prediction_metadata_map(model_predictions):
|
||||
"""Create metadata map for model predictions by groupings them based on image ID."""
|
||||
pred_metadata_map = {}
|
||||
for prediction in model_predictions:
|
||||
pred_metadata_map.setdefault(prediction['image_id'], [])
|
||||
@ -162,6 +168,7 @@ def _create_prediction_metadata_map(model_predictions):
|
||||
|
||||
|
||||
def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch):
|
||||
"""Log the confusion matrix to Weights and Biases experiment."""
|
||||
conf_mat = trainer.validator.confusion_matrix.matrix
|
||||
names = list(trainer.data['names'].values()) + ['background']
|
||||
experiment.log_confusion_matrix(
|
||||
@ -174,6 +181,7 @@ def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch):
|
||||
|
||||
|
||||
def _log_images(experiment, image_paths, curr_step, annotations=None):
|
||||
"""Logs images to the experiment with optional annotations."""
|
||||
if annotations:
|
||||
for image_path, annotation in zip(image_paths, annotations):
|
||||
experiment.log_image(image_path, name=image_path.stem, step=curr_step, annotations=annotation)
|
||||
@ -184,6 +192,7 @@ def _log_images(experiment, image_paths, curr_step, annotations=None):
|
||||
|
||||
|
||||
def _log_image_predictions(experiment, validator, curr_step):
|
||||
"""Logs predicted boxes for a single image during training."""
|
||||
global _comet_image_prediction_count
|
||||
|
||||
task = validator.args.task
|
||||
@ -225,6 +234,7 @@ def _log_image_predictions(experiment, validator, curr_step):
|
||||
|
||||
|
||||
def _log_plots(experiment, trainer):
|
||||
"""Logs evaluation plots and label plots for the experiment."""
|
||||
plot_filenames = [trainer.save_dir / f'{plots}.png' for plots in EVALUATION_PLOT_NAMES]
|
||||
_log_images(experiment, plot_filenames, None)
|
||||
|
||||
@ -233,6 +243,7 @@ def _log_plots(experiment, trainer):
|
||||
|
||||
|
||||
def _log_model(experiment, trainer):
|
||||
"""Log the best-trained model to Comet.ml."""
|
||||
experiment.log_model(
|
||||
COMET_MODEL_NAME,
|
||||
file_or_folder=str(trainer.best),
|
||||
@ -242,12 +253,14 @@ def _log_model(experiment, trainer):
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Creates or resumes a CometML experiment at the start of a YOLO pre-training routine."""
|
||||
experiment = comet_ml.get_global_experiment()
|
||||
if not experiment:
|
||||
_create_experiment(trainer.args)
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Log metrics and save batch images at the end of training epochs."""
|
||||
experiment = comet_ml.get_global_experiment()
|
||||
if not experiment:
|
||||
return
|
||||
@ -267,6 +280,7 @@ def on_train_epoch_end(trainer):
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Logs model assets at the end of each epoch."""
|
||||
experiment = comet_ml.get_global_experiment()
|
||||
if not experiment:
|
||||
return
|
||||
@ -296,6 +310,7 @@ def on_fit_epoch_end(trainer):
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Perform operations at the end of training."""
|
||||
experiment = comet_ml.get_global_experiment()
|
||||
if not experiment:
|
||||
return
|
||||
|
@ -9,6 +9,7 @@ from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
|
||||
|
||||
|
||||
def on_pretrain_routine_end(trainer):
|
||||
"""Logs info before starting timer for upload rate limit."""
|
||||
session = getattr(trainer, 'hub_session', None)
|
||||
if session:
|
||||
# Start timer for upload rate limit
|
||||
@ -17,6 +18,7 @@ def on_pretrain_routine_end(trainer):
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Uploads training progress metrics at the end of each epoch."""
|
||||
session = getattr(trainer, 'hub_session', None)
|
||||
if session:
|
||||
# Upload metrics after val end
|
||||
@ -35,6 +37,7 @@ def on_fit_epoch_end(trainer):
|
||||
|
||||
|
||||
def on_model_save(trainer):
|
||||
"""Saves checkpoints to Ultralytics HUB with rate limiting."""
|
||||
session = getattr(trainer, 'hub_session', None)
|
||||
if session:
|
||||
# Upload checkpoints with rate limiting
|
||||
@ -46,6 +49,7 @@ def on_model_save(trainer):
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Upload final model and metrics to Ultralytics HUB at the end of training."""
|
||||
session = getattr(trainer, 'hub_session', None)
|
||||
if session:
|
||||
# Upload final model and metrics with exponential standoff
|
||||
@ -57,18 +61,22 @@ def on_train_end(trainer):
|
||||
|
||||
|
||||
def on_train_start(trainer):
|
||||
"""Run traces on train start."""
|
||||
traces(trainer.args, traces_sample_rate=1.0)
|
||||
|
||||
|
||||
def on_val_start(validator):
|
||||
"""Runs traces on validation start."""
|
||||
traces(validator.args, traces_sample_rate=1.0)
|
||||
|
||||
|
||||
def on_predict_start(predictor):
|
||||
"""Run traces on predict start."""
|
||||
traces(predictor.args, traces_sample_rate=1.0)
|
||||
|
||||
|
||||
def on_export_start(exporter):
|
||||
"""Run traces on export start."""
|
||||
traces(exporter.args, traces_sample_rate=1.0)
|
||||
|
||||
|
||||
|
@ -16,6 +16,7 @@ except (ImportError, AssertionError):
|
||||
|
||||
|
||||
def on_pretrain_routine_end(trainer):
|
||||
"""Logs training parameters to MLflow."""
|
||||
global mlflow, run, run_id, experiment_name
|
||||
|
||||
if os.environ.get('MLFLOW_TRACKING_URI') is None:
|
||||
@ -45,17 +46,20 @@ def on_pretrain_routine_end(trainer):
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Logs training metrics to Mlflow."""
|
||||
if mlflow:
|
||||
metrics_dict = {f"{re.sub('[()]', '', k)}": float(v) for k, v in trainer.metrics.items()}
|
||||
run.log_metrics(metrics=metrics_dict, step=trainer.epoch)
|
||||
|
||||
|
||||
def on_model_save(trainer):
|
||||
"""Logs model and metrics to mlflow on save."""
|
||||
if mlflow:
|
||||
run.log_artifact(trainer.last)
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Called at end of train loop to log model artifact info."""
|
||||
if mlflow:
|
||||
root_dir = Path(__file__).resolve().parents[3]
|
||||
run.log_artifact(trainer.best)
|
||||
|
@ -7,6 +7,7 @@ except (ImportError, AssertionError):
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Sends training metrics to Ray Tune at end of each epoch."""
|
||||
if ray.tune.is_session_enabled():
|
||||
metrics = trainer.metrics
|
||||
metrics['epoch'] = trainer.epoch
|
||||
|
@ -12,12 +12,14 @@ writer = None # TensorBoard SummaryWriter instance
|
||||
|
||||
|
||||
def _log_scalars(scalars, step=0):
|
||||
"""Logs scalar values to TensorBoard."""
|
||||
if writer:
|
||||
for k, v in scalars.items():
|
||||
writer.add_scalar(k, v, step)
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Initialize TensorBoard logging with SummaryWriter."""
|
||||
if SummaryWriter:
|
||||
try:
|
||||
global writer
|
||||
@ -29,10 +31,12 @@ def on_pretrain_routine_start(trainer):
|
||||
|
||||
|
||||
def on_batch_end(trainer):
|
||||
"""Logs scalar statistics at the end of a training batch."""
|
||||
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Logs epoch metrics at end of training epoch."""
|
||||
_log_scalars(trainer.metrics, trainer.epoch + 1)
|
||||
|
||||
|
||||
|
@ -11,11 +11,13 @@ except (ImportError, AssertionError):
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Initiate and start project if module is present."""
|
||||
wb.init(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, config=vars(
|
||||
trainer.args)) if not wb.run else wb.run
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Logs training metrics and model information at the end of an epoch."""
|
||||
wb.run.log(trainer.metrics, step=trainer.epoch + 1)
|
||||
if trainer.epoch == 0:
|
||||
model_info = {
|
||||
@ -26,6 +28,7 @@ def on_fit_epoch_end(trainer):
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Log metrics and save images at the end of each training epoch."""
|
||||
wb.run.log(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1)
|
||||
wb.run.log(trainer.lr, step=trainer.epoch + 1)
|
||||
if trainer.epoch == 1:
|
||||
@ -35,6 +38,7 @@ def on_train_epoch_end(trainer):
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Save the best model as an artifact at end of training."""
|
||||
art = wb.Artifact(type='model', name=f'run_{wb.run.id}_model')
|
||||
if trainer.best.exists():
|
||||
art.add_file(trainer.best)
|
||||
|
@ -295,7 +295,7 @@ def check_file(file, suffix='', download=True, hard=True):
|
||||
|
||||
|
||||
def check_yaml(file, suffix=('.yaml', '.yml'), hard=True):
|
||||
# Search/download YAML file (if necessary) and return path, checking suffix
|
||||
"""Search/download YAML file (if necessary) and return path, checking suffix."""
|
||||
return check_file(file, suffix, hard=hard)
|
||||
|
||||
|
||||
@ -315,6 +315,7 @@ def check_imshow(warn=False):
|
||||
|
||||
|
||||
def check_yolo(verbose=True, device=''):
|
||||
"""Return a human-readable YOLO software and hardware summary."""
|
||||
from ultralytics.yolo.utils.torch_utils import select_device
|
||||
|
||||
if is_colab():
|
||||
|
@ -24,6 +24,7 @@ def find_free_network_port() -> int:
|
||||
|
||||
|
||||
def generate_ddp_file(trainer):
|
||||
"""Generates a DDP file and returns its file name."""
|
||||
module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1)
|
||||
|
||||
content = f'''cfg = {vars(trainer.args)} \nif __name__ == "__main__":
|
||||
@ -43,6 +44,7 @@ def generate_ddp_file(trainer):
|
||||
|
||||
|
||||
def generate_ddp_command(world_size, trainer):
|
||||
"""Generates and returns command for distributed training."""
|
||||
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
|
||||
if not trainer.resume:
|
||||
shutil.rmtree(trainer.save_dir) # remove the save_dir
|
||||
|
@ -192,7 +192,7 @@ def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'):
|
||||
|
||||
|
||||
def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3):
|
||||
# Multithreaded file download and unzip function, used in data.yaml for autodownload
|
||||
"""Downloads and unzips files concurrently if threads > 1, else sequentially."""
|
||||
dir = Path(dir)
|
||||
dir.mkdir(parents=True, exist_ok=True) # make directory
|
||||
if threads > 1:
|
||||
|
@ -6,4 +6,5 @@ from ultralytics.yolo.utils import emojis
|
||||
class HUBModelError(Exception):
|
||||
|
||||
def __init__(self, message='Model not found. Please check model URL and try again.'):
|
||||
"""Create an exception for when a model is not found."""
|
||||
super().__init__(emojis(message))
|
||||
|
@ -11,13 +11,16 @@ class WorkingDirectory(contextlib.ContextDecorator):
|
||||
"""Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager."""
|
||||
|
||||
def __init__(self, new_dir):
|
||||
"""Sets the working directory to 'new_dir' upon instantiation."""
|
||||
self.dir = new_dir # new dir
|
||||
self.cwd = Path.cwd().resolve() # current dir
|
||||
|
||||
def __enter__(self):
|
||||
"""Changes the current directory to the specified directory."""
|
||||
os.chdir(self.dir)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Restore the current working directory on context exit."""
|
||||
os.chdir(self.cwd)
|
||||
|
||||
|
||||
|
@ -14,6 +14,7 @@ def _ntuple(n):
|
||||
"""From PyTorch internals."""
|
||||
|
||||
def parse(x):
|
||||
"""Parse bounding boxes format between XYWH and LTWH."""
|
||||
return x if isinstance(x, abc.Iterable) else tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
@ -64,6 +65,7 @@ class Bboxes:
|
||||
# return Bboxes(bboxes, format)
|
||||
|
||||
def convert(self, format):
|
||||
"""Converts bounding box format from one type to another."""
|
||||
assert format in _formats
|
||||
if self.format == format:
|
||||
return
|
||||
@ -77,6 +79,7 @@ class Bboxes:
|
||||
self.format = format
|
||||
|
||||
def areas(self):
|
||||
"""Return box areas."""
|
||||
self.convert('xyxy')
|
||||
return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1])
|
||||
|
||||
@ -125,6 +128,7 @@ class Bboxes:
|
||||
self.bboxes[:, 3] += offset[3]
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of boxes."""
|
||||
return len(self.bboxes)
|
||||
|
||||
@classmethod
|
||||
@ -202,9 +206,11 @@ class Instances:
|
||||
self.segments = segments
|
||||
|
||||
def convert_bbox(self, format):
|
||||
"""Convert bounding box format."""
|
||||
self._bboxes.convert(format=format)
|
||||
|
||||
def bbox_areas(self):
|
||||
"""Calculate the area of bounding boxes."""
|
||||
self._bboxes.areas()
|
||||
|
||||
def scale(self, scale_w, scale_h, bbox_only=False):
|
||||
@ -219,6 +225,7 @@ class Instances:
|
||||
self.keypoints[..., 1] *= scale_h
|
||||
|
||||
def denormalize(self, w, h):
|
||||
"""Denormalizes boxes, segments, and keypoints from normalized coordinates."""
|
||||
if not self.normalized:
|
||||
return
|
||||
self._bboxes.mul(scale=(w, h, w, h))
|
||||
@ -230,6 +237,7 @@ class Instances:
|
||||
self.normalized = False
|
||||
|
||||
def normalize(self, w, h):
|
||||
"""Normalize bounding boxes, segments, and keypoints to image dimensions."""
|
||||
if self.normalized:
|
||||
return
|
||||
self._bboxes.mul(scale=(1 / w, 1 / h, 1 / w, 1 / h))
|
||||
@ -279,6 +287,7 @@ class Instances:
|
||||
)
|
||||
|
||||
def flipud(self, h):
|
||||
"""Flips the coordinates of bounding boxes, segments, and keypoints vertically."""
|
||||
if self._bboxes.format == 'xyxy':
|
||||
y1 = self.bboxes[:, 1].copy()
|
||||
y2 = self.bboxes[:, 3].copy()
|
||||
@ -291,6 +300,7 @@ class Instances:
|
||||
self.keypoints[..., 1] = h - self.keypoints[..., 1]
|
||||
|
||||
def fliplr(self, w):
|
||||
"""Reverses the order of the bounding boxes and segments horizontally."""
|
||||
if self._bboxes.format == 'xyxy':
|
||||
x1 = self.bboxes[:, 0].copy()
|
||||
x2 = self.bboxes[:, 2].copy()
|
||||
@ -303,6 +313,7 @@ class Instances:
|
||||
self.keypoints[..., 0] = w - self.keypoints[..., 0]
|
||||
|
||||
def clip(self, w, h):
|
||||
"""Clips bounding boxes, segments, and keypoints values to stay within image boundaries."""
|
||||
ori_format = self._bboxes.format
|
||||
self.convert_bbox(format='xyxy')
|
||||
self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
|
||||
@ -316,6 +327,7 @@ class Instances:
|
||||
self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h)
|
||||
|
||||
def update(self, bboxes, segments=None, keypoints=None):
|
||||
"""Updates instance variables."""
|
||||
new_bboxes = Bboxes(bboxes, format=self._bboxes.format)
|
||||
self._bboxes = new_bboxes
|
||||
if segments is not None:
|
||||
@ -324,6 +336,7 @@ class Instances:
|
||||
self.keypoints = keypoints
|
||||
|
||||
def __len__(self):
|
||||
"""Return the length of the instance list."""
|
||||
return len(self.bboxes)
|
||||
|
||||
@classmethod
|
||||
@ -363,4 +376,5 @@ class Instances:
|
||||
|
||||
@property
|
||||
def bboxes(self):
|
||||
"""Return bounding boxes."""
|
||||
return self._bboxes.bboxes
|
||||
|
@ -12,9 +12,11 @@ class VarifocalLoss(nn.Module):
|
||||
"""Varifocal loss by Zhang et al. https://arxiv.org/abs/2008.13367."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the VarifocalLoss class."""
|
||||
super().__init__()
|
||||
|
||||
def forward(self, pred_score, gt_score, label, alpha=0.75, gamma=2.0):
|
||||
"""Computes varfocal loss."""
|
||||
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') *
|
||||
@ -25,6 +27,7 @@ class VarifocalLoss(nn.Module):
|
||||
class BboxLoss(nn.Module):
|
||||
|
||||
def __init__(self, reg_max, use_dfl=False):
|
||||
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
|
||||
super().__init__()
|
||||
self.reg_max = reg_max
|
||||
self.use_dfl = use_dfl
|
||||
@ -64,6 +67,7 @@ class KeypointLoss(nn.Module):
|
||||
self.sigmas = sigmas
|
||||
|
||||
def forward(self, pred_kpts, gt_kpts, kpt_mask, area):
|
||||
"""Calculates keypoint loss factor and Euclidean distance loss for predicted and actual keypoints."""
|
||||
d = (pred_kpts[..., 0] - gt_kpts[..., 0]) ** 2 + (pred_kpts[..., 1] - gt_kpts[..., 1]) ** 2
|
||||
kpt_loss_factor = (torch.sum(kpt_mask != 0) + torch.sum(kpt_mask == 0)) / (torch.sum(kpt_mask != 0) + 1e-9)
|
||||
# e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula
|
||||
|
@ -180,6 +180,7 @@ class FocalLoss(nn.Module):
|
||||
"""Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
|
||||
|
||||
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
|
||||
"""Initialize FocalLoss object with given loss function and hyperparameters."""
|
||||
super().__init__()
|
||||
self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
|
||||
self.gamma = gamma
|
||||
@ -188,6 +189,7 @@ class FocalLoss(nn.Module):
|
||||
self.loss_fcn.reduction = 'none' # required to apply FL to each element
|
||||
|
||||
def forward(self, pred, true):
|
||||
"""Calculates and updates confusion matrix for object detection/classification tasks."""
|
||||
loss = self.loss_fcn(pred, true)
|
||||
# p_t = torch.exp(-loss)
|
||||
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
|
||||
@ -220,6 +222,7 @@ class ConfusionMatrix:
|
||||
"""
|
||||
|
||||
def __init__(self, nc, conf=0.25, iou_thres=0.45, task='detect'):
|
||||
"""Initialize attributes for the YOLO model."""
|
||||
self.task = task
|
||||
self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == 'detect' else np.zeros((nc, nc))
|
||||
self.nc = nc # number of classes
|
||||
@ -285,9 +288,11 @@ class ConfusionMatrix:
|
||||
self.matrix[dc, self.nc] += 1 # predicted background
|
||||
|
||||
def matrix(self):
|
||||
"""Returns the confusion matrix."""
|
||||
return self.matrix
|
||||
|
||||
def tp_fp(self):
|
||||
"""Returns true positives and false positives."""
|
||||
tp = self.matrix.diagonal() # true positives
|
||||
fp = self.matrix.sum(1) - tp # false positives
|
||||
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
|
||||
@ -679,6 +684,7 @@ class DetMetrics(SimpleClass):
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
|
||||
def process(self, tp, conf, pred_cls, target_cls):
|
||||
"""Process predicted results for object detection and update metrics."""
|
||||
results = ap_per_class(tp, conf, pred_cls, target_cls, plot=self.plot, save_dir=self.save_dir,
|
||||
names=self.names)[2:]
|
||||
self.box.nc = len(self.names)
|
||||
@ -686,28 +692,35 @@ class DetMetrics(SimpleClass):
|
||||
|
||||
@property
|
||||
def keys(self):
|
||||
"""Returns a list of keys for accessing specific metrics."""
|
||||
return ['metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)']
|
||||
|
||||
def mean_results(self):
|
||||
"""Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
|
||||
return self.box.mean_results()
|
||||
|
||||
def class_result(self, i):
|
||||
"""Return the result of evaluating the performance of an object detection model on a specific class."""
|
||||
return self.box.class_result(i)
|
||||
|
||||
@property
|
||||
def maps(self):
|
||||
"""Returns mean Average Precision (mAP) scores per class."""
|
||||
return self.box.maps
|
||||
|
||||
@property
|
||||
def fitness(self):
|
||||
"""Returns the fitness of box object."""
|
||||
return self.box.fitness()
|
||||
|
||||
@property
|
||||
def ap_class_index(self):
|
||||
"""Returns the average precision index per class."""
|
||||
return self.box.ap_class_index
|
||||
|
||||
@property
|
||||
def results_dict(self):
|
||||
"""Returns dictionary of computed performance metrics and statistics."""
|
||||
return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
|
||||
|
||||
|
||||
@ -781,22 +794,27 @@ class SegmentMetrics(SimpleClass):
|
||||
|
||||
@property
|
||||
def keys(self):
|
||||
"""Returns a list of keys for accessing metrics."""
|
||||
return [
|
||||
'metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)',
|
||||
'metrics/precision(M)', 'metrics/recall(M)', 'metrics/mAP50(M)', 'metrics/mAP50-95(M)']
|
||||
|
||||
def mean_results(self):
|
||||
"""Return the mean metrics for bounding box and segmentation results."""
|
||||
return self.box.mean_results() + self.seg.mean_results()
|
||||
|
||||
def class_result(self, i):
|
||||
"""Returns classification results for a specified class index."""
|
||||
return self.box.class_result(i) + self.seg.class_result(i)
|
||||
|
||||
@property
|
||||
def maps(self):
|
||||
"""Returns mAP scores for object detection and semantic segmentation models."""
|
||||
return self.box.maps + self.seg.maps
|
||||
|
||||
@property
|
||||
def fitness(self):
|
||||
"""Get the fitness score for both segmentation and bounding box models."""
|
||||
return self.seg.fitness() + self.box.fitness()
|
||||
|
||||
@property
|
||||
@ -806,6 +824,7 @@ class SegmentMetrics(SimpleClass):
|
||||
|
||||
@property
|
||||
def results_dict(self):
|
||||
"""Returns results of object detection model for evaluation."""
|
||||
return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
|
||||
|
||||
|
||||
@ -846,6 +865,7 @@ class PoseMetrics(SegmentMetrics):
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
|
||||
def __getattr__(self, attr):
|
||||
"""Raises an AttributeError if an invalid attribute is accessed."""
|
||||
name = self.__class__.__name__
|
||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||
|
||||
@ -884,22 +904,27 @@ class PoseMetrics(SegmentMetrics):
|
||||
|
||||
@property
|
||||
def keys(self):
|
||||
"""Returns list of evaluation metric keys."""
|
||||
return [
|
||||
'metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)',
|
||||
'metrics/precision(P)', 'metrics/recall(P)', 'metrics/mAP50(P)', 'metrics/mAP50-95(P)']
|
||||
|
||||
def mean_results(self):
|
||||
"""Return the mean results of box and pose."""
|
||||
return self.box.mean_results() + self.pose.mean_results()
|
||||
|
||||
def class_result(self, i):
|
||||
"""Return the class-wise detection results for a specific class i."""
|
||||
return self.box.class_result(i) + self.pose.class_result(i)
|
||||
|
||||
@property
|
||||
def maps(self):
|
||||
"""Returns the mean average precision (mAP) per class for both box and pose detections."""
|
||||
return self.box.maps + self.pose.maps
|
||||
|
||||
@property
|
||||
def fitness(self):
|
||||
"""Computes classification metrics and speed using the `targets` and `pred` inputs."""
|
||||
return self.pose.fitness() + self.box.fitness()
|
||||
|
||||
|
||||
@ -935,12 +960,15 @@ class ClassifyMetrics(SimpleClass):
|
||||
|
||||
@property
|
||||
def fitness(self):
|
||||
"""Returns top-5 accuracy as fitness score."""
|
||||
return self.top5
|
||||
|
||||
@property
|
||||
def results_dict(self):
|
||||
"""Returns a dictionary with model's performance metrics and fitness score."""
|
||||
return dict(zip(self.keys + ['fitness'], [self.top1, self.top5, self.fitness]))
|
||||
|
||||
@property
|
||||
def keys(self):
|
||||
"""Returns a list of keys for the results_dict property."""
|
||||
return ['metrics/accuracy_top1', 'metrics/accuracy_top5']
|
||||
|
@ -33,6 +33,7 @@ class Colors:
|
||||
dtype=np.uint8)
|
||||
|
||||
def __call__(self, i, bgr=False):
|
||||
"""Converts hex color codes to rgb values."""
|
||||
c = self.palette[int(i) % self.n]
|
||||
return (c[2], c[1], c[0]) if bgr else c
|
||||
|
||||
@ -47,6 +48,7 @@ colors = Colors() # create instance for 'from utils.plots import colors'
|
||||
class Annotator:
|
||||
# YOLOv8 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
|
||||
def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
|
||||
"""Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
|
||||
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
|
||||
non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
|
||||
self.pil = pil or non_ascii
|
||||
@ -71,7 +73,7 @@ class Annotator:
|
||||
self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]
|
||||
|
||||
def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
|
||||
# Add one xyxy box to image with label
|
||||
"""Add one xyxy box to image with label."""
|
||||
if isinstance(box, torch.Tensor):
|
||||
box = box.tolist()
|
||||
if self.pil or not is_ascii(label):
|
||||
@ -191,7 +193,7 @@ class Annotator:
|
||||
self.draw.rectangle(xy, fill, outline, width)
|
||||
|
||||
def text(self, xy, text, txt_color=(255, 255, 255), anchor='top'):
|
||||
# Add text to image (PIL-only)
|
||||
"""Adds text to an image using PIL or cv2."""
|
||||
if anchor == 'bottom': # start y from font bottom
|
||||
w, h = self.font.getsize(text) # text width, height
|
||||
xy[1] += 1 - h
|
||||
@ -214,6 +216,7 @@ class Annotator:
|
||||
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
|
||||
@plt_settings()
|
||||
def plot_labels(boxes, cls, names=(), save_dir=Path('')):
|
||||
"""Save and plot image with no axis or spines."""
|
||||
import pandas as pd
|
||||
import seaborn as sn
|
||||
|
||||
@ -260,7 +263,7 @@ def plot_labels(boxes, cls, names=(), save_dir=Path('')):
|
||||
|
||||
|
||||
def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
|
||||
# Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
|
||||
"""Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop."""
|
||||
b = xyxy2xywh(xyxy.view(-1, 4)) # boxes
|
||||
if square:
|
||||
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
|
||||
|
@ -69,6 +69,7 @@ class TaskAlignedAssigner(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
|
||||
"""Initialize a TaskAlignedAssigner object with customizable hyperparameters."""
|
||||
super().__init__()
|
||||
self.topk = topk
|
||||
self.num_classes = num_classes
|
||||
@ -137,6 +138,7 @@ class TaskAlignedAssigner(nn.Module):
|
||||
return mask_pos, align_metric, overlaps
|
||||
|
||||
def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
|
||||
"""Compute alignment metric given predicted and ground truth bounding boxes."""
|
||||
na = pd_bboxes.shape[-2]
|
||||
mask_gt = mask_gt.bool() # b, max_num_obj, h*w
|
||||
overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
|
||||
|
@ -43,6 +43,7 @@ def smart_inference_mode():
|
||||
"""Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator."""
|
||||
|
||||
def decorate(fn):
|
||||
"""Applies appropriate torch decorator for inference mode based on torch version."""
|
||||
return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)
|
||||
|
||||
return decorate
|
||||
@ -232,7 +233,7 @@ def make_divisible(x, divisor):
|
||||
|
||||
|
||||
def copy_attr(a, b, include=(), exclude=()):
|
||||
# Copy attributes from 'b' to 'a', options to only include [...] and to exclude [...]
|
||||
"""Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes."""
|
||||
for k, v in b.__dict__.items():
|
||||
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
|
||||
continue
|
||||
@ -246,7 +247,7 @@ def get_latest_opset():
|
||||
|
||||
|
||||
def intersect_dicts(da, db, exclude=()):
|
||||
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
|
||||
"""Returns a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values."""
|
||||
return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
|
||||
|
||||
|
||||
@ -310,7 +311,7 @@ class ModelEMA:
|
||||
# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}'
|
||||
|
||||
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
||||
# Update EMA attributes
|
||||
"""Updates attributes and saves stripped model with optimizer removed."""
|
||||
if self.enabled:
|
||||
copy_attr(self.ema, model, include, exclude)
|
||||
|
||||
|
@ -10,10 +10,12 @@ from ultralytics.yolo.utils import DEFAULT_CFG, ROOT
|
||||
class ClassificationPredictor(BasePredictor):
|
||||
|
||||
def preprocess(self, img):
|
||||
"""Converts input image to model-compatible data type."""
|
||||
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
|
||||
return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Postprocesses predictions to return Results objects."""
|
||||
results = []
|
||||
for i, pred in enumerate(preds):
|
||||
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
|
||||
@ -25,6 +27,7 @@ class ClassificationPredictor(BasePredictor):
|
||||
|
||||
|
||||
def predict(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Run YOLO model predictions on input images/videos."""
|
||||
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
|
||||
source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
|
||||
else 'https://ultralytics.com/images/bus.jpg'
|
||||
|
@ -14,15 +14,18 @@ from ultralytics.yolo.utils.torch_utils import is_parallel, strip_optimizer
|
||||
class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""Initialize a ClassificationTrainer object with optional configuration overrides and callbacks."""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides['task'] = 'classify'
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def set_model_attributes(self):
|
||||
"""Set the YOLO model's class names from the loaded dataset."""
|
||||
self.model.names = self.data['names']
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""Returns a modified PyTorch model configured for training YOLO."""
|
||||
model = ClassificationModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
@ -69,6 +72,7 @@ class ClassificationTrainer(BaseTrainer):
|
||||
return # dont return ckpt. Classification doesn't support resume
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
|
||||
"""Returns PyTorch DataLoader with transforms to preprocess images for inference."""
|
||||
loader = build_classification_dataloader(path=dataset_path,
|
||||
imgsz=self.args.imgsz,
|
||||
batch_size=batch_size if mode == 'train' else (batch_size * 2),
|
||||
@ -84,19 +88,23 @@ class ClassificationTrainer(BaseTrainer):
|
||||
return loader
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
"""Preprocesses a batch of images and classes."""
|
||||
batch['img'] = batch['img'].to(self.device)
|
||||
batch['cls'] = batch['cls'].to(self.device)
|
||||
return batch
|
||||
|
||||
def progress_string(self):
|
||||
"""Returns a formatted string showing training progress."""
|
||||
return ('\n' + '%11s' * (4 + len(self.loss_names))) % \
|
||||
('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
|
||||
|
||||
def get_validator(self):
|
||||
"""Returns an instance of ClassificationValidator for validation."""
|
||||
self.loss_names = ['loss']
|
||||
return v8.classify.ClassificationValidator(self.test_loader, self.save_dir)
|
||||
|
||||
def criterion(self, preds, batch):
|
||||
"""Compute the classification loss between predictions and true labels."""
|
||||
loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / self.args.nbs
|
||||
loss_items = loss.detach()
|
||||
return loss, loss_items
|
||||
@ -113,9 +121,11 @@ class ClassificationTrainer(BaseTrainer):
|
||||
return dict(zip(keys, loss_items))
|
||||
|
||||
def resume_training(self, ckpt):
|
||||
"""Resumes training from a given checkpoint."""
|
||||
pass
|
||||
|
||||
def final_eval(self):
|
||||
"""Evaluate trained model and save validation results."""
|
||||
for f in self.last, self.best:
|
||||
if f.exists():
|
||||
strip_optimizer(f) # strip optimizers
|
||||
@ -130,6 +140,7 @@ class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
|
||||
def train(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Train the YOLO classification model."""
|
||||
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
|
||||
data = cfg.data or 'mnist160' # or yolo.ClassificationDataset("mnist")
|
||||
device = cfg.device if cfg.device is not None else ''
|
||||
|
@ -9,14 +9,17 @@ from ultralytics.yolo.utils.metrics import ClassifyMetrics, ConfusionMatrix
|
||||
class ClassificationValidator(BaseValidator):
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||
"""Initializes ClassificationValidator instance with args, dataloader, save_dir, and progress bar."""
|
||||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||
self.args.task = 'classify'
|
||||
self.metrics = ClassifyMetrics()
|
||||
|
||||
def get_desc(self):
|
||||
"""Returns a formatted string summarizing classification metrics."""
|
||||
return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc')
|
||||
|
||||
def init_metrics(self, model):
|
||||
"""Initialize confusion matrix, class names, and top-1 and top-5 accuracy."""
|
||||
self.names = model.names
|
||||
self.nc = len(model.names)
|
||||
self.confusion_matrix = ConfusionMatrix(nc=self.nc, task='classify')
|
||||
@ -24,17 +27,20 @@ class ClassificationValidator(BaseValidator):
|
||||
self.targets = []
|
||||
|
||||
def preprocess(self, batch):
|
||||
"""Preprocesses input batch and returns it."""
|
||||
batch['img'] = batch['img'].to(self.device, non_blocking=True)
|
||||
batch['img'] = batch['img'].half() if self.args.half else batch['img'].float()
|
||||
batch['cls'] = batch['cls'].to(self.device)
|
||||
return batch
|
||||
|
||||
def update_metrics(self, preds, batch):
|
||||
"""Updates running metrics with model predictions and batch targets."""
|
||||
n5 = min(len(self.model.names), 5)
|
||||
self.pred.append(preds.argsort(1, descending=True)[:, :n5])
|
||||
self.targets.append(batch['cls'])
|
||||
|
||||
def finalize_metrics(self, *args, **kwargs):
|
||||
"""Finalizes metrics of the model such as confusion_matrix and speed."""
|
||||
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values()))
|
||||
@ -42,10 +48,12 @@ class ClassificationValidator(BaseValidator):
|
||||
self.metrics.confusion_matrix = self.confusion_matrix
|
||||
|
||||
def get_stats(self):
|
||||
"""Returns a dictionary of metrics obtained by processing targets and predictions."""
|
||||
self.metrics.process(self.targets, self.pred)
|
||||
return self.metrics.results_dict
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size):
|
||||
"""Builds and returns a data loader for classification tasks with given parameters."""
|
||||
return build_classification_dataloader(path=dataset_path,
|
||||
imgsz=self.args.imgsz,
|
||||
batch_size=batch_size,
|
||||
@ -54,11 +62,13 @@ class ClassificationValidator(BaseValidator):
|
||||
workers=self.args.workers)
|
||||
|
||||
def print_results(self):
|
||||
"""Prints evaluation metrics for YOLO object detection model."""
|
||||
pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
|
||||
LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5))
|
||||
|
||||
|
||||
def val(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Validate YOLO model using custom data."""
|
||||
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
|
||||
data = cfg.data or 'mnist160'
|
||||
|
||||
|
@ -10,12 +10,14 @@ from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops
|
||||
class DetectionPredictor(BasePredictor):
|
||||
|
||||
def preprocess(self, img):
|
||||
"""Convert an image to PyTorch tensor and normalize pixel values."""
|
||||
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
|
||||
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
|
||||
img /= 255 # 0 - 255 to 0.0 - 1.0
|
||||
return img
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Postprocesses predictions and returns a list of Results objects."""
|
||||
preds = ops.non_max_suppression(preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
@ -35,6 +37,7 @@ class DetectionPredictor(BasePredictor):
|
||||
|
||||
|
||||
def predict(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Runs YOLO model inference on input image(s)."""
|
||||
model = cfg.model or 'yolov8n.pt'
|
||||
source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
|
||||
else 'https://ultralytics.com/images/bus.jpg'
|
||||
|
@ -44,6 +44,7 @@ class DetectionTrainer(BaseTrainer):
|
||||
rect=mode == 'val', data_info=self.data)[0]
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
"""Preprocesses a batch of images by scaling and converting to float."""
|
||||
batch['img'] = batch['img'].to(self.device, non_blocking=True).float() / 255
|
||||
return batch
|
||||
|
||||
@ -58,16 +59,19 @@ class DetectionTrainer(BaseTrainer):
|
||||
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""Return a YOLO detection model."""
|
||||
model = DetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
return model
|
||||
|
||||
def get_validator(self):
|
||||
"""Returns a DetectionValidator for YOLO model validation."""
|
||||
self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
|
||||
return v8.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
||||
|
||||
def criterion(self, preds, batch):
|
||||
"""Compute loss for YOLO prediction and ground-truth."""
|
||||
if not hasattr(self, 'compute_loss'):
|
||||
self.compute_loss = Loss(de_parallel(self.model))
|
||||
return self.compute_loss(preds, batch)
|
||||
@ -85,10 +89,12 @@ class DetectionTrainer(BaseTrainer):
|
||||
return keys
|
||||
|
||||
def progress_string(self):
|
||||
"""Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
|
||||
return ('\n' + '%11s' *
|
||||
(4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
|
||||
|
||||
def plot_training_samples(self, batch, ni):
|
||||
"""Plots training samples with their annotations."""
|
||||
plot_images(images=batch['img'],
|
||||
batch_idx=batch['batch_idx'],
|
||||
cls=batch['cls'].squeeze(-1),
|
||||
@ -97,9 +103,11 @@ class DetectionTrainer(BaseTrainer):
|
||||
fname=self.save_dir / f'train_batch{ni}.jpg')
|
||||
|
||||
def plot_metrics(self):
|
||||
"""Plots metrics from a CSV file."""
|
||||
plot_results(file=self.csv) # save results.png
|
||||
|
||||
def plot_training_labels(self):
|
||||
"""Create a labeled training plot of the YOLO model."""
|
||||
boxes = np.concatenate([lb['bboxes'] for lb in self.train_loader.dataset.labels], 0)
|
||||
cls = np.concatenate([lb['cls'] for lb in self.train_loader.dataset.labels], 0)
|
||||
plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir)
|
||||
@ -129,6 +137,7 @@ class Loss:
|
||||
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
|
||||
|
||||
def preprocess(self, targets, batch_size, scale_tensor):
|
||||
"""Preprocesses the target counts and matches with the input batch size to output a tensor."""
|
||||
if targets.shape[0] == 0:
|
||||
out = torch.zeros(batch_size, 0, 5, device=self.device)
|
||||
else:
|
||||
@ -145,6 +154,7 @@ class Loss:
|
||||
return out
|
||||
|
||||
def bbox_decode(self, anchor_points, pred_dist):
|
||||
"""Decode predicted object bounding box coordinates from anchor points and distribution."""
|
||||
if self.use_dfl:
|
||||
b, a, c = pred_dist.shape # batch, anchors, channels
|
||||
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
|
||||
@ -153,6 +163,7 @@ class Loss:
|
||||
return dist2bbox(pred_dist, anchor_points, xywh=False)
|
||||
|
||||
def __call__(self, preds, batch):
|
||||
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
|
||||
loss = torch.zeros(3, device=self.device) # box, cls, dfl
|
||||
feats = preds[1] if isinstance(preds, tuple) else preds
|
||||
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
||||
@ -199,6 +210,7 @@ class Loss:
|
||||
|
||||
|
||||
def train(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Train and optimize YOLO model given training data and device."""
|
||||
model = cfg.model or 'yolov8n.pt'
|
||||
data = cfg.data or 'coco128.yaml' # or yolo.ClassificationDataset("mnist")
|
||||
device = cfg.device if cfg.device is not None else ''
|
||||
|
@ -19,6 +19,7 @@ from ultralytics.yolo.utils.torch_utils import de_parallel
|
||||
class DetectionValidator(BaseValidator):
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||
"""Initialize detection model with necessary variables and settings."""
|
||||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||
self.args.task = 'detect'
|
||||
self.is_coco = False
|
||||
@ -28,6 +29,7 @@ class DetectionValidator(BaseValidator):
|
||||
self.niou = self.iouv.numel()
|
||||
|
||||
def preprocess(self, batch):
|
||||
"""Preprocesses batch of images for YOLO training."""
|
||||
batch['img'] = batch['img'].to(self.device, non_blocking=True)
|
||||
batch['img'] = (batch['img'].half() if self.args.half else batch['img'].float()) / 255
|
||||
for k in ['batch_idx', 'cls', 'bboxes']:
|
||||
@ -40,6 +42,7 @@ class DetectionValidator(BaseValidator):
|
||||
return batch
|
||||
|
||||
def init_metrics(self, model):
|
||||
"""Initialize evaluation metrics for YOLO."""
|
||||
val = self.data.get(self.args.split, '') # validation path
|
||||
self.is_coco = isinstance(val, str) and 'coco' in val and val.endswith(f'{os.sep}val2017.txt') # is COCO
|
||||
self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000))
|
||||
@ -54,9 +57,11 @@ class DetectionValidator(BaseValidator):
|
||||
self.stats = []
|
||||
|
||||
def get_desc(self):
|
||||
"""Return a formatted string summarizing class metrics of YOLO model."""
|
||||
return ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)')
|
||||
|
||||
def postprocess(self, preds):
|
||||
"""Apply Non-maximum suppression to prediction outputs."""
|
||||
preds = ops.non_max_suppression(preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
@ -113,10 +118,12 @@ class DetectionValidator(BaseValidator):
|
||||
self.save_one_txt(predn, self.args.save_conf, shape, file)
|
||||
|
||||
def finalize_metrics(self, *args, **kwargs):
|
||||
"""Set final values for metrics speed and confusion matrix."""
|
||||
self.metrics.speed = self.speed
|
||||
self.metrics.confusion_matrix = self.confusion_matrix
|
||||
|
||||
def get_stats(self):
|
||||
"""Returns metrics statistics and results dictionary."""
|
||||
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
|
||||
if len(stats) and stats[0].any():
|
||||
self.metrics.process(*stats)
|
||||
@ -124,6 +131,7 @@ class DetectionValidator(BaseValidator):
|
||||
return self.metrics.results_dict
|
||||
|
||||
def print_results(self):
|
||||
"""Prints training/validation set metrics per class."""
|
||||
pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metrics.keys) # print format
|
||||
LOGGER.info(pf % ('all', self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
|
||||
if self.nt_per_class.sum() == 0:
|
||||
@ -183,6 +191,7 @@ class DetectionValidator(BaseValidator):
|
||||
mode='val')[0]
|
||||
|
||||
def plot_val_samples(self, batch, ni):
|
||||
"""Plot validation image samples."""
|
||||
plot_images(batch['img'],
|
||||
batch['batch_idx'],
|
||||
batch['cls'].squeeze(-1),
|
||||
@ -192,6 +201,7 @@ class DetectionValidator(BaseValidator):
|
||||
names=self.names)
|
||||
|
||||
def plot_predictions(self, batch, preds, ni):
|
||||
"""Plots predicted bounding boxes on input images and saves the result."""
|
||||
plot_images(batch['img'],
|
||||
*output_to_target(preds, max_det=15),
|
||||
paths=batch['im_file'],
|
||||
@ -199,6 +209,7 @@ class DetectionValidator(BaseValidator):
|
||||
names=self.names) # pred
|
||||
|
||||
def save_one_txt(self, predn, save_conf, shape, file):
|
||||
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""
|
||||
gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh
|
||||
for *xyxy, conf, cls in predn.tolist():
|
||||
xywh = (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
|
||||
@ -207,6 +218,7 @@ class DetectionValidator(BaseValidator):
|
||||
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
||||
|
||||
def pred_to_json(self, predn, filename):
|
||||
"""Serialize YOLO predictions to COCO json format."""
|
||||
stem = Path(filename).stem
|
||||
image_id = int(stem) if stem.isnumeric() else stem
|
||||
box = ops.xyxy2xywh(predn[:, :4]) # xywh
|
||||
@ -219,6 +231,7 @@ class DetectionValidator(BaseValidator):
|
||||
'score': round(p[4], 5)})
|
||||
|
||||
def eval_json(self, stats):
|
||||
"""Evaluates YOLO output in JSON format and returns performance statistics."""
|
||||
if self.args.save_json and self.is_coco and len(self.jdict):
|
||||
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
|
||||
pred_json = self.save_dir / 'predictions.json' # predictions
|
||||
@ -245,6 +258,7 @@ class DetectionValidator(BaseValidator):
|
||||
|
||||
|
||||
def val(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Validate trained YOLO model on validation dataset."""
|
||||
model = cfg.model or 'yolov8n.pt'
|
||||
data = cfg.data or 'coco128.yaml'
|
||||
|
||||
|
@ -8,6 +8,7 @@ from ultralytics.yolo.v8.detect.predict import DetectionPredictor
|
||||
class PosePredictor(DetectionPredictor):
|
||||
|
||||
def postprocess(self, preds, img, orig_img):
|
||||
"""Return detection results for a given input image or list of images."""
|
||||
preds = ops.non_max_suppression(preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
@ -35,6 +36,7 @@ class PosePredictor(DetectionPredictor):
|
||||
|
||||
|
||||
def predict(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Runs YOLO to predict objects in an image or video."""
|
||||
model = cfg.model or 'yolov8n-pose.pt'
|
||||
source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
|
||||
else 'https://ultralytics.com/images/bus.jpg'
|
||||
|
@ -21,12 +21,14 @@ from ultralytics.yolo.v8.detect.train import Loss
|
||||
class PoseTrainer(v8.detect.DetectionTrainer):
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""Initialize a PoseTrainer object with specified configurations and overrides."""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides['task'] = 'pose'
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""Get pose estimation model with specified configuration and weights."""
|
||||
model = PoseModel(cfg, ch=3, nc=self.data['nc'], data_kpt_shape=self.data['kpt_shape'], verbose=verbose)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
@ -34,19 +36,23 @@ class PoseTrainer(v8.detect.DetectionTrainer):
|
||||
return model
|
||||
|
||||
def set_model_attributes(self):
|
||||
"""Sets keypoints shape attribute of PoseModel."""
|
||||
super().set_model_attributes()
|
||||
self.model.kpt_shape = self.data['kpt_shape']
|
||||
|
||||
def get_validator(self):
|
||||
"""Returns an instance of the PoseValidator class for validation."""
|
||||
self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss'
|
||||
return v8.pose.PoseValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
||||
|
||||
def criterion(self, preds, batch):
|
||||
"""Computes pose loss for the YOLO model."""
|
||||
if not hasattr(self, 'compute_loss'):
|
||||
self.compute_loss = PoseLoss(de_parallel(self.model))
|
||||
return self.compute_loss(preds, batch)
|
||||
|
||||
def plot_training_samples(self, batch, ni):
|
||||
"""Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
|
||||
images = batch['img']
|
||||
kpts = batch['keypoints']
|
||||
cls = batch['cls'].squeeze(-1)
|
||||
@ -62,6 +68,7 @@ class PoseTrainer(v8.detect.DetectionTrainer):
|
||||
fname=self.save_dir / f'train_batch{ni}.jpg')
|
||||
|
||||
def plot_metrics(self):
|
||||
"""Plots training/val metrics."""
|
||||
plot_results(file=self.csv, pose=True) # save results.png
|
||||
|
||||
|
||||
@ -78,6 +85,7 @@ class PoseLoss(Loss):
|
||||
self.keypoint_loss = KeypointLoss(sigmas=sigmas)
|
||||
|
||||
def __call__(self, preds, batch):
|
||||
"""Calculate the total loss and detach it."""
|
||||
loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
|
||||
feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
|
||||
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
||||
@ -145,6 +153,7 @@ class PoseLoss(Loss):
|
||||
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
|
||||
|
||||
def kpts_decode(self, anchor_points, pred_kpts):
|
||||
"""Decodes predicted keypoints to image coordinates."""
|
||||
y = pred_kpts.clone()
|
||||
y[..., :2] *= 2.0
|
||||
y[..., 0] += anchor_points[:, [0]] - 0.5
|
||||
@ -153,6 +162,7 @@ class PoseLoss(Loss):
|
||||
|
||||
|
||||
def train(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Train the YOLO model on the given data and device."""
|
||||
model = cfg.model or 'yolov8n-pose.yaml'
|
||||
data = cfg.data or 'coco8-pose.yaml'
|
||||
device = cfg.device if cfg.device is not None else ''
|
||||
|
@ -15,20 +15,24 @@ from ultralytics.yolo.v8.detect import DetectionValidator
|
||||
class PoseValidator(DetectionValidator):
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||
"""Initialize a 'PoseValidator' object with custom parameters and assigned attributes."""
|
||||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||
self.args.task = 'pose'
|
||||
self.metrics = PoseMetrics(save_dir=self.save_dir)
|
||||
|
||||
def preprocess(self, batch):
|
||||
"""Preprocesses the batch by converting the 'keypoints' data into a float and moving it to the device."""
|
||||
batch = super().preprocess(batch)
|
||||
batch['keypoints'] = batch['keypoints'].to(self.device).float()
|
||||
return batch
|
||||
|
||||
def get_desc(self):
|
||||
"""Returns description of evaluation metrics in string format."""
|
||||
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Pose(P',
|
||||
'R', 'mAP50', 'mAP50-95)')
|
||||
|
||||
def postprocess(self, preds):
|
||||
"""Apply non-maximum suppression and return detections with high confidence scores."""
|
||||
preds = ops.non_max_suppression(preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
@ -40,6 +44,7 @@ class PoseValidator(DetectionValidator):
|
||||
return preds
|
||||
|
||||
def init_metrics(self, model):
|
||||
"""Initiate pose estimation metrics for YOLO model."""
|
||||
super().init_metrics(model)
|
||||
self.kpt_shape = self.data['kpt_shape']
|
||||
is_pose = self.kpt_shape == [17, 3]
|
||||
@ -137,6 +142,7 @@ class PoseValidator(DetectionValidator):
|
||||
return torch.tensor(correct, dtype=torch.bool, device=detections.device)
|
||||
|
||||
def plot_val_samples(self, batch, ni):
|
||||
"""Plots and saves validation set samples with predicted bounding boxes and keypoints."""
|
||||
plot_images(batch['img'],
|
||||
batch['batch_idx'],
|
||||
batch['cls'].squeeze(-1),
|
||||
@ -147,6 +153,7 @@ class PoseValidator(DetectionValidator):
|
||||
names=self.names)
|
||||
|
||||
def plot_predictions(self, batch, preds, ni):
|
||||
"""Plots predictions for YOLO model."""
|
||||
pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape)[:15] for p in preds], 0)
|
||||
plot_images(batch['img'],
|
||||
*output_to_target(preds, max_det=15),
|
||||
@ -156,6 +163,7 @@ class PoseValidator(DetectionValidator):
|
||||
names=self.names) # pred
|
||||
|
||||
def pred_to_json(self, predn, filename):
|
||||
"""Converts YOLO predictions to COCO JSON format."""
|
||||
stem = Path(filename).stem
|
||||
image_id = int(stem) if stem.isnumeric() else stem
|
||||
box = ops.xyxy2xywh(predn[:, :4]) # xywh
|
||||
@ -169,6 +177,7 @@ class PoseValidator(DetectionValidator):
|
||||
'score': round(p[4], 5)})
|
||||
|
||||
def eval_json(self, stats):
|
||||
"""Evaluates object detection model using COCO JSON format."""
|
||||
if self.args.save_json and self.is_coco and len(self.jdict):
|
||||
anno_json = self.data['path'] / 'annotations/person_keypoints_val2017.json' # annotations
|
||||
pred_json = self.save_dir / 'predictions.json' # predictions
|
||||
@ -197,6 +206,7 @@ class PoseValidator(DetectionValidator):
|
||||
|
||||
|
||||
def val(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Performs validation on YOLO model using given data."""
|
||||
model = cfg.model or 'yolov8n-pose.pt'
|
||||
data = cfg.data or 'coco8-pose.yaml'
|
||||
|
||||
|
@ -41,6 +41,7 @@ class SegmentationPredictor(DetectionPredictor):
|
||||
|
||||
|
||||
def predict(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Runs YOLO object detection on an image or video source."""
|
||||
model = cfg.model or 'yolov8n-seg.pt'
|
||||
source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
|
||||
else 'https://ultralytics.com/images/bus.jpg'
|
||||
|
@ -18,12 +18,14 @@ from ultralytics.yolo.v8.detect.train import Loss
|
||||
class SegmentationTrainer(v8.detect.DetectionTrainer):
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""Initialize a SegmentationTrainer object with given arguments."""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides['task'] = 'segment'
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""Return SegmentationModel initialized with specified config and weights."""
|
||||
model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
@ -31,15 +33,18 @@ class SegmentationTrainer(v8.detect.DetectionTrainer):
|
||||
return model
|
||||
|
||||
def get_validator(self):
|
||||
"""Return an instance of SegmentationValidator for validation of YOLO model."""
|
||||
self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss'
|
||||
return v8.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
||||
|
||||
def criterion(self, preds, batch):
|
||||
"""Returns the computed loss using the SegLoss class on the given predictions and batch."""
|
||||
if not hasattr(self, 'compute_loss'):
|
||||
self.compute_loss = SegLoss(de_parallel(self.model), overlap=self.args.overlap_mask)
|
||||
return self.compute_loss(preds, batch)
|
||||
|
||||
def plot_training_samples(self, batch, ni):
|
||||
"""Creates a plot of training sample images with labels and box coordinates."""
|
||||
images = batch['img']
|
||||
masks = batch['masks']
|
||||
cls = batch['cls'].squeeze(-1)
|
||||
@ -49,6 +54,7 @@ class SegmentationTrainer(v8.detect.DetectionTrainer):
|
||||
plot_images(images, batch_idx, cls, bboxes, masks, paths=paths, fname=self.save_dir / f'train_batch{ni}.jpg')
|
||||
|
||||
def plot_metrics(self):
|
||||
"""Plots training/val metrics."""
|
||||
plot_results(file=self.csv, segment=True) # save results.png
|
||||
|
||||
|
||||
@ -61,6 +67,7 @@ class SegLoss(Loss):
|
||||
self.overlap = overlap
|
||||
|
||||
def __call__(self, preds, batch):
|
||||
"""Calculate and return the loss for the YOLO model."""
|
||||
loss = torch.zeros(4, device=self.device) # box, cls, dfl
|
||||
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
|
||||
batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
|
||||
@ -147,6 +154,7 @@ class SegLoss(Loss):
|
||||
|
||||
|
||||
def train(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Train a YOLO segmentation model based on passed arguments."""
|
||||
model = cfg.model or 'yolov8n-seg.pt'
|
||||
data = cfg.data or 'coco128-seg.yaml' # or yolo.ClassificationDataset("mnist")
|
||||
device = cfg.device if cfg.device is not None else ''
|
||||
|
@ -17,16 +17,19 @@ from ultralytics.yolo.v8.detect import DetectionValidator
|
||||
class SegmentationValidator(DetectionValidator):
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics."""
|
||||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||
self.args.task = 'segment'
|
||||
self.metrics = SegmentMetrics(save_dir=self.save_dir)
|
||||
|
||||
def preprocess(self, batch):
|
||||
"""Preprocesses batch by converting masks to float and sending to device."""
|
||||
batch = super().preprocess(batch)
|
||||
batch['masks'] = batch['masks'].to(self.device).float()
|
||||
return batch
|
||||
|
||||
def init_metrics(self, model):
|
||||
"""Initialize metrics and select mask processing function based on save_json flag."""
|
||||
super().init_metrics(model)
|
||||
self.plot_masks = []
|
||||
if self.args.save_json:
|
||||
@ -36,10 +39,12 @@ class SegmentationValidator(DetectionValidator):
|
||||
self.process = ops.process_mask # faster
|
||||
|
||||
def get_desc(self):
|
||||
"""Return a formatted description of evaluation metrics."""
|
||||
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Mask(P',
|
||||
'R', 'mAP50', 'mAP50-95)')
|
||||
|
||||
def postprocess(self, preds):
|
||||
"""Postprocesses YOLO predictions and returns output detections with proto."""
|
||||
p = ops.non_max_suppression(preds[0],
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
@ -119,6 +124,7 @@ class SegmentationValidator(DetectionValidator):
|
||||
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
|
||||
|
||||
def finalize_metrics(self, *args, **kwargs):
|
||||
"""Sets speed and confusion matrix for evaluation metrics."""
|
||||
self.metrics.speed = self.speed
|
||||
self.metrics.confusion_matrix = self.confusion_matrix
|
||||
|
||||
@ -160,6 +166,7 @@ class SegmentationValidator(DetectionValidator):
|
||||
return torch.tensor(correct, dtype=torch.bool, device=detections.device)
|
||||
|
||||
def plot_val_samples(self, batch, ni):
|
||||
"""Plots validation samples with bounding box labels."""
|
||||
plot_images(batch['img'],
|
||||
batch['batch_idx'],
|
||||
batch['cls'].squeeze(-1),
|
||||
@ -170,6 +177,7 @@ class SegmentationValidator(DetectionValidator):
|
||||
names=self.names)
|
||||
|
||||
def plot_predictions(self, batch, preds, ni):
|
||||
"""Plots batch predictions with masks and bounding boxes."""
|
||||
plot_images(batch['img'],
|
||||
*output_to_target(preds[0], max_det=15),
|
||||
torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
|
||||
@ -184,6 +192,7 @@ class SegmentationValidator(DetectionValidator):
|
||||
from pycocotools.mask import encode # noqa
|
||||
|
||||
def single_encode(x):
|
||||
"""Encode predicted masks as RLE and append results to jdict."""
|
||||
rle = encode(np.asarray(x[:, :, None], order='F', dtype='uint8'))[0]
|
||||
rle['counts'] = rle['counts'].decode('utf-8')
|
||||
return rle
|
||||
@ -204,6 +213,7 @@ class SegmentationValidator(DetectionValidator):
|
||||
'segmentation': rles[i]})
|
||||
|
||||
def eval_json(self, stats):
|
||||
"""Return COCO-style object detection evaluation metrics."""
|
||||
if self.args.save_json and self.is_coco and len(self.jdict):
|
||||
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
|
||||
pred_json = self.save_dir / 'predictions.json' # predictions
|
||||
@ -232,6 +242,7 @@ class SegmentationValidator(DetectionValidator):
|
||||
|
||||
|
||||
def val(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Validate trained YOLO model on validation data."""
|
||||
model = cfg.model or 'yolov8n-seg.pt'
|
||||
data = cfg.data or 'coco128-seg.yaml'
|
||||
|
||||
|
Reference in New Issue
Block a user