ultralytics 8.0.108
add Meituan YOLOv6 models (#2811)
Co-authored-by: Michael Currie <mcurrie@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Hicham Talaoubrid <98521878+HichTala@users.noreply.github.com> Co-authored-by: Zlobin Vladimir <vladimir.zlobin@intel.com> Co-authored-by: Szymon Mikler <sjmikler@gmail.com>
This commit is contained in:
@ -111,10 +111,13 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
|
||||
check_cfg_mismatch(cfg, overrides)
|
||||
cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)
|
||||
|
||||
# Special handling for numeric project/names
|
||||
# Special handling for numeric project/name
|
||||
for k in 'project', 'name':
|
||||
if k in cfg and isinstance(cfg[k], (int, float)):
|
||||
cfg[k] = str(cfg[k])
|
||||
if cfg.get('name') == 'model': # assign model to 'name' arg
|
||||
cfg['name'] = cfg.get('model', '').split('.')[0]
|
||||
LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.")
|
||||
|
||||
# Type and Value checks
|
||||
for k, v in cfg.items():
|
||||
|
@ -116,7 +116,7 @@ def check_source(source):
|
||||
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
||||
is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://'))
|
||||
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
|
||||
screenshot = source.lower().startswith('screen')
|
||||
screenshot = source.lower() == 'screen'
|
||||
if is_url and is_file:
|
||||
source = check_file(source) # download
|
||||
elif isinstance(source, tuple(LOADERS)):
|
||||
|
@ -331,12 +331,12 @@ class YOLO:
|
||||
overrides = self.overrides.copy()
|
||||
overrides.update(kwargs)
|
||||
overrides['mode'] = 'export'
|
||||
if overrides.get('imgsz') is None:
|
||||
overrides['imgsz'] = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
||||
if overrides.get('batch') is None:
|
||||
overrides['batch'] = 1 # default to 1 if not modified
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
args.task = self.task
|
||||
if args.imgsz == DEFAULT_CFG.imgsz:
|
||||
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
||||
if args.batch == DEFAULT_CFG.batch:
|
||||
args.batch = 1 # default to 1 if not modified
|
||||
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
|
||||
|
||||
def train(self, **kwargs):
|
||||
|
@ -684,12 +684,17 @@ def check_amp(model):
|
||||
im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if ONLINE else np.ones((640, 640, 3))
|
||||
prefix = colorstr('AMP: ')
|
||||
LOGGER.info(f'{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...')
|
||||
warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False."
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
assert amp_allclose(YOLO('yolov8n.pt'), im)
|
||||
LOGGER.info(f'{prefix}checks passed ✅')
|
||||
except ConnectionError:
|
||||
LOGGER.warning(f"{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. Setting 'amp=True'.")
|
||||
LOGGER.warning(f'{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}')
|
||||
except (AttributeError, ModuleNotFoundError):
|
||||
LOGGER.warning(
|
||||
f'{prefix}checks skipped ⚠️. Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}'
|
||||
)
|
||||
except AssertionError:
|
||||
LOGGER.warning(f'{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to '
|
||||
f'NaN losses or zero-mAP results, so AMP will be disabled during training.')
|
||||
|
@ -372,12 +372,15 @@ def is_online() -> bool:
|
||||
"""
|
||||
import socket
|
||||
|
||||
for server in '1.1.1.1', '8.8.8.8', '223.5.5.5': # Cloudflare, Google, AliDNS:
|
||||
for host in '1.1.1.1', '8.8.8.8', '223.5.5.5': # Cloudflare, Google, AliDNS:
|
||||
try:
|
||||
socket.create_connection((server, 53), timeout=2) # connect to (server, port=53)
|
||||
return True
|
||||
test_connection = socket.create_connection(address=(host, 53), timeout=2)
|
||||
except (socket.timeout, socket.gaierror, OSError):
|
||||
continue
|
||||
else:
|
||||
# If the connection was successful, close it to avoid a ResourceWarning
|
||||
test_connection.close()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
Benchmark a YOLO model formats for speed and accuracy
|
||||
|
||||
Usage:
|
||||
from ultralytics.yolo.utils.benchmarks import ProfileModels, run_benchmarks
|
||||
from ultralytics.yolo.utils.benchmarks import ProfileModels, benchmark
|
||||
ProfileModels(['yolov8n.yaml', 'yolov8s.yaml'])
|
||||
run_benchmarks(model='yolov8n.pt', imgsz=160)
|
||||
|
||||
@ -163,7 +163,7 @@ class ProfileModels:
|
||||
profile(): Profiles the models and prints the result.
|
||||
"""
|
||||
|
||||
def __init__(self, paths: list, num_timed_runs=100, num_warmup_runs=3, imgsz=640, trt=True):
|
||||
def __init__(self, paths: list, num_timed_runs=100, num_warmup_runs=10, imgsz=640, trt=True):
|
||||
self.paths = paths
|
||||
self.num_timed_runs = num_timed_runs
|
||||
self.num_warmup_runs = num_warmup_runs
|
||||
@ -181,22 +181,22 @@ class ProfileModels:
|
||||
table_rows = []
|
||||
device = 0 if torch.cuda.is_available() else 'cpu'
|
||||
for file in files:
|
||||
engine_file = ''
|
||||
engine_file = file.with_suffix('.engine')
|
||||
if file.suffix in ('.pt', '.yaml'):
|
||||
model = YOLO(str(file))
|
||||
num_params, num_flops = model.info()
|
||||
if self.trt and device == 0:
|
||||
model_info = model.info()
|
||||
if self.trt and device == 0 and not engine_file.is_file():
|
||||
engine_file = model.export(format='engine', half=True, imgsz=self.imgsz, device=device)
|
||||
onnx_file = model.export(format='onnx', half=True, imgsz=self.imgsz, simplify=True, device=device)
|
||||
elif file.suffix == '.onnx':
|
||||
num_params, num_flops = self.get_onnx_model_info(file)
|
||||
model_info = self.get_onnx_model_info(file)
|
||||
onnx_file = file
|
||||
else:
|
||||
continue
|
||||
|
||||
t_engine = self.profile_tensorrt_model(str(engine_file))
|
||||
t_onnx = self.profile_onnx_model(str(onnx_file))
|
||||
table_rows.append(self.generate_table_row(file.stem, t_onnx, t_engine, num_params, num_flops))
|
||||
table_rows.append(self.generate_table_row(file.stem, t_onnx, t_engine, model_info))
|
||||
|
||||
self.print_table(table_rows)
|
||||
|
||||
@ -216,10 +216,21 @@ class ProfileModels:
|
||||
return [Path(file) for file in sorted(files)]
|
||||
|
||||
def get_onnx_model_info(self, onnx_file: str):
|
||||
return 0.0, 0.0
|
||||
# return (num_layers, num_params, num_gradients, num_flops)
|
||||
return 0.0, 0.0, 0.0, 0.0
|
||||
|
||||
def iterative_sigma_clipping(self, data, sigma=2, max_iters=5):
|
||||
data = np.array(data)
|
||||
for _ in range(max_iters):
|
||||
mean, std = np.mean(data), np.std(data)
|
||||
clipped_data = data[(data > mean - sigma * std) & (data < mean + sigma * std)]
|
||||
if len(clipped_data) == len(data):
|
||||
break
|
||||
data = clipped_data
|
||||
return data
|
||||
|
||||
def profile_tensorrt_model(self, engine_file: str):
|
||||
if not Path(engine_file).is_file():
|
||||
if not self.trt or not Path(engine_file).is_file():
|
||||
return 0.0, 0.0
|
||||
|
||||
# Warmup runs
|
||||
@ -230,10 +241,11 @@ class ProfileModels:
|
||||
|
||||
# Timed runs
|
||||
run_times = []
|
||||
for _ in tqdm(range(self.num_timed_runs), desc=engine_file):
|
||||
for _ in tqdm(range(self.num_timed_runs * 30), desc=engine_file):
|
||||
results = model(input_data, verbose=False)
|
||||
run_times.append(results[0].speed['inference']) # Convert to milliseconds
|
||||
|
||||
run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=3) # sigma clipping
|
||||
return np.mean(run_times), np.std(run_times)
|
||||
|
||||
def profile_onnx_model(self, onnx_file: str):
|
||||
@ -246,7 +258,23 @@ class ProfileModels:
|
||||
sess = ort.InferenceSession(onnx_file, sess_options, providers=['CPUExecutionProvider'])
|
||||
|
||||
input_tensor = sess.get_inputs()[0]
|
||||
input_data = np.random.rand(*input_tensor.shape).astype(np.float16 if torch.cuda.is_available() else np.float32)
|
||||
input_type = input_tensor.type
|
||||
|
||||
# Mapping ONNX datatype to numpy datatype
|
||||
if 'float16' in input_type:
|
||||
input_dtype = np.float16
|
||||
elif 'float' in input_type:
|
||||
input_dtype = np.float32
|
||||
elif 'double' in input_type:
|
||||
input_dtype = np.float64
|
||||
elif 'int64' in input_type:
|
||||
input_dtype = np.int64
|
||||
elif 'int32' in input_type:
|
||||
input_dtype = np.int32
|
||||
else:
|
||||
raise ValueError(f'Unsupported ONNX datatype {input_type}')
|
||||
|
||||
input_data = np.random.rand(*input_tensor.shape).astype(input_dtype)
|
||||
input_name = input_tensor.name
|
||||
output_name = sess.get_outputs()[0].name
|
||||
|
||||
@ -261,17 +289,19 @@ class ProfileModels:
|
||||
sess.run([output_name], {input_name: input_data})
|
||||
run_times.append((time.time() - start_time) * 1000) # Convert to milliseconds
|
||||
|
||||
run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=3) # sigma clipping
|
||||
return np.mean(run_times), np.std(run_times)
|
||||
|
||||
def generate_table_row(self, model_name, t_onnx, t_engine, num_params, num_flops):
|
||||
return f'| {model_name} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± {t_engine[1]:.2f} ms | {num_params / 1e6:.1f} | {num_flops:.1f} |'
|
||||
def generate_table_row(self, model_name, t_onnx, t_engine, model_info):
|
||||
layers, params, gradients, flops = model_info
|
||||
return f'| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± {t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |'
|
||||
|
||||
def print_table(self, table_rows):
|
||||
gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'GPU'
|
||||
header = f'| Model | size<br><sup>(pixels) | mAP<sup>val<br>50-95 | Speed<br><sup>CPU ONNX<br>(ms) | Speed<br><sup>{gpu} TensorRT<br>(ms) | params<br><sup>(M) | FLOPs<br><sup>(B) |'
|
||||
separator = '|-------------|---------------------|--------------------|------------------------------|-----------------------------------|------------------|-----------------|'
|
||||
|
||||
print(header)
|
||||
print(f'\n\n{header}')
|
||||
print(separator)
|
||||
for row in table_rows:
|
||||
print(row)
|
||||
|
@ -104,7 +104,8 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
|
||||
"""
|
||||
if ratio_pad is None: # calculate from img0_shape
|
||||
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
||||
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
|
||||
pad = round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1), round(
|
||||
(img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1) # wh padding
|
||||
else:
|
||||
gain = ratio_pad[0][0]
|
||||
pad = ratio_pad[1]
|
||||
|
@ -162,8 +162,9 @@ def model_info(model, detailed=False, verbose=True, imgsz=640):
|
||||
"""Model information. imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]."""
|
||||
if not verbose:
|
||||
return
|
||||
n_p = get_num_params(model)
|
||||
n_g = get_num_gradients(model) # number gradients
|
||||
n_p = get_num_params(model) # number of parameters
|
||||
n_g = get_num_gradients(model) # number of gradients
|
||||
n_l = len(list(model.modules())) # number of layers
|
||||
if detailed:
|
||||
LOGGER.info(
|
||||
f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
|
||||
@ -173,11 +174,12 @@ def model_info(model, detailed=False, verbose=True, imgsz=640):
|
||||
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std(), p.dtype))
|
||||
|
||||
flops = get_flops(model, imgsz)
|
||||
fused = ' (fused)' if model.is_fused() else ''
|
||||
fused = ' (fused)' if getattr(model, 'is_fused', lambda: False)() else ''
|
||||
fs = f', {flops:.1f} GFLOPs' if flops else ''
|
||||
m = Path(getattr(model, 'yaml_file', '') or model.yaml.get('yaml_file', '')).stem.replace('yolo', 'YOLO') or 'Model'
|
||||
LOGGER.info(f'{m} summary{fused}: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}')
|
||||
return n_p, flops
|
||||
yaml_file = getattr(model, 'yaml_file', '') or getattr(model, 'yaml', {}).get('yaml_file', '')
|
||||
model_name = Path(yaml_file).stem.replace('yolo', 'YOLO') or 'Model'
|
||||
LOGGER.info(f'{model_name} summary{fused}: {n_l} layers, {n_p} parameters, {n_g} gradients{fs}')
|
||||
return n_l, n_p, n_g, flops
|
||||
|
||||
|
||||
def get_num_params(model):
|
||||
@ -199,8 +201,7 @@ def get_flops(model, imgsz=640):
|
||||
im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
||||
flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1E9 * 2 if thop else 0 # stride GFLOPs
|
||||
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
|
||||
flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
|
||||
return flops
|
||||
return flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
Reference in New Issue
Block a user