ultralytics 8.0.65
YOLOv8 Pose models (#1347)
Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Mert Can Demir <validatedev@gmail.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Fabian Greavu <fabiangreavu@gmail.com> Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com> Co-authored-by: Eric Pedley <ericpedley@gmail.com> Co-authored-by: JustasBart <40023722+JustasBart@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Aarni Koskela <akx@iki.fi> Co-authored-by: Sergio Sanchez <sergio.ssm.97@gmail.com> Co-authored-by: Bogdan Gheorghe <112427971+bogdan-galileo@users.noreply.github.com> Co-authored-by: Jaap van de Loosdrecht <jaap@vdlmv.nl> Co-authored-by: Noobtoss <96134731+Noobtoss@users.noreply.github.com> Co-authored-by: nerdyespresso <106761627+nerdyespresso@users.noreply.github.com> Co-authored-by: Farid Inawan <frdteknikelektro@gmail.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Alexander Duda <Alexander.Duda@me.com> Co-authored-by: Mehran Ghandehari <mehran.maps@gmail.com> Co-authored-by: Snyk bot <snyk-bot@snyk.io> Co-authored-by: majid nasiri <majnasai@gmail.com>
This commit is contained in:
@ -75,11 +75,13 @@ def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt', imgsz=160, hal
|
||||
|
||||
# Validate
|
||||
if model.task == 'detect':
|
||||
data, key = 'coco128.yaml', 'metrics/mAP50-95(B)'
|
||||
data, key = 'coco8.yaml', 'metrics/mAP50-95(B)'
|
||||
elif model.task == 'segment':
|
||||
data, key = 'coco128-seg.yaml', 'metrics/mAP50-95(M)'
|
||||
data, key = 'coco8-seg.yaml', 'metrics/mAP50-95(M)'
|
||||
elif model.task == 'classify':
|
||||
data, key = 'imagenet100', 'metrics/accuracy_top5'
|
||||
elif model.task == 'pose':
|
||||
data, key = 'coco8-pose.yaml', 'metrics/mAP50-95(P)'
|
||||
|
||||
results = export.val(data=data, batch=1, imgsz=imgsz, plots=False, device=device, half=half, verbose=False)
|
||||
metric, speed = results.results_dict[key], results.speed['inference']
|
||||
|
@ -14,9 +14,9 @@ from tqdm import tqdm
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER, checks, emojis, is_online
|
||||
|
||||
GITHUB_ASSET_NAMES = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] + \
|
||||
[f'yolov5{size}u.pt' for size in 'nsmlx'] + \
|
||||
[f'yolov3{size}u.pt' for size in ('', '-spp', '-tiny')]
|
||||
GITHUB_ASSET_NAMES = [f'yolov8{k}{suffix}.pt' for k in 'nsmlx' for suffix in ('', '6', '-cls', '-seg', '-pose')] + \
|
||||
[f'yolov5{k}u.pt' for k in 'nsmlx'] + \
|
||||
[f'yolov3{k}u.pt' for k in ('', '-spp', '-tiny')]
|
||||
GITHUB_ASSET_STEMS = [Path(k).stem for k in GITHUB_ASSET_NAMES]
|
||||
|
||||
|
||||
|
@ -168,7 +168,7 @@ class Instances:
|
||||
Args:
|
||||
bboxes (ndarray): bboxes with shape [N, 4].
|
||||
segments (list | ndarray): segments.
|
||||
keypoints (ndarray): keypoints with shape [N, 17, 2].
|
||||
keypoints (ndarray): keypoints(x, y, visible) with shape [N, 17, 3].
|
||||
"""
|
||||
if segments is None:
|
||||
segments = []
|
||||
|
@ -54,3 +54,17 @@ class BboxLoss(nn.Module):
|
||||
wr = 1 - wl # weight right
|
||||
return (F.cross_entropy(pred_dist, tl.view(-1), reduction='none').view(tl.shape) * wl +
|
||||
F.cross_entropy(pred_dist, tr.view(-1), reduction='none').view(tl.shape) * wr).mean(-1, keepdim=True)
|
||||
|
||||
|
||||
class KeypointLoss(nn.Module):
|
||||
|
||||
def __init__(self, sigmas) -> None:
|
||||
super().__init__()
|
||||
self.sigmas = sigmas
|
||||
|
||||
def forward(self, pred_kpts, gt_kpts, kpt_mask, area):
|
||||
d = (pred_kpts[..., 0] - gt_kpts[..., 0]) ** 2 + (pred_kpts[..., 1] - gt_kpts[..., 1]) ** 2
|
||||
kpt_loss_factor = (torch.sum(kpt_mask != 0) + torch.sum(kpt_mask == 0)) / (torch.sum(kpt_mask != 0) + 1e-9)
|
||||
# e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula
|
||||
e = d / (2 * self.sigmas) ** 2 / (area + 1e-9) / 2 # from cocoeval
|
||||
return kpt_loss_factor * ((1 - torch.exp(-e)) * kpt_mask).mean()
|
||||
|
@ -13,6 +13,8 @@ import torch.nn as nn
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER, SimpleClass, TryExcept
|
||||
|
||||
OKS_SIGMA = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0
|
||||
|
||||
|
||||
# boxes
|
||||
def box_area(box):
|
||||
@ -108,8 +110,8 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
|
||||
|
||||
def mask_iou(mask1, mask2, eps=1e-7):
|
||||
"""
|
||||
mask1: [N, n] m1 means number of predicted objects
|
||||
mask2: [M, n] m2 means number of gt objects
|
||||
mask1: [N, n] m1 means number of gt objects
|
||||
mask2: [M, n] m2 means number of predicted objects
|
||||
Note: n means image_w x image_h
|
||||
Returns: masks iou, [N, M]
|
||||
"""
|
||||
@ -118,16 +120,18 @@ def mask_iou(mask1, mask2, eps=1e-7):
|
||||
return intersection / (union + eps)
|
||||
|
||||
|
||||
def masks_iou(mask1, mask2, eps=1e-7):
|
||||
def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7):
|
||||
"""OKS
|
||||
kpt1: [N, 17, 3], gt
|
||||
kpt2: [M, 17, 3], pred
|
||||
area: [N], areas from gt
|
||||
"""
|
||||
mask1: [N, n] m1 means number of predicted objects
|
||||
mask2: [N, n] m2 means number of gt objects
|
||||
Note: n means image_w x image_h
|
||||
Returns: masks iou, (N, )
|
||||
"""
|
||||
intersection = (mask1 * mask2).sum(1).clamp(0) # (N, )
|
||||
union = (mask1.sum(1) + mask2.sum(1))[None] - intersection # (area1 + area2) - intersection
|
||||
return intersection / (union + eps)
|
||||
d = (kpt1[:, None, :, 0] - kpt2[..., 0]) ** 2 + (kpt1[:, None, :, 1] - kpt2[..., 1]) ** 2 # (N, M, 17)
|
||||
sigma = torch.tensor(sigma, device=kpt1.device, dtype=kpt1.dtype) # (17, )
|
||||
kpt_mask = kpt1[..., 2] != 0 # (N, 17)
|
||||
e = d / (2 * sigma) ** 2 / (area[:, None, None] + eps) / 2 # from cocoeval
|
||||
# e = d / ((area[None, :, None] + eps) * sigma) ** 2 / 2 # from formula
|
||||
return (torch.exp(-e) * kpt_mask[:, None]).sum(-1) / (kpt_mask.sum(-1)[:, None] + eps)
|
||||
|
||||
|
||||
def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
|
||||
@ -649,13 +653,13 @@ class SegmentMetrics(SimpleClass):
|
||||
self.seg = Metric()
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
|
||||
def process(self, tp_m, tp_b, conf, pred_cls, target_cls):
|
||||
def process(self, tp_b, tp_m, conf, pred_cls, target_cls):
|
||||
"""
|
||||
Processes the detection and segmentation metrics over the given set of predictions.
|
||||
|
||||
Args:
|
||||
tp_m (list): List of True Positive masks.
|
||||
tp_b (list): List of True Positive boxes.
|
||||
tp_m (list): List of True Positive masks.
|
||||
conf (list): List of confidence scores.
|
||||
pred_cls (list): List of predicted classes.
|
||||
target_cls (list): List of target classes.
|
||||
@ -712,6 +716,100 @@ class SegmentMetrics(SimpleClass):
|
||||
return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
|
||||
|
||||
|
||||
class PoseMetrics(SegmentMetrics):
|
||||
"""
|
||||
Calculates and aggregates detection and pose metrics over a given set of classes.
|
||||
|
||||
Args:
|
||||
save_dir (Path): Path to the directory where the output plots should be saved. Default is the current directory.
|
||||
plot (bool): Whether to save the detection and segmentation plots. Default is False.
|
||||
names (list): List of class names. Default is an empty list.
|
||||
|
||||
Attributes:
|
||||
save_dir (Path): Path to the directory where the output plots should be saved.
|
||||
plot (bool): Whether to save the detection and segmentation plots.
|
||||
names (list): List of class names.
|
||||
box (Metric): An instance of the Metric class to calculate box detection metrics.
|
||||
pose (Metric): An instance of the Metric class to calculate mask segmentation metrics.
|
||||
speed (dict): Dictionary to store the time taken in different phases of inference.
|
||||
|
||||
Methods:
|
||||
process(tp_m, tp_b, conf, pred_cls, target_cls): Processes metrics over the given set of predictions.
|
||||
mean_results(): Returns the mean of the detection and segmentation metrics over all the classes.
|
||||
class_result(i): Returns the detection and segmentation metrics of class `i`.
|
||||
maps: Returns the mean Average Precision (mAP) scores for IoU thresholds ranging from 0.50 to 0.95.
|
||||
fitness: Returns the fitness scores, which are a single weighted combination of metrics.
|
||||
ap_class_index: Returns the list of indices of classes used to compute Average Precision (AP).
|
||||
results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
|
||||
"""
|
||||
|
||||
def __init__(self, save_dir=Path('.'), plot=False, names=()) -> None:
|
||||
super().__init__(save_dir, plot, names)
|
||||
self.save_dir = save_dir
|
||||
self.plot = plot
|
||||
self.names = names
|
||||
self.box = Metric()
|
||||
self.pose = Metric()
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
|
||||
def __getattr__(self, attr):
|
||||
name = self.__class__.__name__
|
||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||
|
||||
def process(self, tp_b, tp_p, conf, pred_cls, target_cls):
|
||||
"""
|
||||
Processes the detection and pose metrics over the given set of predictions.
|
||||
|
||||
Args:
|
||||
tp_b (list): List of True Positive boxes.
|
||||
tp_p (list): List of True Positive keypoints.
|
||||
conf (list): List of confidence scores.
|
||||
pred_cls (list): List of predicted classes.
|
||||
target_cls (list): List of target classes.
|
||||
"""
|
||||
|
||||
results_pose = ap_per_class(tp_p,
|
||||
conf,
|
||||
pred_cls,
|
||||
target_cls,
|
||||
plot=self.plot,
|
||||
save_dir=self.save_dir,
|
||||
names=self.names,
|
||||
prefix='Pose')[2:]
|
||||
self.pose.nc = len(self.names)
|
||||
self.pose.update(results_pose)
|
||||
results_box = ap_per_class(tp_b,
|
||||
conf,
|
||||
pred_cls,
|
||||
target_cls,
|
||||
plot=self.plot,
|
||||
save_dir=self.save_dir,
|
||||
names=self.names,
|
||||
prefix='Box')[2:]
|
||||
self.box.nc = len(self.names)
|
||||
self.box.update(results_box)
|
||||
|
||||
@property
|
||||
def keys(self):
|
||||
return [
|
||||
'metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)',
|
||||
'metrics/precision(P)', 'metrics/recall(P)', 'metrics/mAP50(P)', 'metrics/mAP50-95(P)']
|
||||
|
||||
def mean_results(self):
|
||||
return self.box.mean_results() + self.pose.mean_results()
|
||||
|
||||
def class_result(self, i):
|
||||
return self.box.class_result(i) + self.pose.class_result(i)
|
||||
|
||||
@property
|
||||
def maps(self):
|
||||
return self.box.maps + self.pose.maps
|
||||
|
||||
@property
|
||||
def fitness(self):
|
||||
return self.pose.fitness() + self.box.fitness()
|
||||
|
||||
|
||||
class ClassifyMetrics(SimpleClass):
|
||||
"""
|
||||
Class for computing classification metrics including top-1 and top-5 accuracy.
|
||||
|
@ -281,28 +281,23 @@ def clip_boxes(boxes, shape):
|
||||
boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
|
||||
|
||||
|
||||
def clip_coords(boxes, shape):
|
||||
def clip_coords(coords, shape):
|
||||
"""
|
||||
Clip bounding xyxy bounding boxes to image shape (height, width).
|
||||
Clip line coordinates to the image boundaries.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor or numpy.ndarray): Bounding boxes to be clipped.
|
||||
shape (tuple): The shape of the image. (height, width)
|
||||
coords (torch.Tensor) or (numpy.ndarray): A list of line coordinates.
|
||||
shape (tuple): A tuple of integers representing the size of the image in the format (height, width).
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Note:
|
||||
The input `boxes` is modified in-place, there is no return value.
|
||||
(None): The function modifies the input `coordinates` in place, by clipping each coordinate to the image boundaries.
|
||||
"""
|
||||
if isinstance(boxes, torch.Tensor): # faster individually
|
||||
boxes[:, 0].clamp_(0, shape[1]) # x1
|
||||
boxes[:, 1].clamp_(0, shape[0]) # y1
|
||||
boxes[:, 2].clamp_(0, shape[1]) # x2
|
||||
boxes[:, 3].clamp_(0, shape[0]) # y2
|
||||
if isinstance(coords, torch.Tensor): # faster individually
|
||||
coords[..., 0].clamp_(0, shape[1]) # x
|
||||
coords[..., 1].clamp_(0, shape[0]) # y
|
||||
else: # np.array (faster grouped)
|
||||
boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
|
||||
boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
|
||||
coords[..., 0] = coords[..., 0].clip(0, shape[1]) # x
|
||||
coords[..., 1] = coords[..., 1].clip(0, shape[0]) # y
|
||||
|
||||
|
||||
def scale_image(im1_shape, masks, im0_shape, ratio_pad=None):
|
||||
@ -577,17 +572,18 @@ def process_mask_upsample(protos, masks_in, bboxes, shape):
|
||||
|
||||
def process_mask(protos, masks_in, bboxes, shape, upsample=False):
|
||||
"""
|
||||
It takes the output of the mask head, and applies the mask to the bounding boxes. This is faster but produces
|
||||
downsampled quality of mask
|
||||
Apply masks to bounding boxes using the output of the mask head.
|
||||
|
||||
Args:
|
||||
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
|
||||
masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
|
||||
bboxes (torch.Tensor): [n, 4], n is number of masks after nms
|
||||
shape (tuple): the size of the input image (h,w)
|
||||
protos (torch.Tensor): A tensor of shape [mask_dim, mask_h, mask_w].
|
||||
masks_in (torch.Tensor): A tensor of shape [n, mask_dim], where n is the number of masks after NMS.
|
||||
bboxes (torch.Tensor): A tensor of shape [n, 4], where n is the number of masks after NMS.
|
||||
shape (tuple): A tuple of integers representing the size of the input image in the format (h, w).
|
||||
upsample (bool): A flag to indicate whether to upsample the mask to the original image size. Default is False.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The processed masks.
|
||||
(torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w
|
||||
are the height and width of the input image. The mask is applied to the bounding boxes.
|
||||
"""
|
||||
|
||||
c, mh, mw = protos.shape # CHW
|
||||
@ -632,19 +628,19 @@ def process_mask_native(protos, masks_in, bboxes, shape):
|
||||
return masks.gt_(0.5)
|
||||
|
||||
|
||||
def scale_segments(img1_shape, segments, img0_shape, ratio_pad=None, normalize=False):
|
||||
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False):
|
||||
"""
|
||||
Rescale segment coordinates (xyxy) from img1_shape to img0_shape
|
||||
|
||||
Args:
|
||||
img1_shape (tuple): The shape of the image that the segments are from.
|
||||
segments (torch.Tensor): the segments to be scaled
|
||||
img1_shape (tuple): The shape of the image that the coords are from.
|
||||
coords (torch.Tensor): the coords to be scaled
|
||||
img0_shape (tuple): the shape of the image that the segmentation is being applied to
|
||||
ratio_pad (tuple): the ratio of the image size to the padded image size.
|
||||
normalize (bool): If True, the coordinates will be normalized to the range [0, 1]. Defaults to False
|
||||
|
||||
Returns:
|
||||
segments (torch.Tensor): the segmented image.
|
||||
coords (torch.Tensor): the segmented image.
|
||||
"""
|
||||
if ratio_pad is None: # calculate from img0_shape
|
||||
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
||||
@ -653,14 +649,15 @@ def scale_segments(img1_shape, segments, img0_shape, ratio_pad=None, normalize=F
|
||||
gain = ratio_pad[0][0]
|
||||
pad = ratio_pad[1]
|
||||
|
||||
segments[:, 0] -= pad[0] # x padding
|
||||
segments[:, 1] -= pad[1] # y padding
|
||||
segments /= gain
|
||||
clip_segments(segments, img0_shape)
|
||||
coords[..., 0] -= pad[0] # x padding
|
||||
coords[..., 1] -= pad[1] # y padding
|
||||
coords[..., 0] /= gain
|
||||
coords[..., 1] /= gain
|
||||
clip_coords(coords, img0_shape)
|
||||
if normalize:
|
||||
segments[:, 0] /= img0_shape[1] # width
|
||||
segments[:, 1] /= img0_shape[0] # height
|
||||
return segments
|
||||
coords[..., 0] /= img0_shape[1] # width
|
||||
coords[..., 1] /= img0_shape[0] # height
|
||||
return coords
|
||||
|
||||
|
||||
def masks2segments(masks, strategy='largest'):
|
||||
@ -688,23 +685,6 @@ def masks2segments(masks, strategy='largest'):
|
||||
return segments
|
||||
|
||||
|
||||
def clip_segments(segments, shape):
|
||||
"""
|
||||
It takes a list of line segments (x1,y1,x2,y2) and clips them to the image shape (height, width)
|
||||
|
||||
Args:
|
||||
segments (list): a list of segments, each segment is a list of points, each point is a list of x,y
|
||||
coordinates
|
||||
shape (tuple): the shape of the image
|
||||
"""
|
||||
if isinstance(segments, torch.Tensor): # faster individually
|
||||
segments[:, 0].clamp_(0, shape[1]) # x
|
||||
segments[:, 1].clamp_(0, shape[0]) # y
|
||||
else: # np.array (faster grouped)
|
||||
segments[:, 0] = segments[:, 0].clip(0, shape[1]) # x
|
||||
segments[:, 1] = segments[:, 1].clip(0, shape[0]) # y
|
||||
|
||||
|
||||
def clean_str(s):
|
||||
"""
|
||||
Cleans a string by replacing special characters with underscore _
|
||||
|
@ -16,7 +16,7 @@ from ultralytics.yolo.utils import LOGGER, TryExcept, threaded
|
||||
|
||||
from .checks import check_font, check_version, is_ascii
|
||||
from .files import increment_path
|
||||
from .ops import clip_coords, scale_image, xywh2xyxy, xyxy2xywh
|
||||
from .ops import clip_boxes, scale_image, xywh2xyxy, xyxy2xywh
|
||||
|
||||
matplotlib.rc('font', **{'size': 11})
|
||||
matplotlib.use('Agg') # for writing to files only
|
||||
@ -30,6 +30,11 @@ class Colors:
|
||||
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
|
||||
self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
|
||||
self.n = len(self.palette)
|
||||
self.pose_palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102], [230, 230, 0], [255, 153, 255],
|
||||
[153, 204, 255], [255, 102, 255], [255, 51, 255], [102, 178, 255], [51, 153, 255],
|
||||
[255, 153, 153], [255, 102, 102], [255, 51, 51], [153, 255, 153], [102, 255, 102],
|
||||
[51, 255, 51], [0, 255, 0], [0, 0, 255], [255, 0, 0], [255, 255, 255]],
|
||||
dtype=np.uint8)
|
||||
|
||||
def __call__(self, i, bgr=False):
|
||||
c = self.palette[int(i) % self.n]
|
||||
@ -62,6 +67,12 @@ class Annotator:
|
||||
else: # use cv2
|
||||
self.im = im
|
||||
self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
|
||||
# pose
|
||||
self.skeleton = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13], [6, 7], [6, 8], [7, 9],
|
||||
[8, 10], [9, 11], [2, 3], [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]]
|
||||
|
||||
self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]]
|
||||
self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]
|
||||
|
||||
def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
|
||||
# Add one xyxy box to image with label
|
||||
@ -132,6 +143,49 @@ class Annotator:
|
||||
# convert im back to PIL and update draw
|
||||
self.fromarray(self.im)
|
||||
|
||||
def kpts(self, kpts, shape=(640, 640), radius=5, kpt_line=True):
|
||||
"""Plot keypoints.
|
||||
Args:
|
||||
kpts (tensor): predicted kpts, shape: [17, 3]
|
||||
shape (tuple): image shape, (h, w)
|
||||
steps (int): keypoints step
|
||||
radius (int): size of drawing points
|
||||
"""
|
||||
if self.pil:
|
||||
# convert to numpy first
|
||||
self.im = np.asarray(self.im).copy()
|
||||
nkpt, ndim = kpts.shape
|
||||
is_pose = nkpt == 17 and ndim == 3
|
||||
kpt_line &= is_pose # `kpt_line=True` for now only supports human pose plotting
|
||||
for i, k in enumerate(kpts):
|
||||
color_k = [int(x) for x in self.kpt_color[i]] if is_pose else colors(i)
|
||||
x_coord, y_coord = k[0], k[1]
|
||||
if x_coord % shape[1] != 0 and y_coord % shape[0] != 0:
|
||||
if len(k) == 3:
|
||||
conf = k[2]
|
||||
if conf < 0.5:
|
||||
continue
|
||||
cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, color_k, -1)
|
||||
|
||||
if kpt_line:
|
||||
ndim = kpts.shape[-1]
|
||||
for sk_id, sk in enumerate(self.skeleton):
|
||||
pos1 = (int(kpts[(sk[0] - 1), 0]), int(kpts[(sk[0] - 1), 1]))
|
||||
pos2 = (int(kpts[(sk[1] - 1), 0]), int(kpts[(sk[1] - 1), 1]))
|
||||
if ndim == 3:
|
||||
conf1 = kpts[(sk[0] - 1), 2]
|
||||
conf2 = kpts[(sk[1] - 1), 2]
|
||||
if conf1 < 0.5 or conf2 < 0.5:
|
||||
continue
|
||||
if pos1[0] % shape[1] == 0 or pos1[1] % shape[0] == 0 or pos1[0] < 0 or pos1[1] < 0:
|
||||
continue
|
||||
if pos2[0] % shape[1] == 0 or pos2[1] % shape[0] == 0 or pos2[0] < 0 or pos2[1] < 0:
|
||||
continue
|
||||
cv2.line(self.im, pos1, pos2, [int(x) for x in self.limb_color[sk_id]], thickness=2)
|
||||
if self.pil:
|
||||
# convert im back to PIL and update draw
|
||||
self.fromarray(self.im)
|
||||
|
||||
def rectangle(self, xy, fill=None, outline=None, width=1):
|
||||
# Add rectangle to image (PIL-only)
|
||||
self.draw.rectangle(xy, fill, outline, width)
|
||||
@ -213,7 +267,7 @@ def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False,
|
||||
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
|
||||
b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
|
||||
xyxy = xywh2xyxy(b).long()
|
||||
clip_coords(xyxy, im.shape)
|
||||
clip_boxes(xyxy, im.shape)
|
||||
crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
|
||||
if save:
|
||||
file.parent.mkdir(parents=True, exist_ok=True) # make directory
|
||||
@ -229,6 +283,7 @@ def plot_images(images,
|
||||
cls,
|
||||
bboxes,
|
||||
masks=np.zeros(0, dtype=np.uint8),
|
||||
kpts=np.zeros((0, 51), dtype=np.float32),
|
||||
paths=None,
|
||||
fname='images.jpg',
|
||||
names=None):
|
||||
@ -241,6 +296,8 @@ def plot_images(images,
|
||||
bboxes = bboxes.cpu().numpy()
|
||||
if isinstance(masks, torch.Tensor):
|
||||
masks = masks.cpu().numpy().astype(int)
|
||||
if isinstance(kpts, torch.Tensor):
|
||||
kpts = kpts.cpu().numpy()
|
||||
if isinstance(batch_idx, torch.Tensor):
|
||||
batch_idx = batch_idx.cpu().numpy()
|
||||
|
||||
@ -300,6 +357,21 @@ def plot_images(images,
|
||||
label = f'{c}' if labels else f'{c} {conf[j]:.1f}'
|
||||
annotator.box_label(box, label, color=color)
|
||||
|
||||
# Plot keypoints
|
||||
if len(kpts):
|
||||
kpts_ = kpts[idx].copy()
|
||||
if len(kpts_):
|
||||
if kpts_[..., 0].max() <= 1.01 or kpts_[..., 1].max() <= 1.01: # if normalized with tolerance .01
|
||||
kpts_[..., 0] *= w # scale to pixels
|
||||
kpts_[..., 1] *= h
|
||||
elif scale < 1: # absolute coords need scale if image scales
|
||||
kpts_ *= scale
|
||||
kpts_[..., 0] += x
|
||||
kpts_[..., 1] += y
|
||||
for j in range(len(kpts_)):
|
||||
if labels or conf[j] > 0.25: # 0.25 conf thresh
|
||||
annotator.kpts(kpts_[j])
|
||||
|
||||
# Plot masks
|
||||
if len(masks):
|
||||
if idx.shape[0] == masks.shape[0]: # overlap_masks=False
|
||||
@ -307,7 +379,7 @@ def plot_images(images,
|
||||
else: # overlap_masks=True
|
||||
image_masks = masks[[i]] # (1, 640, 640)
|
||||
nl = idx.sum()
|
||||
index = np.arange(nl).reshape(nl, 1, 1) + 1
|
||||
index = np.arange(nl).reshape((nl, 1, 1)) + 1
|
||||
image_masks = np.repeat(image_masks, nl, axis=0)
|
||||
image_masks = np.where(image_masks == index, 1.0, 0.0)
|
||||
|
||||
@ -328,13 +400,16 @@ def plot_images(images,
|
||||
annotator.im.save(fname) # save
|
||||
|
||||
|
||||
def plot_results(file='path/to/results.csv', dir='', segment=False):
|
||||
def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False):
|
||||
# Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
|
||||
import pandas as pd
|
||||
save_dir = Path(file).parent if file else Path(dir)
|
||||
if segment:
|
||||
fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)
|
||||
index = [1, 2, 3, 4, 5, 6, 9, 10, 13, 14, 15, 16, 7, 8, 11, 12]
|
||||
elif pose:
|
||||
fig, ax = plt.subplots(2, 9, figsize=(21, 6), tight_layout=True)
|
||||
index = [1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16, 17, 18, 8, 9, 12, 13]
|
||||
else:
|
||||
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
|
||||
index = [1, 2, 3, 4, 5, 8, 9, 10, 6, 7]
|
||||
|
@ -240,8 +240,8 @@ def copy_attr(a, b, include=(), exclude=()):
|
||||
|
||||
|
||||
def get_latest_opset():
|
||||
# Return max supported ONNX opset by this version of torch
|
||||
return max(int(k[14:]) for k in vars(torch.onnx) if 'symbolic_opset' in k) # opset
|
||||
# Return second-most (for maturity) recently supported ONNX opset by this version of torch
|
||||
return max(int(k[14:]) for k in vars(torch.onnx) if 'symbolic_opset' in k) - 1 # opset
|
||||
|
||||
|
||||
def intersect_dicts(da, db, exclude=()):
|
||||
@ -318,18 +318,18 @@ def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None:
|
||||
"""
|
||||
Strip optimizer from 'f' to finalize training, optionally save as 's'.
|
||||
|
||||
Usage:
|
||||
from ultralytics.yolo.utils.torch_utils import strip_optimizer
|
||||
from pathlib import Path
|
||||
for f in Path('/Users/glennjocher/Downloads/weights').glob('*.pt'):
|
||||
strip_optimizer(f)
|
||||
|
||||
Args:
|
||||
f (str): file path to model to strip the optimizer from. Default is 'best.pt'.
|
||||
s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Usage:
|
||||
from pathlib import Path
|
||||
from ultralytics.yolo.utils.torch_utils import strip_optimizer
|
||||
for f in Path('/Users/glennjocher/Downloads/weights').rglob('*.pt'):
|
||||
strip_optimizer(f)
|
||||
"""
|
||||
x = torch.load(f, map_location=torch.device('cpu'))
|
||||
args = {**DEFAULT_CFG_DICT, **x['train_args']} # combine model args with default args, preferring model args
|
||||
@ -349,7 +349,9 @@ def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None:
|
||||
|
||||
|
||||
def profile(input, ops, n=10, device=None):
|
||||
""" YOLOv8 speed/memory/FLOPs profiler
|
||||
"""
|
||||
YOLOv8 speed/memory/FLOPs profiler
|
||||
|
||||
Usage:
|
||||
input = torch.randn(16, 3, 640, 640)
|
||||
m1 = lambda x: x * torch.sigmoid(x)
|
||||
|
Reference in New Issue
Block a user