From 3acead7e79a8769112a72fad430ef07fc2a4ea48 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 21 Aug 2023 01:22:29 +0200 Subject: [PATCH] Add TensorBoard graph for model visualization (#4464) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- docs/reference/models/yolo/segment/train.md | 4 --- docs/reference/models/yolo/segment/val.md | 4 --- docs/reference/utils/callbacks/tensorboard.md | 4 +++ ultralytics/models/yolo/segment/train.py | 11 ------- ultralytics/models/yolo/segment/val.py | 13 +-------- ultralytics/utils/callbacks/tensorboard.py | 29 +++++++++++++++---- 6 files changed, 29 insertions(+), 36 deletions(-) diff --git a/docs/reference/models/yolo/segment/train.md b/docs/reference/models/yolo/segment/train.md index 65aff01..74434a5 100644 --- a/docs/reference/models/yolo/segment/train.md +++ b/docs/reference/models/yolo/segment/train.md @@ -12,7 +12,3 @@ keywords: Ultralytics, YOLO, SegmentationTrainer, image segmentation, object det --- ## ::: ultralytics.models.yolo.segment.train.SegmentationTrainer

- ---- -## ::: ultralytics.models.yolo.segment.train.train -

diff --git a/docs/reference/models/yolo/segment/val.md b/docs/reference/models/yolo/segment/val.md index 5b20f7e..58d8b9d 100644 --- a/docs/reference/models/yolo/segment/val.md +++ b/docs/reference/models/yolo/segment/val.md @@ -12,7 +12,3 @@ keywords: Ultralytics, YOLO, SegmentationValidator, model segmentation, image cl --- ## ::: ultralytics.models.yolo.segment.val.SegmentationValidator

- ---- -## ::: ultralytics.models.yolo.segment.val.val -

diff --git a/docs/reference/utils/callbacks/tensorboard.md b/docs/reference/utils/callbacks/tensorboard.md index 3cdd2d5..95db268 100644 --- a/docs/reference/utils/callbacks/tensorboard.md +++ b/docs/reference/utils/callbacks/tensorboard.md @@ -13,6 +13,10 @@ keywords: Ultralytics, YOLO, documentation, callback utilities, log_scalars, on_ ## ::: ultralytics.utils.callbacks.tensorboard._log_scalars

+--- +## ::: ultralytics.utils.callbacks.tensorboard._log_tensorboard_graph +

+ --- ## ::: ultralytics.utils.callbacks.tensorboard.on_pretrain_routine_start

diff --git a/ultralytics/models/yolo/segment/train.py b/ultralytics/models/yolo/segment/train.py index c6e148b..b290192 100644 --- a/ultralytics/models/yolo/segment/train.py +++ b/ultralytics/models/yolo/segment/train.py @@ -56,14 +56,3 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer): def plot_metrics(self): """Plots training/val metrics.""" plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png - - -def train(cfg=DEFAULT_CFG): - """Train a YOLO segmentation model based on passed arguments.""" - args = dict(model=cfg.model or 'yolov8n-seg.pt', data=cfg.data or 'coco8-seg.yaml') - trainer = SegmentationTrainer(overrides=args) - trainer.train() - - -if __name__ == '__main__': - train() diff --git a/ultralytics/models/yolo/segment/val.py b/ultralytics/models/yolo/segment/val.py index 6a3aa15..5e074ad 100644 --- a/ultralytics/models/yolo/segment/val.py +++ b/ultralytics/models/yolo/segment/val.py @@ -8,7 +8,7 @@ import torch import torch.nn.functional as F from ultralytics.models.yolo.detect import DetectionValidator -from ultralytics.utils import DEFAULT_CFG, LOGGER, NUM_THREADS, ops +from ultralytics.utils import LOGGER, NUM_THREADS, ops from ultralytics.utils.checks import check_requirements from ultralytics.utils.metrics import SegmentMetrics, box_iou, mask_iou from ultralytics.utils.plotting import output_to_target, plot_images @@ -243,14 +243,3 @@ class SegmentationValidator(DetectionValidator): except Exception as e: LOGGER.warning(f'pycocotools unable to run: {e}') return stats - - -def val(cfg=DEFAULT_CFG): - """Validate trained YOLO model on validation data.""" - args = dict(model=cfg.model or 'yolov8n-seg.pt', data=cfg.data or 'coco8-seg.yaml') - validator = SegmentationValidator(args=args) - validator(model=args['model']) - - -if __name__ == '__main__': - val() diff --git a/ultralytics/utils/callbacks/tensorboard.py b/ultralytics/utils/callbacks/tensorboard.py index 696a1b4..c542c9e 100644 --- a/ultralytics/utils/callbacks/tensorboard.py +++ b/ultralytics/utils/callbacks/tensorboard.py @@ -12,24 +12,43 @@ try: except (ImportError, AssertionError, TypeError): SummaryWriter = None -writer = None # TensorBoard SummaryWriter instance +WRITER = None # TensorBoard SummaryWriter instance def _log_scalars(scalars, step=0): """Logs scalar values to TensorBoard.""" - if writer: + if WRITER: for k, v in scalars.items(): - writer.add_scalar(k, v, step) + WRITER.add_scalar(k, v, step) + + +def _log_tensorboard_graph(trainer): + # Log model graph to TensorBoard + try: + import warnings + + from ultralytics.utils.torch_utils import de_parallel, torch + + imgsz = trainer.args.imgsz + imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz + p = next(trainer.model.parameters()) # for device, type + im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input (WARNING: must be zeros, not empty) + with warnings.catch_warnings(category=UserWarning): + warnings.simplefilter('ignore') # suppress jit trace warning + WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), []) + except Exception as e: + LOGGER.warning(f'WARNING ⚠️ TensorBoard graph visualization failure {e}') def on_pretrain_routine_start(trainer): """Initialize TensorBoard logging with SummaryWriter.""" if SummaryWriter: try: - global writer - writer = SummaryWriter(str(trainer.save_dir)) + global WRITER + WRITER = SummaryWriter(str(trainer.save_dir)) prefix = colorstr('TensorBoard: ') LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/") + _log_tensorboard_graph(trainer) except Exception as e: LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}')