diff --git a/docker/Dockerfile-arm64 b/docker/Dockerfile-arm64 index c11dba4..bd54323 100644 --- a/docker/Dockerfile-arm64 +++ b/docker/Dockerfile-arm64 @@ -24,7 +24,7 @@ ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt /u # Install pip packages RUN python3 -m pip install --upgrade pip wheel -RUN pip install --no-cache -e . +RUN pip install --no-cache -e . thop # Usage Examples ------------------------------------------------------------------------------------------------------- diff --git a/docker/Dockerfile-cpu b/docker/Dockerfile-cpu index fa6ec12..c58e423 100644 --- a/docker/Dockerfile-cpu +++ b/docker/Dockerfile-cpu @@ -25,7 +25,7 @@ ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt /u # Install pip packages RUN python3 -m pip install --upgrade pip wheel -RUN pip install --no-cache -e . --extra-index-url https://download.pytorch.org/whl/cpu +RUN pip install --no-cache -e . thop --extra-index-url https://download.pytorch.org/whl/cpu # Usage Examples ------------------------------------------------------------------------------------------------------- diff --git a/docker/Dockerfile-jetson b/docker/Dockerfile-jetson index fc4971c..6fbbd5d 100644 --- a/docker/Dockerfile-jetson +++ b/docker/Dockerfile-jetson @@ -25,7 +25,7 @@ ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt /u # Install pip packages manually for TensorRT compatibility https://github.com/NVIDIA/TensorRT/issues/2567 RUN python3 -m pip install --upgrade pip wheel -RUN pip install --no-cache tqdm matplotlib pyyaml psutil pandas onnx "numpy==1.23" +RUN pip install --no-cache tqdm matplotlib pyyaml psutil pandas onnx thop "numpy==1.23" RUN pip install --no-cache -e . # Set environment variables diff --git a/docs/modes/train.md b/docs/modes/train.md index 1d629a9..d560360 100644 --- a/docs/modes/train.md +++ b/docs/modes/train.md @@ -83,6 +83,7 @@ task. | `resume` | `False` | resume training from last checkpoint | | `amp` | `True` | Automatic Mixed Precision (AMP) training, choices=[True, False] | | `fraction` | `1.0` | dataset fraction to train on (default is 1.0, all images in train set) | +| `profile` | `False` | profile ONNX and TensorRT speeds during training for loggers | | `lr0` | `0.01` | initial learning rate (i.e. SGD=1E-2, Adam=1E-3) | | `lrf` | `0.01` | final learning rate (lr0 * lrf) | | `momentum` | `0.937` | SGD momentum/Adam beta1 | diff --git a/docs/usage/cfg.md b/docs/usage/cfg.md index ae3853c..0c99002 100644 --- a/docs/usage/cfg.md +++ b/docs/usage/cfg.md @@ -105,6 +105,7 @@ The training settings for YOLO models encompass various hyperparameters and conf | `resume` | `False` | resume training from last checkpoint | | `amp` | `True` | Automatic Mixed Precision (AMP) training, choices=[True, False] | | `fraction` | `1.0` | dataset fraction to train on (default is 1.0, all images in train set) | +| `profile` | `False` | profile ONNX and TensorRT speeds during training for loggers | | `lr0` | `0.01` | initial learning rate (i.e. SGD=1E-2, Adam=1E-3) | | `lrf` | `0.01` | final learning rate (lr0 * lrf) | | `momentum` | `0.937` | SGD momentum/Adam beta1 | diff --git a/ultralytics/yolo/cfg/__init__.py b/ultralytics/yolo/cfg/__init__.py index c3c8d60..bc74d55 100644 --- a/ultralytics/yolo/cfg/__init__.py +++ b/ultralytics/yolo/cfg/__init__.py @@ -1,4 +1,5 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license + import contextlib import re import shutil @@ -72,7 +73,7 @@ CFG_INT_KEYS = ('epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic' CFG_BOOL_KEYS = ('save', 'exist_ok', 'verbose', 'deterministic', 'single_cls', 'rect', 'cos_lr', 'overlap_mask', 'val', 'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt', 'save_conf', 'save_crop', 'show_labels', 'show_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks', 'boxes', 'keras', - 'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'v5loader') + 'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'v5loader', 'profile') def cfg2dict(cfg): diff --git a/ultralytics/yolo/cfg/default.yaml b/ultralytics/yolo/cfg/default.yaml index abf12c3..41b9449 100644 --- a/ultralytics/yolo/cfg/default.yaml +++ b/ultralytics/yolo/cfg/default.yaml @@ -31,6 +31,7 @@ close_mosaic: 0 # (int) disable mosaic augmentation for final epochs resume: False # resume training from last checkpoint amp: True # Automatic Mixed Precision (AMP) training, choices=[True, False], True runs AMP check fraction: 1.0 # dataset fraction to train on (default is 1.0, all images in train set) +profile: False # profile ONNX and TensorRT speeds during training for loggers # Segmentation overlap_mask: True # masks should overlap during training (segment train only) mask_ratio: 4 # mask downsample ratio (segment train only) diff --git a/ultralytics/yolo/utils/benchmarks.py b/ultralytics/yolo/utils/benchmarks.py index ff92683..b87fd2b 100644 --- a/ultralytics/yolo/utils/benchmarks.py +++ b/ultralytics/yolo/utils/benchmarks.py @@ -4,7 +4,7 @@ Benchmark a YOLO model formats for speed and accuracy Usage: from ultralytics.yolo.utils.benchmarks import ProfileModels, benchmark - ProfileModels(['yolov8n.yaml', 'yolov8s.yaml']) + ProfileModels(['yolov8n.yaml', 'yolov8s.yaml']).profile() run_benchmarks(model='yolov8n.pt', imgsz=160) Format | `format=argument` | Model @@ -163,13 +163,13 @@ class ProfileModels: profile(): Profiles the models and prints the result. """ - def __init__(self, paths: list, num_timed_runs=100, num_warmup_runs=10, imgsz=640, trt=True): + def __init__(self, paths: list, num_timed_runs=100, num_warmup_runs=10, imgsz=640, trt=True, device=None): self.paths = paths self.num_timed_runs = num_timed_runs self.num_warmup_runs = num_warmup_runs self.imgsz = imgsz self.trt = trt # run TensorRT profiling - self.profile() # run profiling + self.device = device or torch.device(0 if torch.cuda.is_available() else 'cpu') def profile(self): files = self.get_files() @@ -179,15 +179,16 @@ class ProfileModels: return table_rows = [] - device = 0 if torch.cuda.is_available() else 'cpu' + output = [] for file in files: engine_file = file.with_suffix('.engine') if file.suffix in ('.pt', '.yaml'): model = YOLO(str(file)) + model.fuse() # to report correct params and GFLOPs in model.info() model_info = model.info() - if self.trt and device == 0 and not engine_file.is_file(): - engine_file = model.export(format='engine', half=True, imgsz=self.imgsz, device=device) - onnx_file = model.export(format='onnx', half=True, imgsz=self.imgsz, simplify=True, device=device) + if self.trt and self.device.type != 'cpu' and not engine_file.is_file(): + engine_file = model.export(format='engine', half=True, imgsz=self.imgsz, device=self.device) + onnx_file = model.export(format='onnx', half=True, imgsz=self.imgsz, simplify=True, device=self.device) elif file.suffix == '.onnx': model_info = self.get_onnx_model_info(file) onnx_file = file @@ -197,8 +198,10 @@ class ProfileModels: t_engine = self.profile_tensorrt_model(str(engine_file)) t_onnx = self.profile_onnx_model(str(onnx_file)) table_rows.append(self.generate_table_row(file.stem, t_onnx, t_engine, model_info)) + output.append(self.generate_results_dict(file.stem, t_onnx, t_engine, model_info)) self.print_table(table_rows) + return output def get_files(self): files = [] @@ -219,7 +222,7 @@ class ProfileModels: # return (num_layers, num_params, num_gradients, num_flops) return 0.0, 0.0, 0.0, 0.0 - def iterative_sigma_clipping(self, data, sigma=2, max_iters=5): + def iterative_sigma_clipping(self, data, sigma=2, max_iters=3): data = np.array(data) for _ in range(max_iters): mean, std = np.mean(data), np.std(data) @@ -235,13 +238,13 @@ class ProfileModels: # Warmup runs model = YOLO(engine_file) - input_data = np.random.rand(self.imgsz, self.imgsz, 3).astype(np.float32) + input_data = np.random.rand(self.imgsz, self.imgsz, 3).astype(np.float32) # must be FP32 for _ in range(self.num_warmup_runs): model(input_data, verbose=False) # Timed runs run_times = [] - for _ in tqdm(range(self.num_timed_runs * 30), desc=engine_file): + for _ in tqdm(range(self.num_timed_runs * 50), desc=engine_file): results = model(input_data, verbose=False) run_times.append(results[0].speed['inference']) # Convert to milliseconds @@ -255,6 +258,7 @@ class ProfileModels: # Session with either 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider' sess_options = ort.SessionOptions() sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + sess_options.intra_op_num_threads = 8 # Limit the number of threads sess = ort.InferenceSession(onnx_file, sess_options, providers=['CPUExecutionProvider']) input_tensor = sess.get_inputs()[0] @@ -289,13 +293,22 @@ class ProfileModels: sess.run([output_name], {input_name: input_data}) run_times.append((time.time() - start_time) * 1000) # Convert to milliseconds - run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=3) # sigma clipping + run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=5) # sigma clipping return np.mean(run_times), np.std(run_times) def generate_table_row(self, model_name, t_onnx, t_engine, model_info): layers, params, gradients, flops = model_info return f'| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± {t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |' + def generate_results_dict(self, model_name, t_onnx, t_engine, model_info): + layers, params, gradients, flops = model_info + return { + 'model/name': model_name, + 'model/parameters': params, + 'model/GFLOPs': round(flops, 3), + 'model/speed_ONNX(ms)': round(t_onnx[0], 3), + 'model/speed_TensorRT(ms)': round(t_engine[0], 3)} + def print_table(self, table_rows): gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'GPU' header = f'| Model | size
(pixels) | mAPval
50-95 | Speed
CPU ONNX
(ms) | Speed
{gpu} TensorRT
(ms) | params
(M) | FLOPs
(B) |' diff --git a/ultralytics/yolo/utils/callbacks/__init__.py b/ultralytics/yolo/utils/callbacks/__init__.py index 1071ef4..8ad4ad6 100644 --- a/ultralytics/yolo/utils/callbacks/__init__.py +++ b/ultralytics/yolo/utils/callbacks/__init__.py @@ -1,3 +1,5 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + from .base import add_integration_callbacks, default_callbacks, get_default_callbacks __all__ = 'add_integration_callbacks', 'default_callbacks', 'get_default_callbacks' diff --git a/ultralytics/yolo/utils/callbacks/clearml.py b/ultralytics/yolo/utils/callbacks/clearml.py index 094ad10..2cfdd73 100644 --- a/ultralytics/yolo/utils/callbacks/clearml.py +++ b/ultralytics/yolo/utils/callbacks/clearml.py @@ -1,11 +1,12 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license + import re import matplotlib.image as mpimg import matplotlib.pyplot as plt from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING -from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params +from ultralytics.yolo.utils.torch_utils import model_info_for_loggers try: import clearml @@ -105,11 +106,7 @@ def on_fit_epoch_end(trainer): value=trainer.epoch_time, iteration=trainer.epoch) if trainer.epoch == 0: - model_info = { - 'model/parameters': get_num_params(trainer.model), - 'model/GFLOPs': round(get_flops(trainer.model), 3), - 'model/speed(ms)': round(trainer.validator.speed['inference'], 3)} - for k, v in model_info.items(): + for k, v in model_info_for_loggers(trainer).items(): task.get_logger().report_single_value(k, v) diff --git a/ultralytics/yolo/utils/callbacks/comet.py b/ultralytics/yolo/utils/callbacks/comet.py index f35eed2..bbf93ab 100644 --- a/ultralytics/yolo/utils/callbacks/comet.py +++ b/ultralytics/yolo/utils/callbacks/comet.py @@ -1,9 +1,10 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license + import os from pathlib import Path from ultralytics.yolo.utils import LOGGER, RANK, TESTS_RUNNING, ops -from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params +from ultralytics.yolo.utils.torch_utils import model_info_for_loggers try: import comet_ml @@ -324,11 +325,7 @@ def on_fit_epoch_end(trainer): experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch) experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch) if curr_epoch == 1: - model_info = { - 'model/parameters': get_num_params(trainer.model), - 'model/GFLOPs': round(get_flops(trainer.model), 3), - 'model/speed(ms)': round(trainer.validator.speed['inference'], 3), } - experiment.log_metrics(model_info, step=curr_step, epoch=curr_epoch) + experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch) if not save_assets: return diff --git a/ultralytics/yolo/utils/callbacks/hub.py b/ultralytics/yolo/utils/callbacks/hub.py index 3617a5a..e3b3427 100644 --- a/ultralytics/yolo/utils/callbacks/hub.py +++ b/ultralytics/yolo/utils/callbacks/hub.py @@ -5,7 +5,7 @@ from time import time from ultralytics.hub.utils import PREFIX, events from ultralytics.yolo.utils import LOGGER -from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params +from ultralytics.yolo.utils.torch_utils import model_info_for_loggers def on_pretrain_routine_end(trainer): @@ -24,11 +24,7 @@ def on_fit_epoch_end(trainer): # Upload metrics after val end all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics} if trainer.epoch == 0: - model_info = { - 'model/parameters': get_num_params(trainer.model), - 'model/GFLOPs': round(get_flops(trainer.model), 3), - 'model/speed(ms)': round(trainer.validator.speed['inference'], 3)} - all_plots = {**all_plots, **model_info} + all_plots = {**all_plots, **model_info_for_loggers(trainer)} session.metrics_queue[trainer.epoch] = json.dumps(all_plots) if time() - session.timers['metrics'] > session.rate_limits['metrics']: session.upload_metrics() diff --git a/ultralytics/yolo/utils/callbacks/neptune.py b/ultralytics/yolo/utils/callbacks/neptune.py index 1355d81..96cb049 100644 --- a/ultralytics/yolo/utils/callbacks/neptune.py +++ b/ultralytics/yolo/utils/callbacks/neptune.py @@ -1,9 +1,10 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license + import matplotlib.image as mpimg import matplotlib.pyplot as plt from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING -from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params +from ultralytics.yolo.utils.torch_utils import model_info_for_loggers try: import neptune @@ -68,11 +69,7 @@ def on_train_epoch_end(trainer): def on_fit_epoch_end(trainer): """Callback function called at end of each fit (train+val) epoch.""" if run and trainer.epoch == 0: - model_info = { - 'parameters': get_num_params(trainer.model), - 'GFLOPs': round(get_flops(trainer.model), 3), - 'speed(ms)': round(trainer.validator.speed['inference'], 3)} - run['Configuration/Model'] = model_info + run['Configuration/Model'] = model_info_for_loggers(trainer) _log_scalars(trainer.metrics, trainer.epoch + 1) diff --git a/ultralytics/yolo/utils/callbacks/raytune.py b/ultralytics/yolo/utils/callbacks/raytune.py index 1fff729..1f53225 100644 --- a/ultralytics/yolo/utils/callbacks/raytune.py +++ b/ultralytics/yolo/utils/callbacks/raytune.py @@ -1,3 +1,5 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + try: import ray from ray import tune diff --git a/ultralytics/yolo/utils/callbacks/tensorboard.py b/ultralytics/yolo/utils/callbacks/tensorboard.py index 8c14dcb..a436b9c 100644 --- a/ultralytics/yolo/utils/callbacks/tensorboard.py +++ b/ultralytics/yolo/utils/callbacks/tensorboard.py @@ -1,4 +1,5 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license + from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING, colorstr try: diff --git a/ultralytics/yolo/utils/callbacks/wb.py b/ultralytics/yolo/utils/callbacks/wb.py index f8776cd..2b3d40d 100644 --- a/ultralytics/yolo/utils/callbacks/wb.py +++ b/ultralytics/yolo/utils/callbacks/wb.py @@ -1,30 +1,27 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params +from ultralytics.yolo.utils import TESTS_RUNNING +from ultralytics.yolo.utils.torch_utils import model_info_for_loggers try: import wandb as wb assert hasattr(wb, '__version__') + assert not TESTS_RUNNING # do not log pytest except (ImportError, AssertionError): wb = None def on_pretrain_routine_start(trainer): """Initiate and start project if module is present.""" - wb.init(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, config=vars( - trainer.args)) if not wb.run else wb.run + wb.run or wb.init(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, config=vars(trainer.args)) def on_fit_epoch_end(trainer): """Logs training metrics and model information at the end of an epoch.""" wb.run.log(trainer.metrics, step=trainer.epoch + 1) if trainer.epoch == 0: - model_info = { - 'model/parameters': get_num_params(trainer.model), - 'model/GFLOPs': round(get_flops(trainer.model), 3), - 'model/speed(ms)': round(trainer.validator.speed['inference'], 3)} - wb.run.log(model_info, step=trainer.epoch + 1) + wb.run.log(model_info_for_loggers(trainer), step=trainer.epoch + 1) def on_train_epoch_end(trainer): diff --git a/ultralytics/yolo/utils/torch_utils.py b/ultralytics/yolo/utils/torch_utils.py index f6862fc..98c0302 100644 --- a/ultralytics/yolo/utils/torch_utils.py +++ b/ultralytics/yolo/utils/torch_utils.py @@ -192,6 +192,29 @@ def get_num_gradients(model): return sum(x.numel() for x in model.parameters() if x.requires_grad) +def model_info_for_loggers(trainer): + """ + Return model info dict with useful model information. + + Example for YOLOv8n: + {'model/parameters': 3151904, + 'model/GFLOPs': 8.746, + 'model/speed_ONNX(ms)': 41.244, + 'model/speed_TensorRT(ms)': 3.211, + 'model/speed_PyTorch(ms)': 18.755} + """ + if trainer.args.profile: # profile ONNX and TensorRT times + from ultralytics.yolo.utils.benchmarks import ProfileModels + results = ProfileModels([trainer.last], device=trainer.device).profile()[0] + results.pop('model/name') + else: # only return PyTorch times from most recent validation + results = { + 'model/parameters': get_num_params(trainer.model), + 'model/GFLOPs': round(get_flops(trainer.model), 3)} + results['model/speed_PyTorch(ms)'] = round(trainer.validator.speed['inference'], 3) + return results + + def get_flops(model, imgsz=640): """Return a YOLO model's FLOPs.""" try: