[Docs]: Add customization tutorial and address feedback (#155)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Ayush Chaurasia
2023-01-08 18:31:22 +05:30
committed by GitHub
parent c985eaba0d
commit d387359f74
8 changed files with 133 additions and 74 deletions

View File

@ -30,7 +30,7 @@ class YOLO:
def __init__(self, model='yolov8n.yaml', type="v8") -> None:
"""
Initializes the YOLO object.
> Initializes the YOLO object.
Args:
model (str, Path): model to load or create
@ -57,7 +57,7 @@ class YOLO:
def _new(self, cfg: str, verbose=True):
"""
Initializes a new model and infers the task type from the model definitions.
> Initializes a new model and infers the task type from the model definitions.
Args:
cfg (str): model configuration file
@ -73,7 +73,7 @@ class YOLO:
def _load(self, weights: str):
"""
Initializes a new model and infers the task type from the model head
> Initializes a new model and infers the task type from the model head.
Args:
weights (str): model checkpoint to be loaded
@ -88,7 +88,7 @@ class YOLO:
def reset(self):
"""
Resets the model modules .
> Resets the model modules.
"""
for m in self.model.modules():
if hasattr(m, 'reset_parameters'):
@ -98,7 +98,7 @@ class YOLO:
def info(self, verbose=False):
"""
Logs model info
> Logs model info.
Args:
verbose (bool): Controls verbosity.
@ -129,7 +129,7 @@ class YOLO:
@smart_inference_mode()
def val(self, data=None, **kwargs):
"""
Validate a model on a given dataset
> Validate a model on a given dataset .
Args:
data (str): The dataset to validate on. Accepts all formats accepted by yolo
@ -148,7 +148,7 @@ class YOLO:
@smart_inference_mode()
def export(self, **kwargs):
"""
Export model.
> Export model.
Args:
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
@ -164,7 +164,7 @@ class YOLO:
def train(self, **kwargs):
"""
Trains the model on given dataset.
> Trains the model on a given dataset.
Args:
**kwargs (Any): Any number of arguments representing the training configuration. List of all args can be found in 'config' section.
@ -189,6 +189,12 @@ class YOLO:
self.trainer.train()
def to(self, device):
"""
> Sends the model to the given device.
Args:
device (str): device
"""
self.model.to(device)
def _guess_ops_from_task(self, task):

View File

@ -39,7 +39,7 @@ class BaseTrainer:
"""
BaseTrainer
A base class for creating trainers.
> A base class for creating trainers.
Attributes:
args (OmegaConf): Configuration for the trainer.
@ -74,7 +74,7 @@ class BaseTrainer:
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
"""
Initializes the BaseTrainer class.
> Initializes the BaseTrainer class.
Args:
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
@ -148,13 +148,13 @@ class BaseTrainer:
def add_callback(self, event: str, callback):
"""
Appends the given callback. TODO: unused, consider removing
> Appends the given callback.
"""
self.callbacks[event].append(callback)
def set_callback(self, event: str, callback):
"""
Overrides the existing callbacks with the given callback. TODO: unused, consider removing
> Overrides the existing callbacks with the given callback.
"""
self.callbacks[event] = [callback]
@ -185,7 +185,7 @@ class BaseTrainer:
def _setup_train(self, rank, world_size):
"""
Builds dataloaders and optimizer on correct rank process
> Builds dataloaders and optimizer on correct rank process.
"""
# model
self.run_callbacks("on_pretrain_routine_start")
@ -373,13 +373,13 @@ class BaseTrainer:
def get_dataset(self, data):
"""
Get train, val path from data dict if it exists. Returns None if data format is not recognized
> Get train, val path from data dict if it exists. Returns None if data format is not recognized.
"""
return data["train"], data.get("val") or data.get("test")
def setup_model(self):
"""
load/create/download model for any task
> load/create/download model for any task.
"""
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return
@ -405,15 +405,13 @@ class BaseTrainer:
def preprocess_batch(self, batch):
"""
Allows custom preprocessing model inputs and ground truths depending on task type
> Allows custom preprocessing model inputs and ground truths depending on task type.
"""
return batch
def validate(self):
"""
Runs validation on test set using self.validator.
# TODO: discuss validator class. Enforce that a validator metrics dict should contain
"fitness" metric.
> Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
"""
metrics = self.validator(self)
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
@ -423,9 +421,11 @@ class BaseTrainer:
def log(self, text, rank=-1):
"""
Logs the given text to given ranks process if provided, otherwise logs to all ranks
:param text: text to log
:param rank: List[Int]
> Logs the given text to given ranks process if provided, otherwise logs to all ranks.
Args"
text (str): text to log
rank (List[Int]): process rank
"""
if rank in {-1, 0}:
@ -439,13 +439,13 @@ class BaseTrainer:
def get_dataloader(self, dataset_path, batch_size=16, rank=0):
"""
Returns dataloader derived from torch.data.Dataloader
> Returns dataloader derived from torch.data.Dataloader.
"""
raise NotImplementedError("get_dataloader function not implemented in trainer")
def criterion(self, preds, batch):
"""
Returns loss and individual loss items as Tensor
> Returns loss and individual loss items as Tensor.
"""
raise NotImplementedError("criterion function not implemented in trainer")
@ -531,7 +531,7 @@ class BaseTrainer:
@staticmethod
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
"""
Builds an optimizer with the specified parameters and parameter groups.
> Builds an optimizer with the specified parameters and parameter groups.
Args:
model (nn.Module): model to optimize