ultralytics 8.0.107 (#2778)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Peter van Lunteren <contact@pvanlunteren.com>
This commit is contained in:
Glenn Jocher
2023-05-23 13:50:24 +02:00
committed by GitHub
parent 4db686a315
commit dada5b73c4
23 changed files with 236 additions and 73 deletions

View File

@ -766,9 +766,17 @@ def v8_transforms(dataset, imgsz, hyp):
pre_transform=LetterBox(new_shape=(imgsz, imgsz)),
)])
flip_idx = dataset.data.get('flip_idx', None) # for keypoints augmentation
if dataset.use_keypoints and flip_idx is None and hyp.fliplr > 0.0:
hyp.fliplr = 0.0
LOGGER.warning("WARNING ⚠️ No `flip_idx` provided while training keypoints, setting augmentation 'fliplr=0.0'")
if dataset.use_keypoints:
kpt_shape = dataset.data.get('kpt_shape', None)
if flip_idx is None and hyp.fliplr > 0.0:
hyp.fliplr = 0.0
LOGGER.warning("WARNING ⚠️ No 'flip_idx' array defined in data.yaml, setting augmentation 'fliplr=0.0'")
elif flip_idx:
if len(flip_idx) != kpt_shape[0]:
raise ValueError(f'data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}')
elif flip_idx[0] != 0:
raise ValueError(f'data.yaml flip_idx={flip_idx} must be zero-index (start from 0)')
return Compose([
pre_transform,
MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),

View File

@ -266,7 +266,7 @@ def check_det_dataset(dataset, autodownload=True):
return data # dictionary
def check_cls_dataset(dataset: str):
def check_cls_dataset(dataset: str, split=''):
"""
Check a classification dataset such as Imagenet.
@ -275,6 +275,7 @@ def check_cls_dataset(dataset: str):
Args:
dataset (str): Name of the dataset.
split (str, optional): Dataset split, either 'val', 'test', or ''. Defaults to ''.
Returns:
data (dict): A dictionary containing the following keys and values:
@ -298,10 +299,15 @@ def check_cls_dataset(dataset: str):
train_set = data_dir / 'train'
val_set = data_dir / 'val' if (data_dir / 'val').exists() else None # data/test or data/val
test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test
if split == 'val' and not val_set:
LOGGER.info("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.")
elif split == 'test' and not test_set:
LOGGER.info("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.")
nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
names = dict(enumerate(sorted(names)))
return {'train': train_set, 'val': val_set, 'test': test_set, 'nc': nc, 'names': names}
return {'train': train_set, 'val': val_set or test_set, 'test': test_set or val_set, 'nc': nc, 'names': names}
class HUBDatasetStats():

View File

@ -126,7 +126,7 @@ class BaseValidator:
if isinstance(self.args.data, str) and self.args.data.endswith('.yaml'):
self.data = check_det_dataset(self.args.data)
elif self.args.task == 'classify':
self.data = check_cls_dataset(self.args.data)
self.data = check_cls_dataset(self.args.data, split=self.args.split)
else:
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))

View File

@ -3,7 +3,8 @@
Benchmark a YOLO model formats for speed and accuracy
Usage:
from ultralytics.yolo.utils.benchmarks import run_benchmarks
from ultralytics.yolo.utils.benchmarks import ProfileModels, run_benchmarks
ProfileModels(['yolov8n.yaml', 'yolov8s.yaml'])
run_benchmarks(model='yolov8n.pt', imgsz=160)
Format | `format=argument` | Model
@ -22,14 +23,19 @@ TensorFlow.js | `tfjs` | yolov8n_web_model/
PaddlePaddle | `paddle` | yolov8n_paddle_model/
"""
import glob
import platform
import time
from pathlib import Path
import numpy as np
import torch.cuda
from tqdm import tqdm
from ultralytics import YOLO
from ultralytics.yolo.engine.exporter import export_formats
from ultralytics.yolo.utils import LINUX, LOGGER, MACOS, ROOT, SETTINGS
from ultralytics.yolo.utils.checks import check_yolo
from ultralytics.yolo.utils.checks import check_requirements, check_yolo
from ultralytics.yolo.utils.downloads import download
from ultralytics.yolo.utils.files import file_size
from ultralytics.yolo.utils.torch_utils import select_device
@ -140,5 +146,140 @@ def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
return df
class ProfileModels:
"""
ProfileModels class for profiling different models on ONNX and TensorRT.
This class profiles the performance of different models, provided their paths. The profiling includes parameters such as
model speed and FLOPs.
Attributes:
paths (list): Paths of the models to profile.
num_timed_runs (int): Number of timed runs for the profiling. Default is 100.
num_warmup_runs (int): Number of warmup runs before profiling. Default is 3.
imgsz (int): Image size used in the models. Default is 640.
Methods:
profile(): Profiles the models and prints the result.
"""
def __init__(self, paths: list, num_timed_runs=100, num_warmup_runs=3, imgsz=640, trt=True):
self.paths = paths
self.num_timed_runs = num_timed_runs
self.num_warmup_runs = num_warmup_runs
self.imgsz = imgsz
self.trt = trt # run TensorRT profiling
self.profile() # run profiling
def profile(self):
files = self.get_files()
if not files:
print('No matching *.pt or *.onnx files found.')
return
table_rows = []
device = 0 if torch.cuda.is_available() else 'cpu'
for file in files:
engine_file = ''
if file.suffix in ('.pt', '.yaml'):
model = YOLO(str(file))
num_params, num_flops = model.info()
if self.trt and device == 0:
engine_file = model.export(format='engine', half=True, imgsz=self.imgsz, device=device)
onnx_file = model.export(format='onnx', half=True, imgsz=self.imgsz, simplify=True, device=device)
elif file.suffix == '.onnx':
num_params, num_flops = self.get_onnx_model_info(file)
onnx_file = file
else:
continue
t_engine = self.profile_tensorrt_model(str(engine_file))
t_onnx = self.profile_onnx_model(str(onnx_file))
table_rows.append(self.generate_table_row(file.stem, t_onnx, t_engine, num_params, num_flops))
self.print_table(table_rows)
def get_files(self):
files = []
for path in self.paths:
path = Path(path)
if path.is_dir():
extensions = ['*.pt', '*.onnx', '*.yaml']
files.extend([file for ext in extensions for file in glob.glob(str(path / ext))])
elif path.suffix in {'.pt', '.yaml'}: # add non-existing
files.append(str(path))
else:
files.extend(glob.glob(str(path)))
print(f'Profiling: {sorted(files)}')
return [Path(file) for file in sorted(files)]
def get_onnx_model_info(self, onnx_file: str):
return 0.0, 0.0
def profile_tensorrt_model(self, engine_file: str):
if not Path(engine_file).is_file():
return 0.0, 0.0
# Warmup runs
model = YOLO(engine_file)
input_data = np.random.rand(self.imgsz, self.imgsz, 3).astype(np.float32)
for _ in range(self.num_warmup_runs):
model(input_data, verbose=False)
# Timed runs
run_times = []
for _ in tqdm(range(self.num_timed_runs), desc=engine_file):
results = model(input_data, verbose=False)
run_times.append(results[0].speed['inference']) # Convert to milliseconds
return np.mean(run_times), np.std(run_times)
def profile_onnx_model(self, onnx_file: str):
check_requirements('onnxruntime')
import onnxruntime as ort
# Session with either 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess = ort.InferenceSession(onnx_file, sess_options, providers=['CPUExecutionProvider'])
input_tensor = sess.get_inputs()[0]
input_data = np.random.rand(*input_tensor.shape).astype(np.float16 if torch.cuda.is_available() else np.float32)
input_name = input_tensor.name
output_name = sess.get_outputs()[0].name
# Warmup runs
for _ in range(self.num_warmup_runs):
sess.run([output_name], {input_name: input_data})
# Timed runs
run_times = []
for _ in tqdm(range(self.num_timed_runs), desc=onnx_file):
start_time = time.time()
sess.run([output_name], {input_name: input_data})
run_times.append((time.time() - start_time) * 1000) # Convert to milliseconds
return np.mean(run_times), np.std(run_times)
def generate_table_row(self, model_name, t_onnx, t_engine, num_params, num_flops):
return f'| {model_name} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± {t_engine[1]:.2f} ms | {num_params / 1e6:.1f} | {num_flops:.1f} |'
def print_table(self, table_rows):
gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'GPU'
header = f'| Model | size<br><sup>(pixels) | mAP<sup>val<br>50-95 | Speed<br><sup>CPU ONNX<br>(ms) | Speed<br><sup>{gpu} TensorRT<br>(ms) | params<br><sup>(M) | FLOPs<br><sup>(B) |'
separator = '|-------------|---------------------|--------------------|------------------------------|-----------------------------------|------------------|-----------------|'
print(header)
print(separator)
for row in table_rows:
print(row)
if __name__ == '__main__':
# Benchmark all export formats
benchmark()
# Profiling models on ONNX and TensorRT
ProfileModels(['yolov8n.yaml', 'yolov8s.yaml'])

View File

@ -67,7 +67,7 @@ def box_iou(box1, box2, eps=1e-7):
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
(a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp_(0).prod(2)
# IoU = inter / (area1 + area2 - inter)
return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)
@ -104,8 +104,8 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
# Intersection area
inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \
(b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0)
inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * \
(b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp_(0)
# Union Area
union = w1 * h1 + w2 * h2 - inter + eps
@ -143,7 +143,7 @@ def mask_iou(mask1, mask2, eps=1e-7):
Returns:
(torch.Tensor): A tensor of shape (N, M) representing masks IoU.
"""
intersection = torch.matmul(mask1, mask2.t()).clamp(0)
intersection = torch.matmul(mask1, mask2.T).clamp_(0)
union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection # (area1 + area2) - intersection
return intersection / (union + eps)

View File

@ -10,6 +10,7 @@ import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
from PIL import __version__ as pil_version
from scipy.ndimage.filters import gaussian_filter1d
from ultralytics.yolo.utils import LOGGER, TryExcept, plt_settings, threaded
@ -455,7 +456,8 @@ def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False,
for i, j in enumerate(index):
y = data.values[:, j].astype('float')
# y[y == 0] = np.nan # don't show zero values
ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8)
ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8) # actual results
ax[i].plot(x, gaussian_filter1d(y, sigma=3), ':', label='smooth', linewidth=2) # smoothing line
ax[i].set_title(s[j], fontsize=12)
# if j in [8, 9, 10]: # share train and val loss y axes
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])

View File

@ -41,7 +41,7 @@ def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
# (b, n_max_boxes, h*w) -> (b, h*w)
fg_mask = mask_pos.sum(-2)
if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).repeat([1, n_max_boxes, 1]) # (b, n_max_boxes, h*w)
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)
max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
@ -132,7 +132,7 @@ class TaskAlignedAssigner(nn.Module):
# Get anchor_align metric, (b, max_num_obj, h*w)
align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
# Get topk_metric mask, (b, max_num_obj, h*w)
mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.repeat([1, 1, self.topk]).bool())
mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool())
# Merge all mask to a final mask, (b, max_num_obj, h*w)
mask_pos = mask_topk * mask_in_gts * mask_gt
@ -146,15 +146,15 @@ class TaskAlignedAssigner(nn.Module):
bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj
ind[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.n_max_boxes) # b, max_num_obj
ind[1] = gt_labels.long().squeeze(-1) # b, max_num_obj
ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes) # b, max_num_obj
ind[1] = gt_labels.squeeze(-1) # b, max_num_obj
# Get the scores of each grid for each gt cls
bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt] # b, max_num_obj, h*w
# (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
pd_boxes = pd_bboxes.unsqueeze(1).repeat(1, self.n_max_boxes, 1, 1)[mask_gt]
gt_boxes = gt_bboxes.unsqueeze(2).repeat(1, 1, na, 1)[mask_gt]
overlaps[mask_gt] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp(0)
pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]
gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]
overlaps[mask_gt] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
return align_metric, overlaps
@ -273,4 +273,4 @@ def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
def bbox2dist(anchor_points, bbox, reg_max):
"""Transform bbox(xyxy) to dist(ltrb)."""
x1y1, x2y2 = bbox.chunk(2, -1)
return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp(0, reg_max - 0.01) # dist (lt, rb)
return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01) # dist (lt, rb)

View File

@ -205,6 +205,20 @@ def get_flops(model, imgsz=640):
return 0
def get_flops_with_torch_profiler(model, imgsz=640):
# Compute model FLOPs (thop alternative)
model = de_parallel(model)
p = next(model.parameters())
stride = (max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32) * 2 # max stride
im = torch.zeros((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
with torch.profiler.profile(with_flops=True) as prof:
model(im)
flops = sum(x.flops for x in prof.key_averages()) / 1E9
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
return flops
def initialize_weights(model):
"""Initialize model weights to random values."""
for m in model.modules():

View File

@ -60,8 +60,7 @@ class ClassificationValidator(BaseValidator):
return self.metrics.results_dict
def build_dataset(self, img_path):
dataset = ClassificationDataset(root=img_path, args=self.args, augment=False)
return dataset
return ClassificationDataset(root=img_path, args=self.args, augment=False)
def get_dataloader(self, dataset_path, batch_size):
"""Builds and returns a data loader for classification tasks with given parameters."""