Docstring additions (#122)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2022-12-31 13:42:45 +01:00
committed by GitHub
parent c9f3e469cb
commit df4fc14c10
10 changed files with 291 additions and 73 deletions

View File

@ -116,14 +116,31 @@ def try_export(inner_func):
class Exporter:
"""
Exporter
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
A class for exporting a model.
Attributes:
args (OmegaConf): Configuration for the exporter.
save_dir (Path): Directory to save results.
"""
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
"""
Initializes the Exporter class.
Args:
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
overrides (dict, optional): Configuration overrides. Defaults to None.
"""
if overrides is None:
overrides = {}
self.args = get_config(config, overrides)
project = self.args.project or f"runs/{self.args.task}"
name = self.args.name or "exp" # hardcode mode as export doesn't require it
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
self.save_dir.mkdir(parents=True, exist_ok=True)
self.imgsz = self.args.imgsz
@smart_inference_mode()
def __call__(self, model=None):
@ -143,7 +160,7 @@ class Exporter:
assert not self.args.dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic'
# Checks
self.imgsz = check_imgsz(self.imgsz, stride=model.stride, min_dim=2) # check image size
self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
if self.args.optimize:
assert self.device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'

View File

@ -1,6 +1,6 @@
import torch
from ultralytics import yolo # noqa required for python usage
from ultralytics import yolo # noqa
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_weights
from ultralytics.yolo.configs import get_config
from ultralytics.yolo.engine.exporter import Exporter
@ -9,7 +9,7 @@ from ultralytics.yolo.utils.checks import check_imgsz, check_yaml
from ultralytics.yolo.utils.files import yaml_load
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode
# map head: [model, trainer, validator, predictor]
# Map head to model, trainer, validator, and predictor classes
MODEL_MAP = {
"classify": [
ClassificationModel, 'yolo.TYPE.classify.ClassificationTrainer', 'yolo.TYPE.classify.ClassificationValidator',
@ -24,39 +24,44 @@ MODEL_MAP = {
class YOLO:
"""
Python interface which emulates a model-like behaviour by wrapping trainers.
YOLO
A python interface which emulates a model-like behaviour by wrapping trainers.
"""
__init_key = object()
__init_key = object() # used to ensure proper initialization
def __init__(self, init_key=None, type="v8") -> None:
"""
Initializes the YOLO object.
Args:
type (str): Type/version of models to use
init_key (object): used to ensure proper initialization. Defaults to None.
type (str): Type/version of models to use. Defaults to "v8".
"""
if init_key != YOLO.__init_key:
raise SyntaxError(HELP_MSG)
self.type = type
self.ModelClass = None
self.TrainerClass = None
self.ValidatorClass = None
self.PredictorClass = None
self.model = None
self.trainer = None
self.task = None
self.ModelClass = None # model class
self.TrainerClass = None # trainer class
self.ValidatorClass = None # validator class
self.PredictorClass = None # predictor class
self.model = None # model object
self.trainer = None # trainer object
self.task = None # task type
self.ckpt = None # if loaded from *.pt
self.cfg = None # if loaded from *.yaml
self.overrides = {}
self.init_disabled = False
self.overrides = {} # overrides for trainer object
self.init_disabled = False # disable model initialization
@classmethod
def new(cls, cfg: str, verbose=True):
"""
Initializes a new model and infers the task type from the model definitions
Initializes a new model and infers the task type from the model definitions.
Args:
cfg (str): model configuration file
verbsoe (bool): display model info on load
verbose (bool): display model info on load
"""
cfg = check_yaml(cfg) # check YAML
cfg_dict = yaml_load(cfg) # model dict

View File

@ -41,8 +41,36 @@ from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mo
class BasePredictor:
"""
BasePredictor
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
A base class for creating predictors.
Attributes:
args (OmegaConf): Configuration for the predictor.
save_dir (Path): Directory to save results.
done_setup (bool): Whether the predictor has finished setup.
model (nn.Module): Model used for prediction.
data (dict): Data configuration.
device (torch.device): Device used for prediction.
dataset (Dataset): Dataset used for prediction.
vid_path (str): Path to video file.
vid_writer (cv2.VideoWriter): Video writer for saving video output.
view_img (bool): Whether to view image output.
annotator (Annotator): Annotator used for prediction.
data_path (str): Path to data.
"""
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
"""
Initializes the BasePredictor class.
Args:
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
overrides (dict, optional): Configuration overrides. Defaults to None.
"""
if overrides is None:
overrides = {}
self.args = get_config(config, overrides)
project = self.args.project or f"runs/{self.args.task}"
name = self.args.name or f"{self.args.mode}"

View File

@ -33,9 +33,53 @@ from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds
class BaseTrainer:
"""
BaseTrainer
def __init__(self, cfg=DEFAULT_CONFIG, overrides={}):
self.args = get_config(cfg, overrides)
A base class for creating trainers.
Attributes:
args (OmegaConf): Configuration for the trainer.
check_resume (method): Method to check if training should be resumed from a saved checkpoint.
console (logging.Logger): Logger instance.
validator (BaseValidator): Validator instance.
model (nn.Module): Model instance.
callbacks (defaultdict): Dictionary of callbacks.
save_dir (Path): Directory to save results.
wdir (Path): Directory to save weights.
last (Path): Path to last checkpoint.
best (Path): Path to best checkpoint.
batch_size (int): Batch size for training.
epochs (int): Number of epochs to train for.
start_epoch (int): Starting epoch for training.
device (torch.device): Device to use for training.
amp (bool): Flag to enable AMP (Automatic Mixed Precision).
scaler (amp.GradScaler): Gradient scaler for AMP.
data (str): Path to data.
trainset (torch.utils.data.Dataset): Training dataset.
testset (torch.utils.data.Dataset): Testing dataset.
ema (nn.Module): EMA (Exponential Moving Average) of the model.
lf (nn.Module): Loss function.
scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
best_fitness (float): The best fitness value achieved.
fitness (float): Current fitness value.
loss (float): Current loss value.
tloss (float): Total loss value.
loss_names (list): List of loss names.
csv (Path): Path to results CSV file.
"""
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
"""
Initializes the BaseTrainer class.
Args:
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
overrides (dict, optional): Configuration overrides. Defaults to None.
"""
if overrides is None:
overrides = {}
self.args = get_config(config, overrides)
self.check_resume()
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
@ -464,6 +508,19 @@ class BaseTrainer:
@staticmethod
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
"""
Builds an optimizer with the specified parameters and parameter groups.
Args:
model (nn.Module): model to optimize
name (str): name of the optimizer to use
lr (float): learning rate
momentum (float): momentum
decay (float): weight decay
Returns:
torch.optim.Optimizer: the built optimizer
"""
g = [], [], [] # optimizer parameter groups
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
for v in model.modules():

View File

@ -16,10 +16,36 @@ from ultralytics.yolo.utils.torch_utils import de_parallel, select_device, smart
class BaseValidator:
"""
Base validator class.
BaseValidator
A base class for creating validators.
Attributes:
dataloader (DataLoader): Dataloader to use for validation.
pbar (tqdm): Progress bar to update during validation.
logger (logging.Logger): Logger to use for validation.
args (OmegaConf): Configuration for the validator.
model (nn.Module): Model to validate.
data (dict): Data dictionary.
device (torch.device): Device to use for validation.
batch_i (int): Current batch index.
training (bool): Whether the model is in training mode.
speed (float): Batch processing speed in seconds.
jdict (dict): Dictionary to store validation results.
save_dir (Path): Directory to save results.
"""
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
"""
Initializes a BaseValidator instance.
Args:
dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
save_dir (Path): Directory to save results.
pbar (tqdm.tqdm): Progress bar for displaying progress.
logger (logging.Logger): Logger to log messages.
args (OmegaConf): Configuration for the validator.
"""
self.dataloader = dataloader
self.pbar = pbar
self.logger = logger or LOGGER