[Docs]: Add customization tutorial and address feedback (#155)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
single_channel
Ayush Chaurasia 2 years ago committed by GitHub
parent c985eaba0d
commit d387359f74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,10 +3,10 @@ If you want to train, validate or run inference on models and don't need to make
!!! tip "Syntax" !!! tip "Syntax"
```bash ```bash
yolo task=detect mode=train model=s.yaml epochs=1 ... yolo task=detect mode=train model=yolov8n.yaml epochs=1 ...
... ... ... ... ... ...
segment infer s-cls.pt segment predict yolov8n-seg.pt
classify val s-seg.pt classify val yolov8n-cls.pt
``` ```
The experiment arguments can be overridden directly by pass `arg=val` covered in the next section. You can run any supported task by setting `task` and `mode` in cli. The experiment arguments can be overridden directly by pass `arg=val` covered in the next section. You can run any supported task by setting `task` and `mode` in cli.
@ -18,13 +18,13 @@ The experiment arguments can be overridden directly by pass `arg=val` covered in
| Instance Segment | `segment` | <pre><code>yolo task=segment mode=train </code></pre> | | Instance Segment | `segment` | <pre><code>yolo task=segment mode=train </code></pre> |
| Classification| `classify` | <pre><code>yolo task=classify mode=train </code></pre> | | Classification| `classify` | <pre><code>yolo task=classify mode=train </code></pre> |
=== "Inference" === "Prediction"
| | `task` | snippet | | | `task` | snippet |
| ----------- | ------------- | ------------------------------------------------------------ | | ----------- | ------------- | ------------------------------------------------------------ |
| Detection | `detect` | <pre><code>yolo task=detect mode=infer </code></pre> | | Detection | `detect` | <pre><code>yolo task=detect mode=predict </code></pre> |
| Instance Segment | `segment` | <pre><code>yolo task=segment mode=infer </code></pre>| | Instance Segment | `segment` | <pre><code>yolo task=segment mode=predict </code></pre>|
| Classification| `classify` | <pre><code>yolo task=classify mode=infer </code></pre>| | Classification| `classify` | <pre><code>yolo task=classify mode=predict </code></pre>|
=== "Validation" === "Validation"

@ -46,7 +46,7 @@ include train, val, and predict.
| model | null | Set the model. Format can differ for task type. Supports `model_name`, `model.yaml` & `model.pt` | | model | null | Set the model. Format can differ for task type. Supports `model_name`, `model.yaml` & `model.pt` |
| data | null | Set the data. Format can differ for task type. Supports `data.yaml`, `data_folder`, `dataset_name` | | data | null | Set the data. Format can differ for task type. Supports `data.yaml`, `data_folder`, `dataset_name` |
### Training settings ### Training
Training settings for YOLO models refer to the various hyperparameters and configurations used to train the model on a Training settings for YOLO models refer to the various hyperparameters and configurations used to train the model on a
dataset. These settings can affect the model's performance, speed, and accuracy. Some common YOLO training settings dataset. These settings can affect the model's performance, speed, and accuracy. Some common YOLO training settings
@ -88,7 +88,7 @@ task.
| mask_ratio | 4 | **Segmentation**: Set mask downsampling | | mask_ratio | 4 | **Segmentation**: Set mask downsampling |
| dropout | `False` | **Classification**: Use dropout while training | | dropout | `False` | **Classification**: Use dropout while training |
### Prediction Settings ### Prediction
Prediction settings for YOLO models refer to the various hyperparameters and configurations used to make predictions Prediction settings for YOLO models refer to the various hyperparameters and configurations used to make predictions
with the model on new data. These settings can affect the model's performance, speed, and accuracy. Some common YOLO with the model on new data. These settings can affect the model's performance, speed, and accuracy. Some common YOLO
@ -114,7 +114,7 @@ given task.
| agnostic_nms | `False` | Class-agnostic NMS | | agnostic_nms | `False` | Class-agnostic NMS |
| retina_masks | `False` | **Segmentation:** High resolution masks | | retina_masks | `False` | **Segmentation:** High resolution masks |
### Validation settings ### Validation
Validation settings for YOLO models refer to the various hyperparameters and configurations used to Validation settings for YOLO models refer to the various hyperparameters and configurations used to
evaluate the model's performance on a validation dataset. These settings can affect the model's performance, speed, and evaluate the model's performance on a validation dataset. These settings can affect the model's performance, speed, and
@ -147,7 +147,7 @@ the specific task the model is being used for and the requirements or constraint
It is important to carefully consider and configure these settings to ensure that the exported model is optimized for It is important to carefully consider and configure these settings to ensure that the exported model is optimized for
the intended use case and can be used effectively in the target environment. the intended use case and can be used effectively in the target environment.
### Augmentation settings ### Augmentation
Augmentation settings for YOLO models refer to the various transformations and modifications Augmentation settings for YOLO models refer to the various transformations and modifications
applied to the training data to increase the diversity and size of the dataset. These settings can affect the model's applied to the training data to increase the diversity and size of the dataset. These settings can affect the model's

@ -0,0 +1,66 @@
Both the Ultralytics YOLO command-line and python interfaces are simply a high-level abstraction on the base engine executors. Let's take a look at the Trainer engine.
## BaseTrainer
BaseTrainer contains the generic boilerplate training routine. It can be customized for any task based over overidding the required functions or operations as long the as correct formats are followed. For example you can support your own custom model and dataloder by just overriding these functions:
* `get_model(cfg, weights)` - The function that builds a the model to be trained
* `get_dataloder()` - The function that builds the dataloder
More details and source code can be found in [`BaseTrainer` Reference](../reference/base_trainer.md)
## DetectionTrainer
Here's how you can use the YOLOv8 `DetectionTrainer` and customize it.
```python
from Ultrlaytics.yolo.v8 import DetectionTrainer
trainer = DetectionTrainer(overrides={...})
trainer.train()
trained_model = trainer.best # get best model
```
### Customizing the DetectionTrainer
Let's customize the trainer **to train a custom detection model** that is not supported directly. You can do this by simply overloading the existing the `get_model` functionality:
```python
from Ultrlaytics.yolo.v8 import DetectionTrainer
class CustomTrainer(DetectionTrainer):
def get_model(self, cfg, weights):
...
trainer = CustomTrainer(overrides={...})
trainer.train()
```
You now realize that you need to customize the trainer further to:
* Customize the `loss function`.
* Add `callback` that uploads model to your google drive after every 10 `epochs`
Here's how you can do it:
```python
from Ultrlaytics.yolo.v8 import DetectionTrainer
class CustomTrainer(DetectionTrainer):
def get_model(self, cfg, weights):
...
def criterion(self, preds, batch):
# get ground truth
imgs = batch["imgs"]
bboxes = batch["bboxes"]
...
return loss, loss_items # see Reference-> Trainer for details on the expected format
# callback to upload model weights
def log_model(trainer):
last_weight_path = trainer.last
...
trainer = CustomTrainer(overrides={...})
trainer.add_callback("on_train_epoch_end", log_model) # Adds to existing callback
trainer.train()
```
To know more about Callback triggering events and entry point, checkout our Callbacks guide # TODO
## Other engine components
There are other componenets that can be customized similarly like `Validators` and `Predictiors`
To know more about their implementation details, go to Reference

@ -49,7 +49,7 @@ For more information about the history and development of YOLO, you can refer to
conference on computer vision and pattern recognition (pp. 779-788). conference on computer vision and pattern recognition (pp. 779-788).
- Redmon, J., & Farhadi, A. (2016). YOLO9000: Better, faster, stronger. In Proceedings - Redmon, J., & Farhadi, A. (2016). YOLO9000: Better, faster, stronger. In Proceedings
### YOLOv8 by Ultralytics ### Ultralytics YOLOv8
YOLOv8 is the latest version of the YOLO object detection and image segmentation model developed by YOLOv8 is the latest version of the YOLO object detection and image segmentation model developed by
Ultralytics. YOLOv8 is a cutting-edge, state-of-the-art (SOTA) model that builds upon the success of previous YOLO Ultralytics. YOLOv8 is a cutting-edge, state-of-the-art (SOTA) model that builds upon the success of previous YOLO

@ -56,42 +56,36 @@ This is the simplest way of simply using yolo models in a python environment. It
More functionality coming soon More functionality coming soon
To know more about using `YOLO` models, refer Model class refernce To know more about using `YOLO` models, refer Model class Reference
[Model reference](reference/model.md){ .md-button .md-button--primary} [Model reference](reference/model.md){ .md-button .md-button--primary}
--- ---
### Customizing Tasks with Trainers ### Using Trainers
`YOLO` model class is a high-level wrapper on the Trainer classes. Each YOLO task has its own trainer that inherits from `BaseTrainer`. `YOLO` model class is a high-level wrapper on the Trainer classes. Each YOLO task has its own trainer that inherits from `BaseTrainer`.
You can easily cusotmize Trainers to support custom tasks or explore R&D ideas. !!! tip "Detection Trainer Example"
!!! tip "Trainer Examples"
=== "DetectionTrainer"
```python ```python
from ultralytics import yolo from ultralytics.yolo import v8 import DetectionTrainer, DetectionValidator, DetectionPredictor
trainer = yolo.DetectionTrainer(data=..., epochs=1) # override default configs # trainer
trainer = yolo.DetectionTrainer(data=..., epochs=1, device="1,2,3,4") # DDP trainer = DetectionTrainer(overrides={})
trainer.train() trainer.train()
``` trained_model = trainer.best
=== "SegmentationTrainer" # Validator
```python val = DetectionValidator(args=...)
from ultralytics import yolo val(model=trained_model)
trainer = yolo.SegmentationTrainer(data=..., epochs=1) # override default configs # predictor
trainer = yolo.SegmentationTrainer(data=..., epochs=1, device="0,1,2,3") # DDP pred = DetectionPredictor(overrides={})
trainer.train() pred(source=SOURCE, model=trained_model)
```
=== "ClassificationTrainer"
```python
from ultralytics import yolo
trainer = yolo.ClassificationTrainer(data=..., epochs=1) # override default configs # resume from last weight
trainer = yolo.ClassificationTrainer(data=..., epochs=1, device="0,1,2,3") # DDP overrides["resume"] = trainer.last
trainer.train() trainer = detect.DetectionTrainer(overrides=overrides)
```
Learn more about Customizing `Trainers`, `Validators` and `Predictors` to suit your project needs in the Customization Section. More details about the base engine classes is available in the reference section. ```
You can easily customize Trainers to support custom tasks or explore R&D ideas.
Learn more about Customizing `Trainers`, `Validators` and `Predictors` to suit your project needs in the Customization Section.
[Customization tutorials](#){ .md-button .md-button--primary} [Customization tutorials](engine.md){ .md-button .md-button--primary}

@ -81,14 +81,7 @@ nav:
- CLI: cli.md - CLI: cli.md
- Python Interface: sdk.md - Python Interface: sdk.md
- Configuration: config.md - Configuration: config.md
- Tasks: - Customization Guide: engine.md
- Detection: tasks/detection.md
- Segmentation: tasks/segmentation.md
- Classification: tasks/classification.md
- Advanced Tutorials:
- Customize Trainer: customize/train.md
- Customize Validator: customize/val.md
- Customize Predictor: customize/predict.md
- Reference: - Reference:
- Python Model interface: reference/model.md - Python Model interface: reference/model.md
- Engine: - Engine:

@ -30,7 +30,7 @@ class YOLO:
def __init__(self, model='yolov8n.yaml', type="v8") -> None: def __init__(self, model='yolov8n.yaml', type="v8") -> None:
""" """
Initializes the YOLO object. > Initializes the YOLO object.
Args: Args:
model (str, Path): model to load or create model (str, Path): model to load or create
@ -57,7 +57,7 @@ class YOLO:
def _new(self, cfg: str, verbose=True): def _new(self, cfg: str, verbose=True):
""" """
Initializes a new model and infers the task type from the model definitions. > Initializes a new model and infers the task type from the model definitions.
Args: Args:
cfg (str): model configuration file cfg (str): model configuration file
@ -73,7 +73,7 @@ class YOLO:
def _load(self, weights: str): def _load(self, weights: str):
""" """
Initializes a new model and infers the task type from the model head > Initializes a new model and infers the task type from the model head.
Args: Args:
weights (str): model checkpoint to be loaded weights (str): model checkpoint to be loaded
@ -88,7 +88,7 @@ class YOLO:
def reset(self): def reset(self):
""" """
Resets the model modules . > Resets the model modules.
""" """
for m in self.model.modules(): for m in self.model.modules():
if hasattr(m, 'reset_parameters'): if hasattr(m, 'reset_parameters'):
@ -98,7 +98,7 @@ class YOLO:
def info(self, verbose=False): def info(self, verbose=False):
""" """
Logs model info > Logs model info.
Args: Args:
verbose (bool): Controls verbosity. verbose (bool): Controls verbosity.
@ -129,7 +129,7 @@ class YOLO:
@smart_inference_mode() @smart_inference_mode()
def val(self, data=None, **kwargs): def val(self, data=None, **kwargs):
""" """
Validate a model on a given dataset > Validate a model on a given dataset .
Args: Args:
data (str): The dataset to validate on. Accepts all formats accepted by yolo data (str): The dataset to validate on. Accepts all formats accepted by yolo
@ -148,7 +148,7 @@ class YOLO:
@smart_inference_mode() @smart_inference_mode()
def export(self, **kwargs): def export(self, **kwargs):
""" """
Export model. > Export model.
Args: Args:
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
@ -164,7 +164,7 @@ class YOLO:
def train(self, **kwargs): def train(self, **kwargs):
""" """
Trains the model on given dataset. > Trains the model on a given dataset.
Args: Args:
**kwargs (Any): Any number of arguments representing the training configuration. List of all args can be found in 'config' section. **kwargs (Any): Any number of arguments representing the training configuration. List of all args can be found in 'config' section.
@ -189,6 +189,12 @@ class YOLO:
self.trainer.train() self.trainer.train()
def to(self, device): def to(self, device):
"""
> Sends the model to the given device.
Args:
device (str): device
"""
self.model.to(device) self.model.to(device)
def _guess_ops_from_task(self, task): def _guess_ops_from_task(self, task):

@ -39,7 +39,7 @@ class BaseTrainer:
""" """
BaseTrainer BaseTrainer
A base class for creating trainers. > A base class for creating trainers.
Attributes: Attributes:
args (OmegaConf): Configuration for the trainer. args (OmegaConf): Configuration for the trainer.
@ -74,7 +74,7 @@ class BaseTrainer:
def __init__(self, config=DEFAULT_CONFIG, overrides=None): def __init__(self, config=DEFAULT_CONFIG, overrides=None):
""" """
Initializes the BaseTrainer class. > Initializes the BaseTrainer class.
Args: Args:
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG. config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
@ -148,13 +148,13 @@ class BaseTrainer:
def add_callback(self, event: str, callback): def add_callback(self, event: str, callback):
""" """
Appends the given callback. TODO: unused, consider removing > Appends the given callback.
""" """
self.callbacks[event].append(callback) self.callbacks[event].append(callback)
def set_callback(self, event: str, callback): def set_callback(self, event: str, callback):
""" """
Overrides the existing callbacks with the given callback. TODO: unused, consider removing > Overrides the existing callbacks with the given callback.
""" """
self.callbacks[event] = [callback] self.callbacks[event] = [callback]
@ -185,7 +185,7 @@ class BaseTrainer:
def _setup_train(self, rank, world_size): def _setup_train(self, rank, world_size):
""" """
Builds dataloaders and optimizer on correct rank process > Builds dataloaders and optimizer on correct rank process.
""" """
# model # model
self.run_callbacks("on_pretrain_routine_start") self.run_callbacks("on_pretrain_routine_start")
@ -373,13 +373,13 @@ class BaseTrainer:
def get_dataset(self, data): def get_dataset(self, data):
""" """
Get train, val path from data dict if it exists. Returns None if data format is not recognized > Get train, val path from data dict if it exists. Returns None if data format is not recognized.
""" """
return data["train"], data.get("val") or data.get("test") return data["train"], data.get("val") or data.get("test")
def setup_model(self): def setup_model(self):
""" """
load/create/download model for any task > load/create/download model for any task.
""" """
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return return
@ -405,15 +405,13 @@ class BaseTrainer:
def preprocess_batch(self, batch): def preprocess_batch(self, batch):
""" """
Allows custom preprocessing model inputs and ground truths depending on task type > Allows custom preprocessing model inputs and ground truths depending on task type.
""" """
return batch return batch
def validate(self): def validate(self):
""" """
Runs validation on test set using self.validator. > Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
# TODO: discuss validator class. Enforce that a validator metrics dict should contain
"fitness" metric.
""" """
metrics = self.validator(self) metrics = self.validator(self)
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
@ -423,9 +421,11 @@ class BaseTrainer:
def log(self, text, rank=-1): def log(self, text, rank=-1):
""" """
Logs the given text to given ranks process if provided, otherwise logs to all ranks > Logs the given text to given ranks process if provided, otherwise logs to all ranks.
:param text: text to log
:param rank: List[Int] Args"
text (str): text to log
rank (List[Int]): process rank
""" """
if rank in {-1, 0}: if rank in {-1, 0}:
@ -439,13 +439,13 @@ class BaseTrainer:
def get_dataloader(self, dataset_path, batch_size=16, rank=0): def get_dataloader(self, dataset_path, batch_size=16, rank=0):
""" """
Returns dataloader derived from torch.data.Dataloader > Returns dataloader derived from torch.data.Dataloader.
""" """
raise NotImplementedError("get_dataloader function not implemented in trainer") raise NotImplementedError("get_dataloader function not implemented in trainer")
def criterion(self, preds, batch): def criterion(self, preds, batch):
""" """
Returns loss and individual loss items as Tensor > Returns loss and individual loss items as Tensor.
""" """
raise NotImplementedError("criterion function not implemented in trainer") raise NotImplementedError("criterion function not implemented in trainer")
@ -531,7 +531,7 @@ class BaseTrainer:
@staticmethod @staticmethod
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5): def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
""" """
Builds an optimizer with the specified parameters and parameter groups. > Builds an optimizer with the specified parameters and parameter groups.
Args: Args:
model (nn.Module): model to optimize model (nn.Module): model to optimize

Loading…
Cancel
Save