diff --git a/docs/reference/models/yolo/segment/train.md b/docs/reference/models/yolo/segment/train.md
index 65aff01..74434a5 100644
--- a/docs/reference/models/yolo/segment/train.md
+++ b/docs/reference/models/yolo/segment/train.md
@@ -12,7 +12,3 @@ keywords: Ultralytics, YOLO, SegmentationTrainer, image segmentation, object det
---
## ::: ultralytics.models.yolo.segment.train.SegmentationTrainer
-
----
-## ::: ultralytics.models.yolo.segment.train.train
-
diff --git a/docs/reference/models/yolo/segment/val.md b/docs/reference/models/yolo/segment/val.md
index 5b20f7e..58d8b9d 100644
--- a/docs/reference/models/yolo/segment/val.md
+++ b/docs/reference/models/yolo/segment/val.md
@@ -12,7 +12,3 @@ keywords: Ultralytics, YOLO, SegmentationValidator, model segmentation, image cl
---
## ::: ultralytics.models.yolo.segment.val.SegmentationValidator
-
----
-## ::: ultralytics.models.yolo.segment.val.val
-
diff --git a/docs/reference/utils/callbacks/tensorboard.md b/docs/reference/utils/callbacks/tensorboard.md
index 3cdd2d5..95db268 100644
--- a/docs/reference/utils/callbacks/tensorboard.md
+++ b/docs/reference/utils/callbacks/tensorboard.md
@@ -13,6 +13,10 @@ keywords: Ultralytics, YOLO, documentation, callback utilities, log_scalars, on_
## ::: ultralytics.utils.callbacks.tensorboard._log_scalars
+---
+## ::: ultralytics.utils.callbacks.tensorboard._log_tensorboard_graph
+
+
---
## ::: ultralytics.utils.callbacks.tensorboard.on_pretrain_routine_start
diff --git a/ultralytics/models/yolo/segment/train.py b/ultralytics/models/yolo/segment/train.py
index c6e148b..b290192 100644
--- a/ultralytics/models/yolo/segment/train.py
+++ b/ultralytics/models/yolo/segment/train.py
@@ -56,14 +56,3 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
def plot_metrics(self):
"""Plots training/val metrics."""
plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png
-
-
-def train(cfg=DEFAULT_CFG):
- """Train a YOLO segmentation model based on passed arguments."""
- args = dict(model=cfg.model or 'yolov8n-seg.pt', data=cfg.data or 'coco8-seg.yaml')
- trainer = SegmentationTrainer(overrides=args)
- trainer.train()
-
-
-if __name__ == '__main__':
- train()
diff --git a/ultralytics/models/yolo/segment/val.py b/ultralytics/models/yolo/segment/val.py
index 6a3aa15..5e074ad 100644
--- a/ultralytics/models/yolo/segment/val.py
+++ b/ultralytics/models/yolo/segment/val.py
@@ -8,7 +8,7 @@ import torch
import torch.nn.functional as F
from ultralytics.models.yolo.detect import DetectionValidator
-from ultralytics.utils import DEFAULT_CFG, LOGGER, NUM_THREADS, ops
+from ultralytics.utils import LOGGER, NUM_THREADS, ops
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.metrics import SegmentMetrics, box_iou, mask_iou
from ultralytics.utils.plotting import output_to_target, plot_images
@@ -243,14 +243,3 @@ class SegmentationValidator(DetectionValidator):
except Exception as e:
LOGGER.warning(f'pycocotools unable to run: {e}')
return stats
-
-
-def val(cfg=DEFAULT_CFG):
- """Validate trained YOLO model on validation data."""
- args = dict(model=cfg.model or 'yolov8n-seg.pt', data=cfg.data or 'coco8-seg.yaml')
- validator = SegmentationValidator(args=args)
- validator(model=args['model'])
-
-
-if __name__ == '__main__':
- val()
diff --git a/ultralytics/utils/callbacks/tensorboard.py b/ultralytics/utils/callbacks/tensorboard.py
index 696a1b4..c542c9e 100644
--- a/ultralytics/utils/callbacks/tensorboard.py
+++ b/ultralytics/utils/callbacks/tensorboard.py
@@ -12,24 +12,43 @@ try:
except (ImportError, AssertionError, TypeError):
SummaryWriter = None
-writer = None # TensorBoard SummaryWriter instance
+WRITER = None # TensorBoard SummaryWriter instance
def _log_scalars(scalars, step=0):
"""Logs scalar values to TensorBoard."""
- if writer:
+ if WRITER:
for k, v in scalars.items():
- writer.add_scalar(k, v, step)
+ WRITER.add_scalar(k, v, step)
+
+
+def _log_tensorboard_graph(trainer):
+ # Log model graph to TensorBoard
+ try:
+ import warnings
+
+ from ultralytics.utils.torch_utils import de_parallel, torch
+
+ imgsz = trainer.args.imgsz
+ imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
+ p = next(trainer.model.parameters()) # for device, type
+ im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input (WARNING: must be zeros, not empty)
+ with warnings.catch_warnings(category=UserWarning):
+ warnings.simplefilter('ignore') # suppress jit trace warning
+ WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
+ except Exception as e:
+ LOGGER.warning(f'WARNING ⚠️ TensorBoard graph visualization failure {e}')
def on_pretrain_routine_start(trainer):
"""Initialize TensorBoard logging with SummaryWriter."""
if SummaryWriter:
try:
- global writer
- writer = SummaryWriter(str(trainer.save_dir))
+ global WRITER
+ WRITER = SummaryWriter(str(trainer.save_dir))
prefix = colorstr('TensorBoard: ')
LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
+ _log_tensorboard_graph(trainer)
except Exception as e:
LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}')