General refactoring and improvements (#373)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-01-15 14:44:25 +01:00
committed by GitHub
parent ac628c0d3e
commit 583eac0e80
18 changed files with 304 additions and 309 deletions

View File

@ -32,7 +32,7 @@ class YOLO:
def __init__(self, model='yolov8n.yaml', type="v8") -> None:
"""
> Initializes the YOLO object.
Initializes the YOLO object.
Args:
model (str, Path): model to load or create
@ -59,7 +59,7 @@ class YOLO:
def _new(self, cfg: str, verbose=True):
"""
> Initializes a new model and infers the task type from the model definitions.
Initializes a new model and infers the task type from the model definitions.
Args:
cfg (str): model configuration file
@ -75,7 +75,7 @@ class YOLO:
def _load(self, weights: str):
"""
> Initializes a new model and infers the task type from the model head.
Initializes a new model and infers the task type from the model head.
Args:
weights (str): model checkpoint to be loaded
@ -90,7 +90,7 @@ class YOLO:
def reset(self):
"""
> Resets the model modules.
Resets the model modules.
"""
for m in self.model.modules():
if hasattr(m, 'reset_parameters'):
@ -100,7 +100,7 @@ class YOLO:
def info(self, verbose=False):
"""
> Logs model info.
Logs model info.
Args:
verbose (bool): Controls verbosity.
@ -133,7 +133,7 @@ class YOLO:
@smart_inference_mode()
def val(self, data=None, **kwargs):
"""
> Validate a model on a given dataset .
Validate a model on a given dataset .
Args:
data (str): The dataset to validate on. Accepts all formats accepted by yolo
@ -152,7 +152,7 @@ class YOLO:
@smart_inference_mode()
def export(self, **kwargs):
"""
> Export model.
Export model.
Args:
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
@ -168,7 +168,7 @@ class YOLO:
def train(self, **kwargs):
"""
> Trains the model on a given dataset.
Trains the model on a given dataset.
Args:
**kwargs (Any): Any number of arguments representing the training configuration. List of all args can be found in 'config' section.
@ -197,7 +197,7 @@ class YOLO:
def to(self, device):
"""
> Sends the model to the given device.
Sends the model to the given device.
Args:
device (str): device

View File

@ -89,7 +89,7 @@ class BasePredictor:
self.vid_path, self.vid_writer = None, None
self.annotator = None
self.data_path = None
self.output = dict()
self.output = {}
self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks
callbacks.add_integration_callbacks(self)
@ -216,7 +216,7 @@ class BasePredictor:
self.run_callbacks("on_predict_end")
def predict_cli(self, source=None, model=None, return_outputs=False):
# as __call__ is a genertor now so have to treat it like a genertor
# as __call__ is a generator now so have to treat it like a generator
for _ in (self.__call__(source, model, return_outputs)):
pass

View File

@ -40,7 +40,7 @@ class BaseTrainer:
"""
BaseTrainer
> A base class for creating trainers.
A base class for creating trainers.
Attributes:
args (OmegaConf): Configuration for the trainer.
@ -75,7 +75,7 @@ class BaseTrainer:
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
"""
> Initializes the BaseTrainer class.
Initializes the BaseTrainer class.
Args:
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
@ -149,13 +149,13 @@ class BaseTrainer:
def add_callback(self, event: str, callback):
"""
> Appends the given callback.
Appends the given callback.
"""
self.callbacks[event].append(callback)
def set_callback(self, event: str, callback):
"""
> Overrides the existing callbacks with the given callback.
Overrides the existing callbacks with the given callback.
"""
self.callbacks[event] = [callback]
@ -194,7 +194,7 @@ class BaseTrainer:
def _setup_train(self, rank, world_size):
"""
> Builds dataloaders and optimizer on correct rank process.
Builds dataloaders and optimizer on correct rank process.
"""
# model
self.run_callbacks("on_pretrain_routine_start")
@ -383,13 +383,13 @@ class BaseTrainer:
def get_dataset(self, data):
"""
> Get train, val path from data dict if it exists. Returns None if data format is not recognized.
Get train, val path from data dict if it exists. Returns None if data format is not recognized.
"""
return data["train"], data.get("val") or data.get("test")
def setup_model(self):
"""
> load/create/download model for any task.
load/create/download model for any task.
"""
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return
@ -415,13 +415,13 @@ class BaseTrainer:
def preprocess_batch(self, batch):
"""
> Allows custom preprocessing model inputs and ground truths depending on task type.
Allows custom preprocessing model inputs and ground truths depending on task type.
"""
return batch
def validate(self):
"""
> Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
"""
metrics = self.validator(self)
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
@ -431,7 +431,7 @@ class BaseTrainer:
def log(self, text, rank=-1):
"""
> Logs the given text to given ranks process if provided, otherwise logs to all ranks.
Logs the given text to given ranks process if provided, otherwise logs to all ranks.
Args"
text (str): text to log
@ -449,13 +449,13 @@ class BaseTrainer:
def get_dataloader(self, dataset_path, batch_size=16, rank=0):
"""
> Returns dataloader derived from torch.data.Dataloader.
Returns dataloader derived from torch.data.Dataloader.
"""
raise NotImplementedError("get_dataloader function not implemented in trainer")
def criterion(self, preds, batch):
"""
> Returns loss and individual loss items as Tensor.
Returns loss and individual loss items as Tensor.
"""
raise NotImplementedError("criterion function not implemented in trainer")
@ -543,7 +543,7 @@ class BaseTrainer:
@staticmethod
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
"""
> Builds an optimizer with the specified parameters and parameter groups.
Builds an optimizer with the specified parameters and parameter groups.
Args:
model (nn.Module): model to optimize