ultralytics 8.0.107
(#2778)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Peter van Lunteren <contact@pvanlunteren.com>
This commit is contained in:
@ -766,9 +766,17 @@ def v8_transforms(dataset, imgsz, hyp):
|
||||
pre_transform=LetterBox(new_shape=(imgsz, imgsz)),
|
||||
)])
|
||||
flip_idx = dataset.data.get('flip_idx', None) # for keypoints augmentation
|
||||
if dataset.use_keypoints and flip_idx is None and hyp.fliplr > 0.0:
|
||||
hyp.fliplr = 0.0
|
||||
LOGGER.warning("WARNING ⚠️ No `flip_idx` provided while training keypoints, setting augmentation 'fliplr=0.0'")
|
||||
if dataset.use_keypoints:
|
||||
kpt_shape = dataset.data.get('kpt_shape', None)
|
||||
if flip_idx is None and hyp.fliplr > 0.0:
|
||||
hyp.fliplr = 0.0
|
||||
LOGGER.warning("WARNING ⚠️ No 'flip_idx' array defined in data.yaml, setting augmentation 'fliplr=0.0'")
|
||||
elif flip_idx:
|
||||
if len(flip_idx) != kpt_shape[0]:
|
||||
raise ValueError(f'data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}')
|
||||
elif flip_idx[0] != 0:
|
||||
raise ValueError(f'data.yaml flip_idx={flip_idx} must be zero-index (start from 0)')
|
||||
|
||||
return Compose([
|
||||
pre_transform,
|
||||
MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
|
||||
|
@ -266,7 +266,7 @@ def check_det_dataset(dataset, autodownload=True):
|
||||
return data # dictionary
|
||||
|
||||
|
||||
def check_cls_dataset(dataset: str):
|
||||
def check_cls_dataset(dataset: str, split=''):
|
||||
"""
|
||||
Check a classification dataset such as Imagenet.
|
||||
|
||||
@ -275,6 +275,7 @@ def check_cls_dataset(dataset: str):
|
||||
|
||||
Args:
|
||||
dataset (str): Name of the dataset.
|
||||
split (str, optional): Dataset split, either 'val', 'test', or ''. Defaults to ''.
|
||||
|
||||
Returns:
|
||||
data (dict): A dictionary containing the following keys and values:
|
||||
@ -298,10 +299,15 @@ def check_cls_dataset(dataset: str):
|
||||
train_set = data_dir / 'train'
|
||||
val_set = data_dir / 'val' if (data_dir / 'val').exists() else None # data/test or data/val
|
||||
test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test
|
||||
if split == 'val' and not val_set:
|
||||
LOGGER.info("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
|
||||
elif split == 'test' and not test_set:
|
||||
LOGGER.info("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.")
|
||||
|
||||
nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
|
||||
names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
|
||||
names = dict(enumerate(sorted(names)))
|
||||
return {'train': train_set, 'val': val_set, 'test': test_set, 'nc': nc, 'names': names}
|
||||
return {'train': train_set, 'val': val_set or test_set, 'test': test_set or val_set, 'nc': nc, 'names': names}
|
||||
|
||||
|
||||
class HUBDatasetStats():
|
||||
|
@ -126,7 +126,7 @@ class BaseValidator:
|
||||
if isinstance(self.args.data, str) and self.args.data.endswith('.yaml'):
|
||||
self.data = check_det_dataset(self.args.data)
|
||||
elif self.args.task == 'classify':
|
||||
self.data = check_cls_dataset(self.args.data)
|
||||
self.data = check_cls_dataset(self.args.data, split=self.args.split)
|
||||
else:
|
||||
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
|
||||
|
||||
|
@ -3,7 +3,8 @@
|
||||
Benchmark a YOLO model formats for speed and accuracy
|
||||
|
||||
Usage:
|
||||
from ultralytics.yolo.utils.benchmarks import run_benchmarks
|
||||
from ultralytics.yolo.utils.benchmarks import ProfileModels, run_benchmarks
|
||||
ProfileModels(['yolov8n.yaml', 'yolov8s.yaml'])
|
||||
run_benchmarks(model='yolov8n.pt', imgsz=160)
|
||||
|
||||
Format | `format=argument` | Model
|
||||
@ -22,14 +23,19 @@ TensorFlow.js | `tfjs` | yolov8n_web_model/
|
||||
PaddlePaddle | `paddle` | yolov8n_paddle_model/
|
||||
"""
|
||||
|
||||
import glob
|
||||
import platform
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch.cuda
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.yolo.engine.exporter import export_formats
|
||||
from ultralytics.yolo.utils import LINUX, LOGGER, MACOS, ROOT, SETTINGS
|
||||
from ultralytics.yolo.utils.checks import check_yolo
|
||||
from ultralytics.yolo.utils.checks import check_requirements, check_yolo
|
||||
from ultralytics.yolo.utils.downloads import download
|
||||
from ultralytics.yolo.utils.files import file_size
|
||||
from ultralytics.yolo.utils.torch_utils import select_device
|
||||
@ -140,5 +146,140 @@ def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
|
||||
return df
|
||||
|
||||
|
||||
class ProfileModels:
|
||||
"""
|
||||
ProfileModels class for profiling different models on ONNX and TensorRT.
|
||||
|
||||
This class profiles the performance of different models, provided their paths. The profiling includes parameters such as
|
||||
model speed and FLOPs.
|
||||
|
||||
Attributes:
|
||||
paths (list): Paths of the models to profile.
|
||||
num_timed_runs (int): Number of timed runs for the profiling. Default is 100.
|
||||
num_warmup_runs (int): Number of warmup runs before profiling. Default is 3.
|
||||
imgsz (int): Image size used in the models. Default is 640.
|
||||
|
||||
Methods:
|
||||
profile(): Profiles the models and prints the result.
|
||||
"""
|
||||
|
||||
def __init__(self, paths: list, num_timed_runs=100, num_warmup_runs=3, imgsz=640, trt=True):
|
||||
self.paths = paths
|
||||
self.num_timed_runs = num_timed_runs
|
||||
self.num_warmup_runs = num_warmup_runs
|
||||
self.imgsz = imgsz
|
||||
self.trt = trt # run TensorRT profiling
|
||||
self.profile() # run profiling
|
||||
|
||||
def profile(self):
|
||||
files = self.get_files()
|
||||
|
||||
if not files:
|
||||
print('No matching *.pt or *.onnx files found.')
|
||||
return
|
||||
|
||||
table_rows = []
|
||||
device = 0 if torch.cuda.is_available() else 'cpu'
|
||||
for file in files:
|
||||
engine_file = ''
|
||||
if file.suffix in ('.pt', '.yaml'):
|
||||
model = YOLO(str(file))
|
||||
num_params, num_flops = model.info()
|
||||
if self.trt and device == 0:
|
||||
engine_file = model.export(format='engine', half=True, imgsz=self.imgsz, device=device)
|
||||
onnx_file = model.export(format='onnx', half=True, imgsz=self.imgsz, simplify=True, device=device)
|
||||
elif file.suffix == '.onnx':
|
||||
num_params, num_flops = self.get_onnx_model_info(file)
|
||||
onnx_file = file
|
||||
else:
|
||||
continue
|
||||
|
||||
t_engine = self.profile_tensorrt_model(str(engine_file))
|
||||
t_onnx = self.profile_onnx_model(str(onnx_file))
|
||||
table_rows.append(self.generate_table_row(file.stem, t_onnx, t_engine, num_params, num_flops))
|
||||
|
||||
self.print_table(table_rows)
|
||||
|
||||
def get_files(self):
|
||||
files = []
|
||||
for path in self.paths:
|
||||
path = Path(path)
|
||||
if path.is_dir():
|
||||
extensions = ['*.pt', '*.onnx', '*.yaml']
|
||||
files.extend([file for ext in extensions for file in glob.glob(str(path / ext))])
|
||||
elif path.suffix in {'.pt', '.yaml'}: # add non-existing
|
||||
files.append(str(path))
|
||||
else:
|
||||
files.extend(glob.glob(str(path)))
|
||||
|
||||
print(f'Profiling: {sorted(files)}')
|
||||
return [Path(file) for file in sorted(files)]
|
||||
|
||||
def get_onnx_model_info(self, onnx_file: str):
|
||||
return 0.0, 0.0
|
||||
|
||||
def profile_tensorrt_model(self, engine_file: str):
|
||||
if not Path(engine_file).is_file():
|
||||
return 0.0, 0.0
|
||||
|
||||
# Warmup runs
|
||||
model = YOLO(engine_file)
|
||||
input_data = np.random.rand(self.imgsz, self.imgsz, 3).astype(np.float32)
|
||||
for _ in range(self.num_warmup_runs):
|
||||
model(input_data, verbose=False)
|
||||
|
||||
# Timed runs
|
||||
run_times = []
|
||||
for _ in tqdm(range(self.num_timed_runs), desc=engine_file):
|
||||
results = model(input_data, verbose=False)
|
||||
run_times.append(results[0].speed['inference']) # Convert to milliseconds
|
||||
|
||||
return np.mean(run_times), np.std(run_times)
|
||||
|
||||
def profile_onnx_model(self, onnx_file: str):
|
||||
check_requirements('onnxruntime')
|
||||
import onnxruntime as ort
|
||||
|
||||
# Session with either 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'
|
||||
sess_options = ort.SessionOptions()
|
||||
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
sess = ort.InferenceSession(onnx_file, sess_options, providers=['CPUExecutionProvider'])
|
||||
|
||||
input_tensor = sess.get_inputs()[0]
|
||||
input_data = np.random.rand(*input_tensor.shape).astype(np.float16 if torch.cuda.is_available() else np.float32)
|
||||
input_name = input_tensor.name
|
||||
output_name = sess.get_outputs()[0].name
|
||||
|
||||
# Warmup runs
|
||||
for _ in range(self.num_warmup_runs):
|
||||
sess.run([output_name], {input_name: input_data})
|
||||
|
||||
# Timed runs
|
||||
run_times = []
|
||||
for _ in tqdm(range(self.num_timed_runs), desc=onnx_file):
|
||||
start_time = time.time()
|
||||
sess.run([output_name], {input_name: input_data})
|
||||
run_times.append((time.time() - start_time) * 1000) # Convert to milliseconds
|
||||
|
||||
return np.mean(run_times), np.std(run_times)
|
||||
|
||||
def generate_table_row(self, model_name, t_onnx, t_engine, num_params, num_flops):
|
||||
return f'| {model_name} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± {t_engine[1]:.2f} ms | {num_params / 1e6:.1f} | {num_flops:.1f} |'
|
||||
|
||||
def print_table(self, table_rows):
|
||||
gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'GPU'
|
||||
header = f'| Model | size<br><sup>(pixels) | mAP<sup>val<br>50-95 | Speed<br><sup>CPU ONNX<br>(ms) | Speed<br><sup>{gpu} TensorRT<br>(ms) | params<br><sup>(M) | FLOPs<br><sup>(B) |'
|
||||
separator = '|-------------|---------------------|--------------------|------------------------------|-----------------------------------|------------------|-----------------|'
|
||||
|
||||
print(header)
|
||||
print(separator)
|
||||
for row in table_rows:
|
||||
print(row)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Benchmark all export formats
|
||||
benchmark()
|
||||
|
||||
# Profiling models on ONNX and TensorRT
|
||||
ProfileModels(['yolov8n.yaml', 'yolov8s.yaml'])
|
||||
|
@ -67,7 +67,7 @@ def box_iou(box1, box2, eps=1e-7):
|
||||
|
||||
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
|
||||
(a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
|
||||
inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
|
||||
inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp_(0).prod(2)
|
||||
|
||||
# IoU = inter / (area1 + area2 - inter)
|
||||
return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)
|
||||
@ -104,8 +104,8 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
|
||||
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
|
||||
|
||||
# Intersection area
|
||||
inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \
|
||||
(b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0)
|
||||
inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * \
|
||||
(b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp_(0)
|
||||
|
||||
# Union Area
|
||||
union = w1 * h1 + w2 * h2 - inter + eps
|
||||
@ -143,7 +143,7 @@ def mask_iou(mask1, mask2, eps=1e-7):
|
||||
Returns:
|
||||
(torch.Tensor): A tensor of shape (N, M) representing masks IoU.
|
||||
"""
|
||||
intersection = torch.matmul(mask1, mask2.t()).clamp(0)
|
||||
intersection = torch.matmul(mask1, mask2.T).clamp_(0)
|
||||
union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection # (area1 + area2) - intersection
|
||||
return intersection / (union + eps)
|
||||
|
||||
|
@ -10,6 +10,7 @@ import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from PIL import __version__ as pil_version
|
||||
from scipy.ndimage.filters import gaussian_filter1d
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER, TryExcept, plt_settings, threaded
|
||||
|
||||
@ -455,7 +456,8 @@ def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False,
|
||||
for i, j in enumerate(index):
|
||||
y = data.values[:, j].astype('float')
|
||||
# y[y == 0] = np.nan # don't show zero values
|
||||
ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8)
|
||||
ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8) # actual results
|
||||
ax[i].plot(x, gaussian_filter1d(y, sigma=3), ':', label='smooth', linewidth=2) # smoothing line
|
||||
ax[i].set_title(s[j], fontsize=12)
|
||||
# if j in [8, 9, 10]: # share train and val loss y axes
|
||||
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
|
||||
|
@ -41,7 +41,7 @@ def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
|
||||
# (b, n_max_boxes, h*w) -> (b, h*w)
|
||||
fg_mask = mask_pos.sum(-2)
|
||||
if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
|
||||
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).repeat([1, n_max_boxes, 1]) # (b, n_max_boxes, h*w)
|
||||
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)
|
||||
max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
|
||||
|
||||
is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
|
||||
@ -132,7 +132,7 @@ class TaskAlignedAssigner(nn.Module):
|
||||
# Get anchor_align metric, (b, max_num_obj, h*w)
|
||||
align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
|
||||
# Get topk_metric mask, (b, max_num_obj, h*w)
|
||||
mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.repeat([1, 1, self.topk]).bool())
|
||||
mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool())
|
||||
# Merge all mask to a final mask, (b, max_num_obj, h*w)
|
||||
mask_pos = mask_topk * mask_in_gts * mask_gt
|
||||
|
||||
@ -146,15 +146,15 @@ class TaskAlignedAssigner(nn.Module):
|
||||
bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
|
||||
|
||||
ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj
|
||||
ind[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.n_max_boxes) # b, max_num_obj
|
||||
ind[1] = gt_labels.long().squeeze(-1) # b, max_num_obj
|
||||
ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes) # b, max_num_obj
|
||||
ind[1] = gt_labels.squeeze(-1) # b, max_num_obj
|
||||
# Get the scores of each grid for each gt cls
|
||||
bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt] # b, max_num_obj, h*w
|
||||
|
||||
# (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
|
||||
pd_boxes = pd_bboxes.unsqueeze(1).repeat(1, self.n_max_boxes, 1, 1)[mask_gt]
|
||||
gt_boxes = gt_bboxes.unsqueeze(2).repeat(1, 1, na, 1)[mask_gt]
|
||||
overlaps[mask_gt] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp(0)
|
||||
pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]
|
||||
gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]
|
||||
overlaps[mask_gt] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
|
||||
|
||||
align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
|
||||
return align_metric, overlaps
|
||||
@ -273,4 +273,4 @@ def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
|
||||
def bbox2dist(anchor_points, bbox, reg_max):
|
||||
"""Transform bbox(xyxy) to dist(ltrb)."""
|
||||
x1y1, x2y2 = bbox.chunk(2, -1)
|
||||
return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp(0, reg_max - 0.01) # dist (lt, rb)
|
||||
return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01) # dist (lt, rb)
|
||||
|
@ -205,6 +205,20 @@ def get_flops(model, imgsz=640):
|
||||
return 0
|
||||
|
||||
|
||||
def get_flops_with_torch_profiler(model, imgsz=640):
|
||||
# Compute model FLOPs (thop alternative)
|
||||
model = de_parallel(model)
|
||||
p = next(model.parameters())
|
||||
stride = (max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32) * 2 # max stride
|
||||
im = torch.zeros((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
||||
with torch.profiler.profile(with_flops=True) as prof:
|
||||
model(im)
|
||||
flops = sum(x.flops for x in prof.key_averages()) / 1E9
|
||||
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
|
||||
flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
|
||||
return flops
|
||||
|
||||
|
||||
def initialize_weights(model):
|
||||
"""Initialize model weights to random values."""
|
||||
for m in model.modules():
|
||||
|
@ -60,8 +60,7 @@ class ClassificationValidator(BaseValidator):
|
||||
return self.metrics.results_dict
|
||||
|
||||
def build_dataset(self, img_path):
|
||||
dataset = ClassificationDataset(root=img_path, args=self.args, augment=False)
|
||||
return dataset
|
||||
return ClassificationDataset(root=img_path, args=self.args, augment=False)
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size):
|
||||
"""Builds and returns a data loader for classification tasks with given parameters."""
|
||||
|
Reference in New Issue
Block a user