`ultralytics 8.0.105` classification hyp fix and new `onplot` callbacks (#2684)

Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ivan Shcheklein <shcheklein@gmail.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent b1119d512e
commit 23fc50641c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -33,9 +33,18 @@ def extract_classes_and_functions(filepath):
def create_markdown(py_filepath, module_path, classes, functions):
md_filepath = py_filepath.with_suffix('.md')
# Read existing content and keep header content between first two ---
header_content = ""
if md_filepath.exists():
with open(md_filepath, 'r') as file:
existing_content = file.read()
header_parts = existing_content.split('---', 2)
if len(header_parts) >= 3:
header_content = f"{header_parts[0]}---{header_parts[1]}---\n\n"
md_content = [f"# {class_name}\n---\n:::{module_path}.{class_name}\n<br><br>\n" for class_name in classes]
md_content.extend(f"# {func_name}\n---\n:::{module_path}.{func_name}\n<br><br>\n" for func_name in functions)
md_content = "\n".join(md_content)
md_content = header_content + "\n".join(md_content)
os.makedirs(os.path.dirname(md_filepath), exist_ok=True)
with open(md_filepath, 'w') as file:

@ -1,7 +1,81 @@
---
comments: true
description: Learn about the MNIST dataset, a large database of handwritten digits commonly used for training various image processing systems and machine learning models.
---
# 🚧 Page Under Construction ⚒
# MNIST Dataset
This page is currently under construction! 👷Please check back later for updates. 😃🔜
The [MNIST](http://yann.lecun.com/exdb/mnist/) (Modified National Institute of Standards and Technology) dataset is a large database of handwritten digits that is commonly used for training various image processing systems and machine learning models. It was created by "re-mixing" the samples from NIST's original datasets and has become a benchmark for evaluating the performance of image classification algorithms.
## Key Features
- MNIST contains 60,000 training images and 10,000 testing images of handwritten digits.
- The dataset comprises grayscale images of size 28x28 pixels.
- The images are normalized to fit into a 28x28 pixel bounding box and anti-aliased, introducing grayscale levels.
- MNIST is widely used for training and testing in the field of machine learning, especially for image classification tasks.
## Dataset Structure
The MNIST dataset is split into two subsets:
1. **Training Set**: This subset contains 60,000 images of handwritten digits used for training machine learning models.
2. **Testing Set**: This subset consists of 10,000 images used for testing and benchmarking the trained models.
## Extended MNIST (EMNIST)
Extended MNIST (EMNIST) is a newer dataset developed and released by NIST to be the successor to MNIST. While MNIST included images only of handwritten digits, EMNIST includes all the images from NIST Special Database 19, which is a large database of handwritten uppercase and lowercase letters as well as digits. The images in EMNIST were converted into the same 28x28 pixel format, by the same process, as were the MNIST images. Accordingly, tools that work with the older, smaller MNIST dataset will likely work unmodified with EMNIST.
## Applications
The MNIST dataset is widely used for training and evaluating deep learning models in image classification tasks, such as Convolutional Neural Networks (CNNs), Support Vector Machines (SVMs), and various other machine learning algorithms. The dataset's simple and well-structured format makes it an essential resource for researchers and practitioners in the field of machine learning and computer vision.
## Usage
To train a CNN model on the MNIST dataset for 100 epochs with an image size of 32x32, you can use the following code snippets. For a comprehensive list of available arguments, refer to the model [Training](../../modes/train.md) page.
!!! example "Train Example"
=== "Python"
```python
from ultralytics import YOLO
# Load a model
model = YOLO('yolov8n-cls.pt') # load a pretrained model (recommended for training)
# Train the model
model.train(data='mnist', epochs=100, imgsz=32)
```
=== "CLI"
```bash
# Start training from a pretrained *.pt model
cnn detect train data=MNIST.yaml model=cnn_mnist.pt epochs=100 imgsz=28
```
## Sample Images and Annotations
The MNIST dataset contains grayscale images of handwritten digits, providing a well-structured dataset for image classification tasks. Here are some examples of images from the dataset:
![Dataset sample image](https://upload.wikimedia.org/wikipedia/commons/2/27/MnistExamples.png)
The example showcases the variety and complexity of the handwritten digits in the MNIST dataset, highlighting the importance of a diverse dataset for training robust image classification models.
## Citations and Acknowledgments
If you use the MNIST dataset in your
research or development work, please cite the following paper:
```bibtex
@article{lecun2010mnist,
title={MNIST handwritten digit database},
author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist},
volume={2},
year={2010}
}
```
We would like to acknowledge Yann LeCun, Corinna Cortes, and Christopher J.C. Burges for creating and maintaining the MNIST dataset as a valuable resource for the machine learning and computer vision research community. For more information about the MNIST dataset and its creators, visit the [MNIST dataset website](http://yann.lecun.com/exdb/mnist/).

@ -118,7 +118,7 @@ Auto-annotation is an essential feature that allows you to generate a segmentati
To auto-annotate your dataset using the Ultralytics framework, you can use the `auto_annotate` function as shown below:
```python
from ultralytics.yolo.data import auto_annotate
from ultralytics.yolo.data.annotator import auto_annotate
auto_annotate(data="path/to/images", det_model="yolov8x.pt", sam_model='sam_b.pt')
```

@ -9,6 +9,12 @@ description: Explore RT-DETR, a high-performance real-time object detector. Lear
Real-Time Detection Transformer (RT-DETR) is an end-to-end object detector that provides real-time performance while maintaining high accuracy. It efficiently processes multi-scale features by decoupling intra-scale interaction and cross-scale fusion, and supports flexible adjustment of inference speed using different decoder layers without retraining. RT-DETR outperforms many real-time object detectors on accelerated backends like CUDA with TensorRT.
![Model example image](https://user-images.githubusercontent.com/26833433/238963168-90e8483f-90aa-4eb6-a5e1-0d408b23dd33.png)
**Overview of RT-DETR.** Model architecture diagram showing the last three stages of the backbone {S3, S4, S5} as the input
to the encoder. The efficient hybrid encoder transforms multiscale features into a sequence of image features through intrascale feature interaction (AIFI) and cross-scale feature-fusion module (CCFM). The IoU-aware query selection is employed
to select a fixed number of image features to serve as initial object queries for the decoder. Finally, the decoder with auxiliary
prediction heads iteratively optimizes object queries to generate boxes and confidence scores ([source](https://arxiv.org/pdf/2304.08069.pdf)).
### Key Features
- **Efficient Hybrid Encoder:** RT-DETR uses an efficient hybrid encoder that processes multi-scale features by decoupling intra-scale interaction and cross-scale fusion. This design reduces computational costs and allows for real-time object detection.

@ -57,7 +57,7 @@ Auto-annotation is an essential feature that allows you to generate a [segmentat
To auto-annotate your dataset using the Ultralytics framework, you can use the `auto_annotate` function as shown below:
```python
from ultralytics.yolo.data import auto_annotate
from ultralytics.yolo.data.annotator import auto_annotate
auto_annotate(data="path/to/images", det_model="yolov8x.pt", sam_model='sam_b.pt')
```

@ -5,4 +5,4 @@ description: Learn how to use Ultralytics hub authentication in your projects wi
# Auth
---
:::ultralytics.hub.auth.Auth
<br><br>
<br><br>

@ -5,4 +5,4 @@ description: Accelerate your AI development with the Ultralytics HUB Training Se
# HUBTrainingSession
---
:::ultralytics.hub.session.HUBTrainingSession
<br><br>
<br><br>

@ -20,4 +20,4 @@ description: Explore Ultralytics events, including 'request_with_credentials' an
# smart_request
---
:::ultralytics.hub.utils.smart_request
<br><br>
<br><br>

@ -10,4 +10,4 @@ description: Ensure class names match filenames for easy imports. Use AutoBacken
# check_class_names
---
:::ultralytics.nn.autobackend.check_class_names
<br><br>
<br><br>

@ -10,4 +10,4 @@ description: Detect 80+ object categories with bounding box coordinates and clas
# Detections
---
:::ultralytics.nn.autoshape.Detections
<br><br>
<br><br>

@ -85,4 +85,4 @@ description: Explore ultralytics.nn.modules.block to build powerful YOLO object
# BottleneckCSP
---
:::ultralytics.nn.modules.block.BottleneckCSP
<br><br>
<br><br>

@ -65,4 +65,4 @@ description: Explore convolutional neural network modules & techniques such as L
# autopad
---
:::ultralytics.nn.modules.conv.autopad
<br><br>
<br><br>

@ -25,4 +25,4 @@ description: 'Learn about Ultralytics YOLO modules: Segment, Classify, and RTDET
# RTDETRDecoder
---
:::ultralytics.nn.modules.head.RTDETRDecoder
<br><br>
<br><br>

@ -50,4 +50,4 @@ description: Explore the Ultralytics nn modules pages on Transformer and MLP blo
# DeformableTransformerDecoder
---
:::ultralytics.nn.modules.transformer.DeformableTransformerDecoder
<br><br>
<br><br>

@ -25,4 +25,4 @@ description: 'Learn about Ultralytics NN modules: get_clones, linear_init_, and
# multi_scale_deformable_attn_pytorch
---
:::ultralytics.nn.modules.utils.multi_scale_deformable_attn_pytorch
<br><br>
<br><br>

@ -65,4 +65,4 @@ description: Learn how to work with Ultralytics YOLO Detection, Segmentation & C
# guess_model_task
---
:::ultralytics.nn.tasks.guess_model_task
<br><br>
<br><br>

@ -15,4 +15,4 @@ description: Learn how to register custom event-tracking and track predictions w
# register_tracker
---
:::ultralytics.tracker.track.register_tracker
<br><br>
<br><br>

@ -10,4 +10,4 @@ description: 'TrackState: A comprehensive guide to Ultralytics tracker''s BaseTr
# BaseTrack
---
:::ultralytics.tracker.trackers.basetrack.BaseTrack
<br><br>
<br><br>

@ -10,4 +10,4 @@ description: '"Optimize tracking with Ultralytics BOTrack. Easily sort and track
# BOTSORT
---
:::ultralytics.tracker.trackers.bot_sort.BOTSORT
<br><br>
<br><br>

@ -10,4 +10,4 @@ description: Learn how to track ByteAI model sizes and tips for model optimizati
# BYTETracker
---
:::ultralytics.tracker.trackers.byte_tracker.BYTETracker
<br><br>
<br><br>

@ -5,4 +5,4 @@ description: '"Track Google Marketing Campaigns in GMC with Ultralytics Tracker.
# GMC
---
:::ultralytics.tracker.utils.gmc.GMC
<br><br>
<br><br>

@ -10,4 +10,4 @@ description: Improve object tracking with KalmanFilterXYAH in Ultralytics YOLO -
# KalmanFilterXYWH
---
:::ultralytics.tracker.utils.kalman_filter.KalmanFilterXYWH
<br><br>
<br><br>

@ -60,4 +60,4 @@ description: Learn how to match and fuse object detections for accurate target t
# bbox_ious
---
:::ultralytics.tracker.utils.matching.bbox_ious
<br><br>
<br><br>

@ -5,4 +5,4 @@ description: Learn how to use auto_annotate in Ultralytics YOLO to generate anno
# auto_annotate
---
:::ultralytics.yolo.data.annotator.auto_annotate
<br><br>
<br><br>

@ -90,4 +90,4 @@ description: Use Ultralytics YOLO Data Augmentation transforms with Base, MixUp,
# classify_albumentations
---
:::ultralytics.yolo.data.augment.classify_albumentations
<br><br>
<br><br>

@ -5,4 +5,4 @@ description: Learn about BaseDataset in Ultralytics YOLO, a flexible dataset cla
# BaseDataset
---
:::ultralytics.yolo.data.base.BaseDataset
<br><br>
<br><br>

@ -35,4 +35,4 @@ description: Maximize YOLO performance with Ultralytics' InfiniteDataLoader, see
# load_inference_source
---
:::ultralytics.yolo.data.build.load_inference_source
<br><br>
<br><br>

@ -30,4 +30,4 @@ description: Convert COCO-91 to COCO-80 class, RLE to polygon, and merge multi-s
# delete_dsstore
---
:::ultralytics.yolo.data.converter.delete_dsstore
<br><br>
<br><br>

@ -35,4 +35,4 @@ description: 'Ultralytics YOLO Docs: Learn about stream loaders for image and te
# autocast_list
---
:::ultralytics.yolo.data.dataloaders.stream_loaders.autocast_list
<br><br>
<br><br>

@ -85,4 +85,4 @@ description: Enhance image data with Albumentations CenterCrop, normalize, augme
# classify_transforms
---
:::ultralytics.yolo.data.dataloaders.v5augmentations.classify_transforms
<br><br>
<br><br>

@ -90,4 +90,4 @@ description: Efficiently load images and labels to models using Ultralytics YOLO
# create_classification_dataloader
---
:::ultralytics.yolo.data.dataloaders.v5loader.create_classification_dataloader
<br><br>
<br><br>

@ -15,4 +15,4 @@ description: Create custom YOLOv5 datasets with Ultralytics YOLODataset and Sema
# SemanticDataset
---
:::ultralytics.yolo.data.dataset.SemanticDataset
<br><br>
<br><br>

@ -5,4 +5,4 @@ description: Create a custom dataset of mixed and oriented rectangular objects w
# MixAndRectDataset
---
:::ultralytics.yolo.data.dataset_wrappers.MixAndRectDataset
<br><br>
<br><br>

@ -65,4 +65,4 @@ description: Efficiently handle data in YOLO with Ultralytics. Utilize HUBDatase
# zip_directory
---
:::ultralytics.yolo.data.utils.zip_directory
<br><br>
<br><br>

@ -30,4 +30,4 @@ description: Learn how to export your YOLO model in various formats using Ultral
# export
---
:::ultralytics.yolo.engine.exporter.export
<br><br>
<br><br>

@ -5,4 +5,4 @@ description: Discover the YOLO model of Ultralytics engine to simplify your obje
# YOLO
---
:::ultralytics.yolo.engine.model.YOLO
<br><br>
<br><br>

@ -5,4 +5,4 @@ description: '"The BasePredictor class in Ultralytics YOLO Engine predicts objec
# BasePredictor
---
:::ultralytics.yolo.engine.predictor.BasePredictor
<br><br>
<br><br>

@ -20,4 +20,4 @@ description: Learn about BaseTensor & Boxes in Ultralytics YOLO Engine. Check ou
# Masks
---
:::ultralytics.yolo.engine.results.Masks
<br><br>
<br><br>

@ -10,4 +10,4 @@ description: Train faster with mixed precision. Learn how to use BaseTrainer wit
# check_amp
---
:::ultralytics.yolo.engine.trainer.check_amp
<br><br>
<br><br>

@ -5,4 +5,4 @@ description: Ensure YOLOv5 models meet constraints and standards with the BaseVa
# BaseValidator
---
:::ultralytics.yolo.engine.validator.BaseValidator
<br><br>
<br><br>

@ -10,4 +10,4 @@ description: Dynamically adjusts input size to optimize GPU memory usage during
# autobatch
---
:::ultralytics.yolo.utils.autobatch.autobatch
<br><br>
<br><br>

@ -5,4 +5,4 @@ description: Improve your YOLO's performance and measure its speed. Benchmark ut
# benchmark
---
:::ultralytics.yolo.utils.benchmarks.benchmark
<br><br>
<br><br>

@ -135,4 +135,4 @@ description: Learn about YOLO's callback functions from on_train_start to add_in
# add_integration_callbacks
---
:::ultralytics.yolo.utils.callbacks.base.add_integration_callbacks
<br><br>
<br><br>

@ -35,4 +35,4 @@ description: Improve your YOLOv5 model training with callbacks from ClearML. Lea
# on_train_end
---
:::ultralytics.yolo.utils.callbacks.clearml.on_train_end
<br><br>
<br><br>

@ -120,4 +120,4 @@ description: Learn about YOLO callbacks using the Comet.ml platform, enhancing o
# on_train_end
---
:::ultralytics.yolo.utils.callbacks.comet.on_train_end
<br><br>
<br><br>

@ -40,4 +40,4 @@ description: Improve YOLOv5 model training with Ultralytics' on-train callbacks.
# on_export_start
---
:::ultralytics.yolo.utils.callbacks.hub.on_export_start
<br><br>
<br><br>

@ -15,4 +15,4 @@ description: Track model performance and metrics with MLflow in YOLOv5. Use call
# on_train_end
---
:::ultralytics.yolo.utils.callbacks.mlflow.on_train_end
<br><br>
<br><br>

@ -40,4 +40,4 @@ description: Improve YOLOv5 training with Neptune, a powerful logging tool. Trac
# on_train_end
---
:::ultralytics.yolo.utils.callbacks.neptune.on_train_end
<br><br>
<br><br>

@ -5,4 +5,4 @@ description: '"Improve YOLO model performance with on_fit_epoch_end callback. Le
# on_fit_epoch_end
---
:::ultralytics.yolo.utils.callbacks.raytune.on_fit_epoch_end
<br><br>
<br><br>

@ -20,4 +20,4 @@ description: Learn how to monitor the training process with Tensorboard using Ul
# on_fit_epoch_end
---
:::ultralytics.yolo.utils.callbacks.tensorboard.on_fit_epoch_end
<br><br>
<br><br>

@ -20,4 +20,4 @@ description: Learn how to use Ultralytics YOLO's built-in callbacks `on_pretrain
# on_train_end
---
:::ultralytics.yolo.utils.callbacks.wb.on_train_end
<br><br>
<br><br>

@ -80,4 +80,4 @@ description: 'Check functions for YOLO utils: image size, version, font, require
# print_args
---
:::ultralytics.yolo.utils.checks.print_args
<br><br>
<br><br>

@ -20,4 +20,4 @@ description: Learn how to find free network port and generate DDP (Distributed D
# ddp_cleanup
---
:::ultralytics.yolo.utils.dist.ddp_cleanup
<br><br>
<br><br>

@ -30,4 +30,4 @@ description: Download and unzip YOLO pretrained models. Ultralytics YOLO docs ut
# download
---
:::ultralytics.yolo.utils.downloads.download
<br><br>
<br><br>

@ -5,4 +5,4 @@ description: Learn about HUBModelError in Ultralytics YOLO Docs. Resolve the err
# HUBModelError
---
:::ultralytics.yolo.utils.errors.HUBModelError
<br><br>
<br><br>

@ -35,4 +35,4 @@ description: 'Learn about Ultralytics YOLO files and directory utilities: Workin
# make_dirs
---
:::ultralytics.yolo.utils.files.make_dirs
<br><br>
<br><br>

@ -15,4 +15,4 @@ description: Learn about Bounding Boxes (Bboxes) and _ntuple in Ultralytics YOLO
# _ntuple
---
:::ultralytics.yolo.utils.instance._ntuple
<br><br>
<br><br>

@ -15,4 +15,4 @@ description: Learn about Varifocal Loss and Keypoint Loss in Ultralytics YOLO fo
# KeypointLoss
---
:::ultralytics.yolo.utils.loss.KeypointLoss
<br><br>
<br><br>

@ -95,4 +95,4 @@ description: Explore Ultralytics YOLO's FocalLoss, DetMetrics, PoseMetrics, Clas
# ap_per_class
---
:::ultralytics.yolo.utils.metrics.ap_per_class
<br><br>
<br><br>

@ -135,4 +135,4 @@ description: Learn about various utility functions in Ultralytics YOLO, includin
# clean_str
---
:::ultralytics.yolo.utils.ops.clean_str
<br><br>
<br><br>

@ -40,4 +40,4 @@ description: 'Discover the power of YOLO''s plotting functions: Colors, Labels a
# feature_visualization
---
:::ultralytics.yolo.utils.plotting.feature_visualization
<br><br>
<br><br>

@ -30,4 +30,4 @@ description: Improve your YOLO models with Ultralytics' TaskAlignedAssigner, sel
# bbox2dist
---
:::ultralytics.yolo.utils.tal.bbox2dist
<br><br>
<br><br>

@ -120,4 +120,4 @@ description: Optimize your PyTorch models with Ultralytics YOLO's torch_utils fu
# profile
---
:::ultralytics.yolo.utils.torch_utils.profile
<br><br>
<br><br>

@ -10,4 +10,4 @@ description: Learn how to use ClassificationPredictor in Ultralytics YOLOv8 for
# predict
---
:::ultralytics.yolo.v8.classify.predict.predict
<br><br>
<br><br>

@ -10,4 +10,4 @@ description: Train a custom image classification model using Ultralytics YOLOv8
# train
---
:::ultralytics.yolo.v8.classify.train.train
<br><br>
<br><br>

@ -10,4 +10,4 @@ description: Ensure model classification accuracy with Ultralytics YOLO's Classi
# val
---
:::ultralytics.yolo.v8.classify.val.val
<br><br>
<br><br>

@ -10,4 +10,4 @@ description: Detect and predict objects in images and videos using the Ultralyti
# predict
---
:::ultralytics.yolo.v8.detect.predict.predict
<br><br>
<br><br>

@ -15,4 +15,4 @@ description: Train and optimize custom object detection models with Ultralytics
# train
---
:::ultralytics.yolo.v8.detect.train.train
<br><br>
<br><br>

@ -10,4 +10,4 @@ description: Validate YOLOv5 detections using this PyTorch module. Ensure model
# val
---
:::ultralytics.yolo.v8.detect.val.val
<br><br>
<br><br>

@ -10,4 +10,4 @@ description: Predict human pose coordinates and confidence scores using YOLOv5.
# predict
---
:::ultralytics.yolo.v8.pose.predict.predict
<br><br>
<br><br>

@ -15,4 +15,4 @@ description: Boost posture detection using PoseTrainer and train models using tr
# train
---
:::ultralytics.yolo.v8.pose.train.train
<br><br>
<br><br>

@ -10,4 +10,4 @@ description: Ensure proper human poses in images with YOLOv8 Pose Validation, pa
# val
---
:::ultralytics.yolo.v8.pose.val.val
<br><br>
<br><br>

@ -10,4 +10,4 @@ description: '"Use SegmentationPredictor in YOLOv8 for efficient object detectio
# predict
---
:::ultralytics.yolo.v8.segment.predict.predict
<br><br>
<br><br>

@ -15,4 +15,4 @@ description: Learn about SegmentationTrainer and Train in Ultralytics YOLO v8 fo
# train
---
:::ultralytics.yolo.v8.segment.train.train
<br><br>
<br><br>

@ -10,4 +10,4 @@ description: Ensure segmentation quality on large datasets with SegmentationVali
# val
---
:::ultralytics.yolo.v8.segment.val.val
<br><br>
<br><br>

@ -0,0 +1,19 @@
import pytest
def pytest_addoption(parser):
parser.addoption('--runslow', action='store_true', default=False, help='run slow tests')
def pytest_configure(config):
config.addinivalue_line('markers', 'slow: mark test as slow to run')
def pytest_collection_modifyitems(config, items):
if config.getoption('--runslow'):
# --runslow given in cli: do not skip slow tests
return
skip_slow = pytest.mark.skip(reason='need --runslow option to run')
for item in items:
if 'slow' in item.keywords:
item.add_marker(skip_slow)

@ -3,10 +3,17 @@
import subprocess
from pathlib import Path
from ultralytics.yolo.utils import LINUX, ONLINE, ROOT, SETTINGS
import pytest
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n'
CFG = 'yolov8n'
from ultralytics.yolo.utils import ONLINE, ROOT, SETTINGS
WEIGHT_DIR = Path(SETTINGS['weights_dir'])
TASK_ARGS = [ # (task, model, data)
('detect', 'yolov8n', 'coco8.yaml'), ('segment', 'yolov8n-seg', 'coco8-seg.yaml'),
('classify', 'yolov8n-cls', 'imagenet10'), ('pose', 'yolov8n-pose', 'coco8-pose.yaml')]
EXPORT_ARGS = [ # (model, format)
('yolov8n', 'torchscript'), ('yolov8n-seg', 'torchscript'), ('yolov8n-cls', 'torchscript'),
('yolov8n-pose', 'torchscript')]
def run(cmd):
@ -20,78 +27,33 @@ def test_special_modes():
run('yolo help')
# Train checks ---------------------------------------------------------------------------------------------------------
def test_train_det():
run(f'yolo train detect model={CFG}.yaml data=coco8.yaml imgsz=32 epochs=1 v5loader')
def test_train_seg():
run(f'yolo train segment model={CFG}-seg.yaml data=coco8-seg.yaml imgsz=32 epochs=1')
def test_train_cls():
run(f'yolo train classify model={CFG}-cls.yaml data=imagenet10 imgsz=32 epochs=1')
def test_train_pose():
run(f'yolo train pose model={CFG}-pose.yaml data=coco8-pose.yaml imgsz=32 epochs=1')
# Val checks -----------------------------------------------------------------------------------------------------------
def test_val_detect():
run(f'yolo val detect model={MODEL}.pt data=coco8.yaml imgsz=32')
def test_val_segment():
run(f'yolo val segment model={MODEL}-seg.pt data=coco8-seg.yaml imgsz=32')
def test_val_classify():
run(f'yolo val classify model={MODEL}-cls.pt data=imagenet10 imgsz=32')
@pytest.mark.parametrize('task,model,data', TASK_ARGS)
def test_train(task, model, data):
run(f'yolo train {task} model={model}.yaml data={data} imgsz=32 epochs=1')
def test_val_pose():
run(f'yolo val pose model={MODEL}-pose.pt data=coco8-pose.yaml imgsz=32')
@pytest.mark.parametrize('task,model,data', TASK_ARGS)
def test_val(task, model, data):
run(f'yolo val {task} model={model}.pt data={data} imgsz=32')
# Predict checks -------------------------------------------------------------------------------------------------------
def test_predict_detect():
run(f"yolo predict model={MODEL}.pt source={ROOT / 'assets'} imgsz=32 save save_crop save_txt")
@pytest.mark.parametrize('task,model,data', TASK_ARGS)
def test_predict(task, model, data):
run(f"yolo predict model={model}.pt source={ROOT / 'assets'} imgsz=32 save save_crop save_txt")
if ONLINE:
run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32')
run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_landscape_min.mov imgsz=32')
run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_portrait_min.mov imgsz=32')
def test_predict_segment():
run(f"yolo predict model={MODEL}-seg.pt source={ROOT / 'assets'} imgsz=32 save save_txt")
def test_predict_classify():
run(f"yolo predict model={MODEL}-cls.pt source={ROOT / 'assets'} imgsz=32 save save_txt")
def test_predict_pose():
run(f"yolo predict model={MODEL}-pose.pt source={ROOT / 'assets'} imgsz=32 save save_txt")
# Export checks --------------------------------------------------------------------------------------------------------
def test_export_detect_torchscript():
run(f'yolo export model={MODEL}.pt format=torchscript')
def test_export_segment_torchscript():
run(f'yolo export model={MODEL}-seg.pt format=torchscript')
def test_export_classify_torchscript():
run(f'yolo export model={MODEL}-cls.pt format=torchscript')
run(f'yolo predict model={model}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32')
run(f'yolo predict model={model}.pt source=https://ultralytics.com/assets/decelera_landscape_min.mov imgsz=32')
run(f'yolo predict model={model}.pt source=https://ultralytics.com/assets/decelera_portrait_min.mov imgsz=32')
def test_export_classify_pose():
run(f'yolo export model={MODEL}-pose.pt format=torchscript')
@pytest.mark.parametrize('model,format', EXPORT_ARGS)
def test_export(model, format):
run(f'yolo export model={model}.pt format={format}')
def test_export_detect_edgetpu(enabled=False):
if enabled and LINUX:
run(f'yolo export model={MODEL}.pt format=edgetpu')
# Slow Tests
@pytest.mark.slow
@pytest.mark.parametrize('task,model,data', TASK_ARGS)
def test_train_gpu(task, model, data):
run(f'yolo train {task} model={model}.yaml data={data} imgsz=32 epochs=1 device="0"') # single GPU
run(f'yolo train {task} model={model}.pt data={data} imgsz=32 epochs=1 device="0,1"') # Multi GPU

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.104'
__version__ = '8.0.105'
from ultralytics.hub import start
from ultralytics.vit.rtdetr import RTDETR

@ -789,13 +789,20 @@ def classify_transforms(size=224, mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0)): #
return T.Compose([CenterCrop(size), ToTensor()])
def hsv2colorjitter(h, s, v):
"""Map HSV (hue, saturation, value) jitter into ColorJitter values (brightness, contrast, saturation, hue)"""
return v, v, s, h
def classify_albumentations(
augment=True,
size=224,
scale=(0.08, 1.0),
hflip=0.5,
vflip=0.0,
jitter=0.4,
hsv_h=0.015, # image HSV-Hue augmentation (fraction)
hsv_s=0.7, # image HSV-Saturation augmentation (fraction)
hsv_v=0.4, # image HSV-Value augmentation (fraction)
mean=(0.0, 0.0, 0.0), # IMAGENET_MEAN
std=(1.0, 1.0, 1.0), # IMAGENET_STD
auto_aug=False,
@ -810,16 +817,15 @@ def classify_albumentations(
if augment: # Resize and crop
T = [A.RandomResizedCrop(height=size, width=size, scale=scale)]
if auto_aug:
# TODO: implement AugMix, AutoAug & RandAug in albumentation
# TODO: implement AugMix, AutoAug & RandAug in albumentations
LOGGER.info(f'{prefix}auto augmentations are currently not supported')
else:
if hflip > 0:
T += [A.HorizontalFlip(p=hflip)]
if vflip > 0:
T += [A.VerticalFlip(p=vflip)]
if jitter > 0:
jitter = float(jitter)
T += [A.ColorJitter(jitter, jitter, jitter, 0)] # brightness, contrast, saturation, 0 hue
if any((hsv_h, hsv_s, hsv_v)):
T += [A.ColorJitter(*hsv2colorjitter(hsv_h, hsv_s, hsv_v))] # brightness, contrast, saturation, hue
else: # Use fixed crop for eval set (reproducibility)
T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor

@ -202,21 +202,48 @@ class YOLODataset(BaseDataset):
# Classification dataloaders -------------------------------------------------------------------------------------------
class ClassificationDataset(torchvision.datasets.ImageFolder):
"""
YOLOv5 Classification Dataset.
Arguments
root: Dataset path
transform: torchvision transforms, used by default
album_transform: Albumentations transforms, used if installed
YOLO Classification Dataset.
Args:
root (str): Dataset path.
transform (callable, optional): torchvision transforms, used by default.
album_transform (callable, optional): Albumentations transforms, used if installed.
Attributes:
cache_ram (bool): True if images should be cached in RAM, False otherwise.
cache_disk (bool): True if images should be cached on disk, False otherwise.
samples (list): List of samples containing file, index, npy, and im.
torch_transforms (callable): torchvision transforms applied to the dataset.
album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True.
"""
def __init__(self, root, augment=False, imgsz=224, cache=False):
"""Initialize YOLO object with root, image size, augmentations, and cache settings"""
def __init__(self, root, args, augment=False, cache=False):
"""
Initialize YOLO object with root, image size, augmentations, and cache settings.
Args:
root (str): Dataset path.
args (Namespace): Argument parser containing dataset related settings.
augment (bool, optional): True if dataset should be augmented, False otherwise. Defaults to False.
cache (Union[bool, str], optional): Cache setting, can be True, False, 'ram' or 'disk'. Defaults to False.
"""
super().__init__(root=root)
self.torch_transforms = classify_transforms(imgsz)
self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
self.cache_ram = cache is True or cache == 'ram'
self.cache_disk = cache == 'disk'
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
self.torch_transforms = classify_transforms(args.imgsz)
self.album_transforms = classify_albumentations(
augment=augment,
size=args.imgsz,
scale=(1.0 - args.scale, 1.0), # (0.08, 1.0)
hflip=args.fliplr,
vflip=args.flipud,
hsv_h=args.hsv_h, # HSV-Hue augmentation (fraction)
hsv_s=args.hsv_s, # HSV-Saturation augmentation (fraction)
hsv_v=args.hsv_v, # HSV-Value augmentation (fraction)
mean=(0.0, 0.0, 0.0), # IMAGENET_MEAN
std=(1.0, 1.0, 1.0), # IMAGENET_STD
auto_aug=False) if augment else None
def __getitem__(self, i):
"""Returns subset of data and targets corresponding to given indices."""

@ -85,6 +85,7 @@ class BaseTrainer:
self.validator = None
self.model = None
self.metrics = None
self.plots = {}
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
# Dirs
@ -537,6 +538,10 @@ class BaseTrainer:
"""Plot and display metrics visually."""
pass
def on_plot(self, name, data=None):
"""Registers plots (e.g. to be consumed in callbacks)"""
self.plots[name] = {'data': data, 'timestamp': time.time()}
def final_eval(self):
"""Performs final evaluation and validation for object detection YOLO model."""
for f in self.last, self.best:

@ -19,6 +19,7 @@ Usage - formats:
yolov8n_paddle_model # PaddlePaddle
"""
import json
import time
from pathlib import Path
import torch
@ -84,6 +85,7 @@ class BaseValidator:
if self.args.conf is None:
self.args.conf = 0.001 # default conf=0.001
self.plots = {}
self.callbacks = _callbacks or callbacks.get_default_callbacks()
@smart_inference_mode()
@ -252,6 +254,10 @@ class BaseValidator:
"""Returns the metric keys used in YOLO training/validation."""
return []
def on_plot(self, name, data=None):
"""Registers plots (e.g. to be consumed in callbacks)"""
self.plots[name] = {'data': data, 'timestamp': time.time()}
# TODO: may need to put these following functions into callback
def plot_val_samples(self, batch, ni):
"""Plots validation samples during training."""

@ -300,7 +300,7 @@ class ConfusionMatrix:
@TryExcept('WARNING ⚠️ ConfusionMatrix plot failure')
@plt_settings()
def plot(self, normalize=True, save_dir='', names=()):
def plot(self, normalize=True, save_dir='', names=(), on_plot=None):
"""
Plot the confusion matrix using seaborn and save it to a file.
@ -308,6 +308,7 @@ class ConfusionMatrix:
normalize (bool): Whether to normalize the confusion matrix.
save_dir (str): Directory where the plot will be saved.
names (tuple): Names of classes, used as labels on the plot.
on_plot (func): An optional callback to pass plots path and data when they are rendered.
"""
import seaborn as sn
@ -336,8 +337,11 @@ class ConfusionMatrix:
ax.set_xlabel('True')
ax.set_ylabel('Predicted')
ax.set_title(title)
fig.savefig(Path(save_dir) / f'{title.lower().replace(" ", "_")}.png', dpi=250)
plot_fname = Path(save_dir) / f'{title.lower().replace(" ", "_")}.png'
fig.savefig(plot_fname, dpi=250)
plt.close(fig)
if on_plot:
on_plot(plot_fname)
def print(self):
"""
@ -356,7 +360,7 @@ def smooth(y, f=0.05):
@plt_settings()
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=(), on_plot=None):
"""Plots a precision-recall curve."""
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
py = np.stack(py, axis=1)
@ -376,10 +380,12 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
ax.set_title('Precision-Recall Curve')
fig.savefig(save_dir, dpi=250)
plt.close(fig)
if on_plot:
on_plot(save_dir)
@plt_settings()
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric', on_plot=None):
"""Plots a metric-confidence curve."""
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
@ -399,6 +405,8 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi
ax.set_title(f'{ylabel}-Confidence Curve')
fig.savefig(save_dir, dpi=250)
plt.close(fig)
if on_plot:
on_plot(save_dir)
def compute_ap(recall, precision):
@ -434,7 +442,16 @@ def compute_ap(recall, precision):
return ap, mpre, mrec
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), names=(), eps=1e-16, prefix=''):
def ap_per_class(tp,
conf,
pred_cls,
target_cls,
plot=False,
on_plot=None,
save_dir=Path(),
names=(),
eps=1e-16,
prefix=''):
"""
Computes the average precision per class for object detection evaluation.
@ -444,6 +461,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), na
pred_cls (np.ndarray): Array of predicted classes of the detections.
target_cls (np.ndarray): Array of true classes of the detections.
plot (bool, optional): Whether to plot PR curves or not. Defaults to False.
on_plot (func, optional): A callback to pass plots path and data when they are rendered. Defaults to None.
save_dir (Path, optional): Directory to save the PR curves. Defaults to an empty path.
names (tuple, optional): Tuple of class names to plot PR curves. Defaults to an empty tuple.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-16.
@ -502,10 +520,10 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), na
names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
names = dict(enumerate(names)) # to dict
if plot:
plot_pr_curve(px, py, ap, save_dir / f'{prefix}PR_curve.png', names)
plot_mc_curve(px, f1, save_dir / f'{prefix}F1_curve.png', names, ylabel='F1')
plot_mc_curve(px, p, save_dir / f'{prefix}P_curve.png', names, ylabel='Precision')
plot_mc_curve(px, r, save_dir / f'{prefix}R_curve.png', names, ylabel='Recall')
plot_pr_curve(px, py, ap, save_dir / f'{prefix}PR_curve.png', names, on_plot=on_plot)
plot_mc_curve(px, f1, save_dir / f'{prefix}F1_curve.png', names, ylabel='F1', on_plot=on_plot)
plot_mc_curve(px, p, save_dir / f'{prefix}P_curve.png', names, ylabel='Precision', on_plot=on_plot)
plot_mc_curve(px, r, save_dir / f'{prefix}R_curve.png', names, ylabel='Recall', on_plot=on_plot)
i = smooth(f1.mean(0), 0.1).argmax() # max F1 index
p, r, f1 = p[:, i], r[:, i], f1[:, i]
@ -657,11 +675,13 @@ class DetMetrics(SimpleClass):
Args:
save_dir (Path): A path to the directory where the output plots will be saved. Defaults to current directory.
plot (bool): A flag that indicates whether to plot precision-recall curves for each class. Defaults to False.
on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
names (tuple of str): A tuple of strings that represents the names of the classes. Defaults to an empty tuple.
Attributes:
save_dir (Path): A path to the directory where the output plots will be saved.
plot (bool): A flag that indicates whether to plot the precision-recall curves for each class.
on_plot (func): An optional callback to pass plots path and data when they are rendered.
names (tuple of str): A tuple of strings that represents the names of the classes.
box (Metric): An instance of the Metric class for storing the results of the detection metrics.
speed (dict): A dictionary for storing the execution time of different parts of the detection process.
@ -677,9 +697,10 @@ class DetMetrics(SimpleClass):
results_dict: Returns a dictionary that maps detection metric keys to their computed values.
"""
def __init__(self, save_dir=Path('.'), plot=False, names=()) -> None:
def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
self.save_dir = save_dir
self.plot = plot
self.on_plot = on_plot
self.names = names
self.box = Metric()
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
@ -732,11 +753,13 @@ class SegmentMetrics(SimpleClass):
Args:
save_dir (Path): Path to the directory where the output plots should be saved. Default is the current directory.
plot (bool): Whether to save the detection and segmentation plots. Default is False.
on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
names (list): List of class names. Default is an empty list.
Attributes:
save_dir (Path): Path to the directory where the output plots should be saved.
plot (bool): Whether to save the detection and segmentation plots.
on_plot (func): An optional callback to pass plots path and data when they are rendered.
names (list): List of class names.
box (Metric): An instance of the Metric class to calculate box detection metrics.
seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
@ -752,9 +775,10 @@ class SegmentMetrics(SimpleClass):
results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
"""
def __init__(self, save_dir=Path('.'), plot=False, names=()) -> None:
def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
self.save_dir = save_dir
self.plot = plot
self.on_plot = on_plot
self.names = names
self.box = Metric()
self.seg = Metric()
@ -777,6 +801,7 @@ class SegmentMetrics(SimpleClass):
pred_cls,
target_cls,
plot=self.plot,
on_plot=self.on_plot,
save_dir=self.save_dir,
names=self.names,
prefix='Mask')[2:]
@ -787,6 +812,7 @@ class SegmentMetrics(SimpleClass):
pred_cls,
target_cls,
plot=self.plot,
on_plot=self.on_plot,
save_dir=self.save_dir,
names=self.names,
prefix='Box')[2:]
@ -836,11 +862,13 @@ class PoseMetrics(SegmentMetrics):
Args:
save_dir (Path): Path to the directory where the output plots should be saved. Default is the current directory.
plot (bool): Whether to save the detection and segmentation plots. Default is False.
on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
names (list): List of class names. Default is an empty list.
Attributes:
save_dir (Path): Path to the directory where the output plots should be saved.
plot (bool): Whether to save the detection and segmentation plots.
on_plot (func): An optional callback to pass plots path and data when they are rendered.
names (list): List of class names.
box (Metric): An instance of the Metric class to calculate box detection metrics.
pose (Metric): An instance of the Metric class to calculate mask segmentation metrics.
@ -856,10 +884,11 @@ class PoseMetrics(SegmentMetrics):
results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
"""
def __init__(self, save_dir=Path('.'), plot=False, names=()) -> None:
def __init__(self, save_dir=Path('.'), plot=False, on_plot=None, names=()) -> None:
super().__init__(save_dir, plot, names)
self.save_dir = save_dir
self.plot = plot
self.on_plot = on_plot
self.names = names
self.box = Metric()
self.pose = Metric()
@ -887,6 +916,7 @@ class PoseMetrics(SegmentMetrics):
pred_cls,
target_cls,
plot=self.plot,
on_plot=self.on_plot,
save_dir=self.save_dir,
names=self.names,
prefix='Pose')[2:]
@ -897,6 +927,7 @@ class PoseMetrics(SegmentMetrics):
pred_cls,
target_cls,
plot=self.plot,
on_plot=self.on_plot,
save_dir=self.save_dir,
names=self.names,
prefix='Box')[2:]

@ -228,7 +228,7 @@ class Annotator:
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
@plt_settings()
def plot_labels(boxes, cls, names=(), save_dir=Path('')):
def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
"""Save and plot image with no axis or spines."""
import pandas as pd
import seaborn as sn
@ -271,8 +271,11 @@ def plot_labels(boxes, cls, names=(), save_dir=Path('')):
for s in ['top', 'right', 'left', 'bottom']:
ax[a].spines[s].set_visible(False)
plt.savefig(save_dir / 'labels.jpg', dpi=200)
fname = save_dir / 'labels.jpg'
plt.savefig(fname, dpi=200)
plt.close()
if on_plot:
on_plot(fname)
def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
@ -301,7 +304,8 @@ def plot_images(images,
kpts=np.zeros((0, 51), dtype=np.float32),
paths=None,
fname='images.jpg',
names=None):
names=None,
on_plot=None):
# Plot image grid with labels
if isinstance(images, torch.Tensor):
images = images.cpu().float().numpy()
@ -419,10 +423,12 @@ def plot_images(images,
im[y:y + h, x:x + w, :][mask] = im[y:y + h, x:x + w, :][mask] * 0.4 + np.array(color) * 0.6
annotator.fromarray(im)
annotator.im.save(fname) # save
if on_plot:
on_plot(fname)
@plt_settings()
def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False, classify=False):
def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False, classify=False, on_plot=None):
"""Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')."""
import pandas as pd
save_dir = Path(file).parent if file else Path(dir)
@ -456,8 +462,11 @@ def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False,
except Exception as e:
LOGGER.warning(f'WARNING: Plotting error for {f}: {e}')
ax[1].legend()
fig.savefig(save_dir / 'results.png', dpi=200)
fname = save_dir / 'results.png'
fig.savefig(fname, dpi=200)
plt.close()
if on_plot:
on_plot(fname)
def output_to_target(output, max_det=300):

@ -71,7 +71,7 @@ class ClassificationTrainer(BaseTrainer):
return # dont return ckpt. Classification doesn't support resume
def build_dataset(self, img_path, mode='train', batch=None):
return ClassificationDataset(root=img_path, imgsz=self.args.imgsz, augment=mode == 'train')
return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train')
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
"""Returns PyTorch DataLoader with transforms to preprocess images for inference."""
@ -126,7 +126,7 @@ class ClassificationTrainer(BaseTrainer):
def plot_metrics(self):
"""Plots metrics from a CSV file."""
plot_results(file=self.csv, classify=True) # save results.png
plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png
def final_eval(self):
"""Evaluate trained model and save validation results."""
@ -147,7 +147,8 @@ class ClassificationTrainer(BaseTrainer):
plot_images(images=batch['img'],
batch_idx=torch.arange(len(batch['img'])),
cls=batch['cls'].squeeze(-1),
fname=self.save_dir / f'train_batch{ni}.jpg')
fname=self.save_dir / f'train_batch{ni}.jpg',
on_plot=self.on_plot)
def train(cfg=DEFAULT_CFG, use_python=False):

@ -47,7 +47,10 @@ class ClassificationValidator(BaseValidator):
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
if self.args.plots:
for normalize in True, False:
self.confusion_matrix.plot(save_dir=self.save_dir, names=self.names.values(), normalize=normalize)
self.confusion_matrix.plot(save_dir=self.save_dir,
names=self.names.values(),
normalize=normalize,
on_plot=self.on_plot)
self.metrics.speed = self.speed
self.metrics.confusion_matrix = self.confusion_matrix
@ -57,7 +60,7 @@ class ClassificationValidator(BaseValidator):
return self.metrics.results_dict
def build_dataset(self, img_path):
dataset = ClassificationDataset(root=img_path, imgsz=self.args.imgsz, augment=False)
dataset = ClassificationDataset(root=img_path, args=self.args, augment=False)
return dataset
def get_dataloader(self, dataset_path, batch_size):
@ -76,7 +79,8 @@ class ClassificationValidator(BaseValidator):
batch_idx=torch.arange(len(batch['img'])),
cls=batch['cls'].squeeze(-1),
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names)
names=self.names,
on_plot=self.on_plot)
def plot_predictions(self, batch, preds, ni):
"""Plots predicted bounding boxes on input images and saves the result."""
@ -84,7 +88,8 @@ class ClassificationValidator(BaseValidator):
batch_idx=torch.arange(len(batch['img'])),
cls=torch.argmax(preds, dim=1),
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names) # pred
names=self.names,
on_plot=self.on_plot) # pred
def val(cfg=DEFAULT_CFG, use_python=False):

@ -121,17 +121,18 @@ class DetectionTrainer(BaseTrainer):
cls=batch['cls'].squeeze(-1),
bboxes=batch['bboxes'],
paths=batch['im_file'],
fname=self.save_dir / f'train_batch{ni}.jpg')
fname=self.save_dir / f'train_batch{ni}.jpg',
on_plot=self.on_plot)
def plot_metrics(self):
"""Plots metrics from a CSV file."""
plot_results(file=self.csv) # save results.png
plot_results(file=self.csv, on_plot=self.on_plot) # save results.png
def plot_training_labels(self):
"""Create a labeled training plot of the YOLO model."""
boxes = np.concatenate([lb['bboxes'] for lb in self.train_loader.dataset.labels], 0)
cls = np.concatenate([lb['cls'] for lb in self.train_loader.dataset.labels], 0)
plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir)
plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir, on_plot=self.on_plot)
# Criterion class for computing training losses

@ -24,7 +24,7 @@ class DetectionValidator(BaseValidator):
self.args.task = 'detect'
self.is_coco = False
self.class_map = None
self.metrics = DetMetrics(save_dir=self.save_dir)
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95
self.niou = self.iouv.numel()
@ -145,7 +145,10 @@ class DetectionValidator(BaseValidator):
if self.args.plots:
for normalize in True, False:
self.confusion_matrix.plot(save_dir=self.save_dir, names=self.names.values(), normalize=normalize)
self.confusion_matrix.plot(save_dir=self.save_dir,
names=self.names.values(),
normalize=normalize,
on_plot=self.on_plot)
def _process_batch(self, detections, labels):
"""
@ -215,7 +218,8 @@ class DetectionValidator(BaseValidator):
batch['bboxes'],
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names)
names=self.names,
on_plot=self.on_plot)
def plot_predictions(self, batch, preds, ni):
"""Plots predicted bounding boxes on input images and saves the result."""
@ -223,7 +227,8 @@ class DetectionValidator(BaseValidator):
*output_to_target(preds, max_det=15),
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names) # pred
names=self.names,
on_plot=self.on_plot) # pred
def save_one_txt(self, predn, save_conf, shape, file):
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""

@ -65,11 +65,12 @@ class PoseTrainer(v8.detect.DetectionTrainer):
bboxes,
kpts=kpts,
paths=paths,
fname=self.save_dir / f'train_batch{ni}.jpg')
fname=self.save_dir / f'train_batch{ni}.jpg',
on_plot=self.on_plot)
def plot_metrics(self):
"""Plots training/val metrics."""
plot_results(file=self.csv, pose=True) # save results.png
plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png
# Criterion class for computing training losses

@ -18,7 +18,7 @@ class PoseValidator(DetectionValidator):
"""Initialize a 'PoseValidator' object with custom parameters and assigned attributes."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = 'pose'
self.metrics = PoseMetrics(save_dir=self.save_dir)
self.metrics = PoseMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
def preprocess(self, batch):
"""Preprocesses the batch by converting the 'keypoints' data into a float and moving it to the device."""
@ -150,7 +150,8 @@ class PoseValidator(DetectionValidator):
kpts=batch['keypoints'],
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names)
names=self.names,
on_plot=self.on_plot)
def plot_predictions(self, batch, preds, ni):
"""Plots predictions for YOLO model."""
@ -160,7 +161,8 @@ class PoseValidator(DetectionValidator):
kpts=pred_kpts,
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names) # pred
names=self.names,
on_plot=self.on_plot) # pred
def pred_to_json(self, predn, filename):
"""Converts YOLO predictions to COCO JSON format."""

@ -45,17 +45,18 @@ class SegmentationTrainer(v8.detect.DetectionTrainer):
def plot_training_samples(self, batch, ni):
"""Creates a plot of training sample images with labels and box coordinates."""
images = batch['img']
masks = batch['masks']
cls = batch['cls'].squeeze(-1)
bboxes = batch['bboxes']
paths = batch['im_file']
batch_idx = batch['batch_idx']
plot_images(images, batch_idx, cls, bboxes, masks, paths=paths, fname=self.save_dir / f'train_batch{ni}.jpg')
plot_images(batch['img'],
batch['batch_idx'],
batch['cls'].squeeze(-1),
batch['bboxes'],
batch['masks'],
paths=batch['im_file'],
fname=self.save_dir / f'train_batch{ni}.jpg',
on_plot=self.on_plot)
def plot_metrics(self):
"""Plots training/val metrics."""
plot_results(file=self.csv, segment=True) # save results.png
plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png
# Criterion class for computing training losses

@ -20,7 +20,7 @@ class SegmentationValidator(DetectionValidator):
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = 'segment'
self.metrics = SegmentMetrics(save_dir=self.save_dir)
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
def preprocess(self, batch):
"""Preprocesses batch by converting masks to float and sending to device."""
@ -174,7 +174,8 @@ class SegmentationValidator(DetectionValidator):
batch['masks'],
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names)
names=self.names,
on_plot=self.on_plot)
def plot_predictions(self, batch, preds, ni):
"""Plots batch predictions with masks and bounding boxes."""
@ -183,7 +184,8 @@ class SegmentationValidator(DetectionValidator):
torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names) # pred
names=self.names,
on_plot=self.on_plot) # pred
self.plot_masks.clear()
def pred_to_json(self, predn, filename, pred_masks):

Loading…
Cancel
Save