[Docs]: Add customization tutorial and address feedback (#155)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
@ -30,7 +30,7 @@ class YOLO:
|
||||
|
||||
def __init__(self, model='yolov8n.yaml', type="v8") -> None:
|
||||
"""
|
||||
Initializes the YOLO object.
|
||||
> Initializes the YOLO object.
|
||||
|
||||
Args:
|
||||
model (str, Path): model to load or create
|
||||
@ -57,7 +57,7 @@ class YOLO:
|
||||
|
||||
def _new(self, cfg: str, verbose=True):
|
||||
"""
|
||||
Initializes a new model and infers the task type from the model definitions.
|
||||
> Initializes a new model and infers the task type from the model definitions.
|
||||
|
||||
Args:
|
||||
cfg (str): model configuration file
|
||||
@ -73,7 +73,7 @@ class YOLO:
|
||||
|
||||
def _load(self, weights: str):
|
||||
"""
|
||||
Initializes a new model and infers the task type from the model head
|
||||
> Initializes a new model and infers the task type from the model head.
|
||||
|
||||
Args:
|
||||
weights (str): model checkpoint to be loaded
|
||||
@ -88,7 +88,7 @@ class YOLO:
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Resets the model modules .
|
||||
> Resets the model modules.
|
||||
"""
|
||||
for m in self.model.modules():
|
||||
if hasattr(m, 'reset_parameters'):
|
||||
@ -98,7 +98,7 @@ class YOLO:
|
||||
|
||||
def info(self, verbose=False):
|
||||
"""
|
||||
Logs model info
|
||||
> Logs model info.
|
||||
|
||||
Args:
|
||||
verbose (bool): Controls verbosity.
|
||||
@ -129,7 +129,7 @@ class YOLO:
|
||||
@smart_inference_mode()
|
||||
def val(self, data=None, **kwargs):
|
||||
"""
|
||||
Validate a model on a given dataset
|
||||
> Validate a model on a given dataset .
|
||||
|
||||
Args:
|
||||
data (str): The dataset to validate on. Accepts all formats accepted by yolo
|
||||
@ -148,7 +148,7 @@ class YOLO:
|
||||
@smart_inference_mode()
|
||||
def export(self, **kwargs):
|
||||
"""
|
||||
Export model.
|
||||
> Export model.
|
||||
|
||||
Args:
|
||||
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
|
||||
@ -164,7 +164,7 @@ class YOLO:
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""
|
||||
Trains the model on given dataset.
|
||||
> Trains the model on a given dataset.
|
||||
|
||||
Args:
|
||||
**kwargs (Any): Any number of arguments representing the training configuration. List of all args can be found in 'config' section.
|
||||
@ -189,6 +189,12 @@ class YOLO:
|
||||
self.trainer.train()
|
||||
|
||||
def to(self, device):
|
||||
"""
|
||||
> Sends the model to the given device.
|
||||
|
||||
Args:
|
||||
device (str): device
|
||||
"""
|
||||
self.model.to(device)
|
||||
|
||||
def _guess_ops_from_task(self, task):
|
||||
|
@ -39,7 +39,7 @@ class BaseTrainer:
|
||||
"""
|
||||
BaseTrainer
|
||||
|
||||
A base class for creating trainers.
|
||||
> A base class for creating trainers.
|
||||
|
||||
Attributes:
|
||||
args (OmegaConf): Configuration for the trainer.
|
||||
@ -74,7 +74,7 @@ class BaseTrainer:
|
||||
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
|
||||
"""
|
||||
Initializes the BaseTrainer class.
|
||||
> Initializes the BaseTrainer class.
|
||||
|
||||
Args:
|
||||
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||
@ -148,13 +148,13 @@ class BaseTrainer:
|
||||
|
||||
def add_callback(self, event: str, callback):
|
||||
"""
|
||||
Appends the given callback. TODO: unused, consider removing
|
||||
> Appends the given callback.
|
||||
"""
|
||||
self.callbacks[event].append(callback)
|
||||
|
||||
def set_callback(self, event: str, callback):
|
||||
"""
|
||||
Overrides the existing callbacks with the given callback. TODO: unused, consider removing
|
||||
> Overrides the existing callbacks with the given callback.
|
||||
"""
|
||||
self.callbacks[event] = [callback]
|
||||
|
||||
@ -185,7 +185,7 @@ class BaseTrainer:
|
||||
|
||||
def _setup_train(self, rank, world_size):
|
||||
"""
|
||||
Builds dataloaders and optimizer on correct rank process
|
||||
> Builds dataloaders and optimizer on correct rank process.
|
||||
"""
|
||||
# model
|
||||
self.run_callbacks("on_pretrain_routine_start")
|
||||
@ -373,13 +373,13 @@ class BaseTrainer:
|
||||
|
||||
def get_dataset(self, data):
|
||||
"""
|
||||
Get train, val path from data dict if it exists. Returns None if data format is not recognized
|
||||
> Get train, val path from data dict if it exists. Returns None if data format is not recognized.
|
||||
"""
|
||||
return data["train"], data.get("val") or data.get("test")
|
||||
|
||||
def setup_model(self):
|
||||
"""
|
||||
load/create/download model for any task
|
||||
> load/create/download model for any task.
|
||||
"""
|
||||
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
||||
return
|
||||
@ -405,15 +405,13 @@ class BaseTrainer:
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
"""
|
||||
Allows custom preprocessing model inputs and ground truths depending on task type
|
||||
> Allows custom preprocessing model inputs and ground truths depending on task type.
|
||||
"""
|
||||
return batch
|
||||
|
||||
def validate(self):
|
||||
"""
|
||||
Runs validation on test set using self.validator.
|
||||
# TODO: discuss validator class. Enforce that a validator metrics dict should contain
|
||||
"fitness" metric.
|
||||
> Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
|
||||
"""
|
||||
metrics = self.validator(self)
|
||||
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
||||
@ -423,9 +421,11 @@ class BaseTrainer:
|
||||
|
||||
def log(self, text, rank=-1):
|
||||
"""
|
||||
Logs the given text to given ranks process if provided, otherwise logs to all ranks
|
||||
:param text: text to log
|
||||
:param rank: List[Int]
|
||||
> Logs the given text to given ranks process if provided, otherwise logs to all ranks.
|
||||
|
||||
Args"
|
||||
text (str): text to log
|
||||
rank (List[Int]): process rank
|
||||
|
||||
"""
|
||||
if rank in {-1, 0}:
|
||||
@ -439,13 +439,13 @@ class BaseTrainer:
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0):
|
||||
"""
|
||||
Returns dataloader derived from torch.data.Dataloader
|
||||
> Returns dataloader derived from torch.data.Dataloader.
|
||||
"""
|
||||
raise NotImplementedError("get_dataloader function not implemented in trainer")
|
||||
|
||||
def criterion(self, preds, batch):
|
||||
"""
|
||||
Returns loss and individual loss items as Tensor
|
||||
> Returns loss and individual loss items as Tensor.
|
||||
"""
|
||||
raise NotImplementedError("criterion function not implemented in trainer")
|
||||
|
||||
@ -531,7 +531,7 @@ class BaseTrainer:
|
||||
@staticmethod
|
||||
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
|
||||
"""
|
||||
Builds an optimizer with the specified parameters and parameter groups.
|
||||
> Builds an optimizer with the specified parameters and parameter groups.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model to optimize
|
||||
|
Reference in New Issue
Block a user