Docstring additions (#122)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -116,14 +116,31 @@ def try_export(inner_func):
|
||||
|
||||
|
||||
class Exporter:
|
||||
"""
|
||||
Exporter
|
||||
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
|
||||
A class for exporting a model.
|
||||
|
||||
Attributes:
|
||||
args (OmegaConf): Configuration for the exporter.
|
||||
save_dir (Path): Directory to save results.
|
||||
"""
|
||||
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
|
||||
"""
|
||||
Initializes the Exporter class.
|
||||
|
||||
Args:
|
||||
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
self.args = get_config(config, overrides)
|
||||
project = self.args.project or f"runs/{self.args.task}"
|
||||
name = self.args.name or "exp" # hardcode mode as export doesn't require it
|
||||
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.imgsz = self.args.imgsz
|
||||
|
||||
@smart_inference_mode()
|
||||
def __call__(self, model=None):
|
||||
@ -143,7 +160,7 @@ class Exporter:
|
||||
assert not self.args.dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic'
|
||||
|
||||
# Checks
|
||||
self.imgsz = check_imgsz(self.imgsz, stride=model.stride, min_dim=2) # check image size
|
||||
self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
|
||||
if self.args.optimize:
|
||||
assert self.device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
|
||||
from ultralytics import yolo # noqa required for python usage
|
||||
from ultralytics import yolo # noqa
|
||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_weights
|
||||
from ultralytics.yolo.configs import get_config
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
@ -9,7 +9,7 @@ from ultralytics.yolo.utils.checks import check_imgsz, check_yaml
|
||||
from ultralytics.yolo.utils.files import yaml_load
|
||||
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode
|
||||
|
||||
# map head: [model, trainer, validator, predictor]
|
||||
# Map head to model, trainer, validator, and predictor classes
|
||||
MODEL_MAP = {
|
||||
"classify": [
|
||||
ClassificationModel, 'yolo.TYPE.classify.ClassificationTrainer', 'yolo.TYPE.classify.ClassificationValidator',
|
||||
@ -24,39 +24,44 @@ MODEL_MAP = {
|
||||
|
||||
class YOLO:
|
||||
"""
|
||||
Python interface which emulates a model-like behaviour by wrapping trainers.
|
||||
YOLO
|
||||
|
||||
A python interface which emulates a model-like behaviour by wrapping trainers.
|
||||
"""
|
||||
__init_key = object()
|
||||
__init_key = object() # used to ensure proper initialization
|
||||
|
||||
def __init__(self, init_key=None, type="v8") -> None:
|
||||
"""
|
||||
Initializes the YOLO object.
|
||||
|
||||
Args:
|
||||
type (str): Type/version of models to use
|
||||
init_key (object): used to ensure proper initialization. Defaults to None.
|
||||
type (str): Type/version of models to use. Defaults to "v8".
|
||||
"""
|
||||
if init_key != YOLO.__init_key:
|
||||
raise SyntaxError(HELP_MSG)
|
||||
|
||||
self.type = type
|
||||
self.ModelClass = None
|
||||
self.TrainerClass = None
|
||||
self.ValidatorClass = None
|
||||
self.PredictorClass = None
|
||||
self.model = None
|
||||
self.trainer = None
|
||||
self.task = None
|
||||
self.ModelClass = None # model class
|
||||
self.TrainerClass = None # trainer class
|
||||
self.ValidatorClass = None # validator class
|
||||
self.PredictorClass = None # predictor class
|
||||
self.model = None # model object
|
||||
self.trainer = None # trainer object
|
||||
self.task = None # task type
|
||||
self.ckpt = None # if loaded from *.pt
|
||||
self.cfg = None # if loaded from *.yaml
|
||||
self.overrides = {}
|
||||
self.init_disabled = False
|
||||
self.overrides = {} # overrides for trainer object
|
||||
self.init_disabled = False # disable model initialization
|
||||
|
||||
@classmethod
|
||||
def new(cls, cfg: str, verbose=True):
|
||||
"""
|
||||
Initializes a new model and infers the task type from the model definitions
|
||||
Initializes a new model and infers the task type from the model definitions.
|
||||
|
||||
Args:
|
||||
cfg (str): model configuration file
|
||||
verbsoe (bool): display model info on load
|
||||
verbose (bool): display model info on load
|
||||
"""
|
||||
cfg = check_yaml(cfg) # check YAML
|
||||
cfg_dict = yaml_load(cfg) # model dict
|
||||
|
@ -41,8 +41,36 @@ from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mo
|
||||
|
||||
|
||||
class BasePredictor:
|
||||
"""
|
||||
BasePredictor
|
||||
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
|
||||
A base class for creating predictors.
|
||||
|
||||
Attributes:
|
||||
args (OmegaConf): Configuration for the predictor.
|
||||
save_dir (Path): Directory to save results.
|
||||
done_setup (bool): Whether the predictor has finished setup.
|
||||
model (nn.Module): Model used for prediction.
|
||||
data (dict): Data configuration.
|
||||
device (torch.device): Device used for prediction.
|
||||
dataset (Dataset): Dataset used for prediction.
|
||||
vid_path (str): Path to video file.
|
||||
vid_writer (cv2.VideoWriter): Video writer for saving video output.
|
||||
view_img (bool): Whether to view image output.
|
||||
annotator (Annotator): Annotator used for prediction.
|
||||
data_path (str): Path to data.
|
||||
"""
|
||||
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
|
||||
"""
|
||||
Initializes the BasePredictor class.
|
||||
|
||||
Args:
|
||||
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
self.args = get_config(config, overrides)
|
||||
project = self.args.project or f"runs/{self.args.task}"
|
||||
name = self.args.name or f"{self.args.mode}"
|
||||
|
@ -33,9 +33,53 @@ from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds
|
||||
|
||||
|
||||
class BaseTrainer:
|
||||
"""
|
||||
BaseTrainer
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CONFIG, overrides={}):
|
||||
self.args = get_config(cfg, overrides)
|
||||
A base class for creating trainers.
|
||||
|
||||
Attributes:
|
||||
args (OmegaConf): Configuration for the trainer.
|
||||
check_resume (method): Method to check if training should be resumed from a saved checkpoint.
|
||||
console (logging.Logger): Logger instance.
|
||||
validator (BaseValidator): Validator instance.
|
||||
model (nn.Module): Model instance.
|
||||
callbacks (defaultdict): Dictionary of callbacks.
|
||||
save_dir (Path): Directory to save results.
|
||||
wdir (Path): Directory to save weights.
|
||||
last (Path): Path to last checkpoint.
|
||||
best (Path): Path to best checkpoint.
|
||||
batch_size (int): Batch size for training.
|
||||
epochs (int): Number of epochs to train for.
|
||||
start_epoch (int): Starting epoch for training.
|
||||
device (torch.device): Device to use for training.
|
||||
amp (bool): Flag to enable AMP (Automatic Mixed Precision).
|
||||
scaler (amp.GradScaler): Gradient scaler for AMP.
|
||||
data (str): Path to data.
|
||||
trainset (torch.utils.data.Dataset): Training dataset.
|
||||
testset (torch.utils.data.Dataset): Testing dataset.
|
||||
ema (nn.Module): EMA (Exponential Moving Average) of the model.
|
||||
lf (nn.Module): Loss function.
|
||||
scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
|
||||
best_fitness (float): The best fitness value achieved.
|
||||
fitness (float): Current fitness value.
|
||||
loss (float): Current loss value.
|
||||
tloss (float): Total loss value.
|
||||
loss_names (list): List of loss names.
|
||||
csv (Path): Path to results CSV file.
|
||||
"""
|
||||
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
|
||||
"""
|
||||
Initializes the BaseTrainer class.
|
||||
|
||||
Args:
|
||||
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
self.args = get_config(config, overrides)
|
||||
self.check_resume()
|
||||
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
||||
|
||||
@ -464,6 +508,19 @@ class BaseTrainer:
|
||||
|
||||
@staticmethod
|
||||
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
|
||||
"""
|
||||
Builds an optimizer with the specified parameters and parameter groups.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model to optimize
|
||||
name (str): name of the optimizer to use
|
||||
lr (float): learning rate
|
||||
momentum (float): momentum
|
||||
decay (float): weight decay
|
||||
|
||||
Returns:
|
||||
torch.optim.Optimizer: the built optimizer
|
||||
"""
|
||||
g = [], [], [] # optimizer parameter groups
|
||||
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
||||
for v in model.modules():
|
||||
|
@ -16,10 +16,36 @@ from ultralytics.yolo.utils.torch_utils import de_parallel, select_device, smart
|
||||
|
||||
class BaseValidator:
|
||||
"""
|
||||
Base validator class.
|
||||
BaseValidator
|
||||
|
||||
A base class for creating validators.
|
||||
|
||||
Attributes:
|
||||
dataloader (DataLoader): Dataloader to use for validation.
|
||||
pbar (tqdm): Progress bar to update during validation.
|
||||
logger (logging.Logger): Logger to use for validation.
|
||||
args (OmegaConf): Configuration for the validator.
|
||||
model (nn.Module): Model to validate.
|
||||
data (dict): Data dictionary.
|
||||
device (torch.device): Device to use for validation.
|
||||
batch_i (int): Current batch index.
|
||||
training (bool): Whether the model is in training mode.
|
||||
speed (float): Batch processing speed in seconds.
|
||||
jdict (dict): Dictionary to store validation results.
|
||||
save_dir (Path): Directory to save results.
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
|
||||
"""
|
||||
Initializes a BaseValidator instance.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
|
||||
save_dir (Path): Directory to save results.
|
||||
pbar (tqdm.tqdm): Progress bar for displaying progress.
|
||||
logger (logging.Logger): Logger to log messages.
|
||||
args (OmegaConf): Configuration for the validator.
|
||||
"""
|
||||
self.dataloader = dataloader
|
||||
self.pbar = pbar
|
||||
self.logger = logger or LOGGER
|
||||
|
Reference in New Issue
Block a user