General refactoring and improvements (#373)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -32,7 +32,7 @@ class YOLO:
|
||||
|
||||
def __init__(self, model='yolov8n.yaml', type="v8") -> None:
|
||||
"""
|
||||
> Initializes the YOLO object.
|
||||
Initializes the YOLO object.
|
||||
|
||||
Args:
|
||||
model (str, Path): model to load or create
|
||||
@ -59,7 +59,7 @@ class YOLO:
|
||||
|
||||
def _new(self, cfg: str, verbose=True):
|
||||
"""
|
||||
> Initializes a new model and infers the task type from the model definitions.
|
||||
Initializes a new model and infers the task type from the model definitions.
|
||||
|
||||
Args:
|
||||
cfg (str): model configuration file
|
||||
@ -75,7 +75,7 @@ class YOLO:
|
||||
|
||||
def _load(self, weights: str):
|
||||
"""
|
||||
> Initializes a new model and infers the task type from the model head.
|
||||
Initializes a new model and infers the task type from the model head.
|
||||
|
||||
Args:
|
||||
weights (str): model checkpoint to be loaded
|
||||
@ -90,7 +90,7 @@ class YOLO:
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
> Resets the model modules.
|
||||
Resets the model modules.
|
||||
"""
|
||||
for m in self.model.modules():
|
||||
if hasattr(m, 'reset_parameters'):
|
||||
@ -100,7 +100,7 @@ class YOLO:
|
||||
|
||||
def info(self, verbose=False):
|
||||
"""
|
||||
> Logs model info.
|
||||
Logs model info.
|
||||
|
||||
Args:
|
||||
verbose (bool): Controls verbosity.
|
||||
@ -133,7 +133,7 @@ class YOLO:
|
||||
@smart_inference_mode()
|
||||
def val(self, data=None, **kwargs):
|
||||
"""
|
||||
> Validate a model on a given dataset .
|
||||
Validate a model on a given dataset .
|
||||
|
||||
Args:
|
||||
data (str): The dataset to validate on. Accepts all formats accepted by yolo
|
||||
@ -152,7 +152,7 @@ class YOLO:
|
||||
@smart_inference_mode()
|
||||
def export(self, **kwargs):
|
||||
"""
|
||||
> Export model.
|
||||
Export model.
|
||||
|
||||
Args:
|
||||
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
|
||||
@ -168,7 +168,7 @@ class YOLO:
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""
|
||||
> Trains the model on a given dataset.
|
||||
Trains the model on a given dataset.
|
||||
|
||||
Args:
|
||||
**kwargs (Any): Any number of arguments representing the training configuration. List of all args can be found in 'config' section.
|
||||
@ -197,7 +197,7 @@ class YOLO:
|
||||
|
||||
def to(self, device):
|
||||
"""
|
||||
> Sends the model to the given device.
|
||||
Sends the model to the given device.
|
||||
|
||||
Args:
|
||||
device (str): device
|
||||
|
@ -89,7 +89,7 @@ class BasePredictor:
|
||||
self.vid_path, self.vid_writer = None, None
|
||||
self.annotator = None
|
||||
self.data_path = None
|
||||
self.output = dict()
|
||||
self.output = {}
|
||||
self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
@ -216,7 +216,7 @@ class BasePredictor:
|
||||
self.run_callbacks("on_predict_end")
|
||||
|
||||
def predict_cli(self, source=None, model=None, return_outputs=False):
|
||||
# as __call__ is a genertor now so have to treat it like a genertor
|
||||
# as __call__ is a generator now so have to treat it like a generator
|
||||
for _ in (self.__call__(source, model, return_outputs)):
|
||||
pass
|
||||
|
||||
|
@ -40,7 +40,7 @@ class BaseTrainer:
|
||||
"""
|
||||
BaseTrainer
|
||||
|
||||
> A base class for creating trainers.
|
||||
A base class for creating trainers.
|
||||
|
||||
Attributes:
|
||||
args (OmegaConf): Configuration for the trainer.
|
||||
@ -75,7 +75,7 @@ class BaseTrainer:
|
||||
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
|
||||
"""
|
||||
> Initializes the BaseTrainer class.
|
||||
Initializes the BaseTrainer class.
|
||||
|
||||
Args:
|
||||
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||
@ -149,13 +149,13 @@ class BaseTrainer:
|
||||
|
||||
def add_callback(self, event: str, callback):
|
||||
"""
|
||||
> Appends the given callback.
|
||||
Appends the given callback.
|
||||
"""
|
||||
self.callbacks[event].append(callback)
|
||||
|
||||
def set_callback(self, event: str, callback):
|
||||
"""
|
||||
> Overrides the existing callbacks with the given callback.
|
||||
Overrides the existing callbacks with the given callback.
|
||||
"""
|
||||
self.callbacks[event] = [callback]
|
||||
|
||||
@ -194,7 +194,7 @@ class BaseTrainer:
|
||||
|
||||
def _setup_train(self, rank, world_size):
|
||||
"""
|
||||
> Builds dataloaders and optimizer on correct rank process.
|
||||
Builds dataloaders and optimizer on correct rank process.
|
||||
"""
|
||||
# model
|
||||
self.run_callbacks("on_pretrain_routine_start")
|
||||
@ -383,13 +383,13 @@ class BaseTrainer:
|
||||
|
||||
def get_dataset(self, data):
|
||||
"""
|
||||
> Get train, val path from data dict if it exists. Returns None if data format is not recognized.
|
||||
Get train, val path from data dict if it exists. Returns None if data format is not recognized.
|
||||
"""
|
||||
return data["train"], data.get("val") or data.get("test")
|
||||
|
||||
def setup_model(self):
|
||||
"""
|
||||
> load/create/download model for any task.
|
||||
load/create/download model for any task.
|
||||
"""
|
||||
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
||||
return
|
||||
@ -415,13 +415,13 @@ class BaseTrainer:
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
"""
|
||||
> Allows custom preprocessing model inputs and ground truths depending on task type.
|
||||
Allows custom preprocessing model inputs and ground truths depending on task type.
|
||||
"""
|
||||
return batch
|
||||
|
||||
def validate(self):
|
||||
"""
|
||||
> Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
|
||||
Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
|
||||
"""
|
||||
metrics = self.validator(self)
|
||||
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
||||
@ -431,7 +431,7 @@ class BaseTrainer:
|
||||
|
||||
def log(self, text, rank=-1):
|
||||
"""
|
||||
> Logs the given text to given ranks process if provided, otherwise logs to all ranks.
|
||||
Logs the given text to given ranks process if provided, otherwise logs to all ranks.
|
||||
|
||||
Args"
|
||||
text (str): text to log
|
||||
@ -449,13 +449,13 @@ class BaseTrainer:
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0):
|
||||
"""
|
||||
> Returns dataloader derived from torch.data.Dataloader.
|
||||
Returns dataloader derived from torch.data.Dataloader.
|
||||
"""
|
||||
raise NotImplementedError("get_dataloader function not implemented in trainer")
|
||||
|
||||
def criterion(self, preds, batch):
|
||||
"""
|
||||
> Returns loss and individual loss items as Tensor.
|
||||
Returns loss and individual loss items as Tensor.
|
||||
"""
|
||||
raise NotImplementedError("criterion function not implemented in trainer")
|
||||
|
||||
@ -543,7 +543,7 @@ class BaseTrainer:
|
||||
@staticmethod
|
||||
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
|
||||
"""
|
||||
> Builds an optimizer with the specified parameters and parameter groups.
|
||||
Builds an optimizer with the specified parameters and parameter groups.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model to optimize
|
||||
|
Reference in New Issue
Block a user