General refactoring and improvements (#373)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -32,7 +32,7 @@ class YOLO:
|
||||
|
||||
def __init__(self, model='yolov8n.yaml', type="v8") -> None:
|
||||
"""
|
||||
> Initializes the YOLO object.
|
||||
Initializes the YOLO object.
|
||||
|
||||
Args:
|
||||
model (str, Path): model to load or create
|
||||
@ -59,7 +59,7 @@ class YOLO:
|
||||
|
||||
def _new(self, cfg: str, verbose=True):
|
||||
"""
|
||||
> Initializes a new model and infers the task type from the model definitions.
|
||||
Initializes a new model and infers the task type from the model definitions.
|
||||
|
||||
Args:
|
||||
cfg (str): model configuration file
|
||||
@ -75,7 +75,7 @@ class YOLO:
|
||||
|
||||
def _load(self, weights: str):
|
||||
"""
|
||||
> Initializes a new model and infers the task type from the model head.
|
||||
Initializes a new model and infers the task type from the model head.
|
||||
|
||||
Args:
|
||||
weights (str): model checkpoint to be loaded
|
||||
@ -90,7 +90,7 @@ class YOLO:
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
> Resets the model modules.
|
||||
Resets the model modules.
|
||||
"""
|
||||
for m in self.model.modules():
|
||||
if hasattr(m, 'reset_parameters'):
|
||||
@ -100,7 +100,7 @@ class YOLO:
|
||||
|
||||
def info(self, verbose=False):
|
||||
"""
|
||||
> Logs model info.
|
||||
Logs model info.
|
||||
|
||||
Args:
|
||||
verbose (bool): Controls verbosity.
|
||||
@ -133,7 +133,7 @@ class YOLO:
|
||||
@smart_inference_mode()
|
||||
def val(self, data=None, **kwargs):
|
||||
"""
|
||||
> Validate a model on a given dataset .
|
||||
Validate a model on a given dataset .
|
||||
|
||||
Args:
|
||||
data (str): The dataset to validate on. Accepts all formats accepted by yolo
|
||||
@ -152,7 +152,7 @@ class YOLO:
|
||||
@smart_inference_mode()
|
||||
def export(self, **kwargs):
|
||||
"""
|
||||
> Export model.
|
||||
Export model.
|
||||
|
||||
Args:
|
||||
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
|
||||
@ -168,7 +168,7 @@ class YOLO:
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""
|
||||
> Trains the model on a given dataset.
|
||||
Trains the model on a given dataset.
|
||||
|
||||
Args:
|
||||
**kwargs (Any): Any number of arguments representing the training configuration. List of all args can be found in 'config' section.
|
||||
@ -197,7 +197,7 @@ class YOLO:
|
||||
|
||||
def to(self, device):
|
||||
"""
|
||||
> Sends the model to the given device.
|
||||
Sends the model to the given device.
|
||||
|
||||
Args:
|
||||
device (str): device
|
||||
|
@ -89,7 +89,7 @@ class BasePredictor:
|
||||
self.vid_path, self.vid_writer = None, None
|
||||
self.annotator = None
|
||||
self.data_path = None
|
||||
self.output = dict()
|
||||
self.output = {}
|
||||
self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
@ -216,7 +216,7 @@ class BasePredictor:
|
||||
self.run_callbacks("on_predict_end")
|
||||
|
||||
def predict_cli(self, source=None, model=None, return_outputs=False):
|
||||
# as __call__ is a genertor now so have to treat it like a genertor
|
||||
# as __call__ is a generator now so have to treat it like a generator
|
||||
for _ in (self.__call__(source, model, return_outputs)):
|
||||
pass
|
||||
|
||||
|
@ -40,7 +40,7 @@ class BaseTrainer:
|
||||
"""
|
||||
BaseTrainer
|
||||
|
||||
> A base class for creating trainers.
|
||||
A base class for creating trainers.
|
||||
|
||||
Attributes:
|
||||
args (OmegaConf): Configuration for the trainer.
|
||||
@ -75,7 +75,7 @@ class BaseTrainer:
|
||||
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
|
||||
"""
|
||||
> Initializes the BaseTrainer class.
|
||||
Initializes the BaseTrainer class.
|
||||
|
||||
Args:
|
||||
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||
@ -149,13 +149,13 @@ class BaseTrainer:
|
||||
|
||||
def add_callback(self, event: str, callback):
|
||||
"""
|
||||
> Appends the given callback.
|
||||
Appends the given callback.
|
||||
"""
|
||||
self.callbacks[event].append(callback)
|
||||
|
||||
def set_callback(self, event: str, callback):
|
||||
"""
|
||||
> Overrides the existing callbacks with the given callback.
|
||||
Overrides the existing callbacks with the given callback.
|
||||
"""
|
||||
self.callbacks[event] = [callback]
|
||||
|
||||
@ -194,7 +194,7 @@ class BaseTrainer:
|
||||
|
||||
def _setup_train(self, rank, world_size):
|
||||
"""
|
||||
> Builds dataloaders and optimizer on correct rank process.
|
||||
Builds dataloaders and optimizer on correct rank process.
|
||||
"""
|
||||
# model
|
||||
self.run_callbacks("on_pretrain_routine_start")
|
||||
@ -383,13 +383,13 @@ class BaseTrainer:
|
||||
|
||||
def get_dataset(self, data):
|
||||
"""
|
||||
> Get train, val path from data dict if it exists. Returns None if data format is not recognized.
|
||||
Get train, val path from data dict if it exists. Returns None if data format is not recognized.
|
||||
"""
|
||||
return data["train"], data.get("val") or data.get("test")
|
||||
|
||||
def setup_model(self):
|
||||
"""
|
||||
> load/create/download model for any task.
|
||||
load/create/download model for any task.
|
||||
"""
|
||||
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
||||
return
|
||||
@ -415,13 +415,13 @@ class BaseTrainer:
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
"""
|
||||
> Allows custom preprocessing model inputs and ground truths depending on task type.
|
||||
Allows custom preprocessing model inputs and ground truths depending on task type.
|
||||
"""
|
||||
return batch
|
||||
|
||||
def validate(self):
|
||||
"""
|
||||
> Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
|
||||
Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
|
||||
"""
|
||||
metrics = self.validator(self)
|
||||
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
||||
@ -431,7 +431,7 @@ class BaseTrainer:
|
||||
|
||||
def log(self, text, rank=-1):
|
||||
"""
|
||||
> Logs the given text to given ranks process if provided, otherwise logs to all ranks.
|
||||
Logs the given text to given ranks process if provided, otherwise logs to all ranks.
|
||||
|
||||
Args"
|
||||
text (str): text to log
|
||||
@ -449,13 +449,13 @@ class BaseTrainer:
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0):
|
||||
"""
|
||||
> Returns dataloader derived from torch.data.Dataloader.
|
||||
Returns dataloader derived from torch.data.Dataloader.
|
||||
"""
|
||||
raise NotImplementedError("get_dataloader function not implemented in trainer")
|
||||
|
||||
def criterion(self, preds, batch):
|
||||
"""
|
||||
> Returns loss and individual loss items as Tensor.
|
||||
Returns loss and individual loss items as Tensor.
|
||||
"""
|
||||
raise NotImplementedError("criterion function not implemented in trainer")
|
||||
|
||||
@ -543,7 +543,7 @@ class BaseTrainer:
|
||||
@staticmethod
|
||||
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
|
||||
"""
|
||||
> Builds an optimizer with the specified parameters and parameter groups.
|
||||
Builds an optimizer with the specified parameters and parameter groups.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model to optimize
|
||||
|
@ -10,7 +10,7 @@ except (ModuleNotFoundError, ImportError):
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
experiment = comet_ml.Experiment(project_name=trainer.args.project or "YOLOv8",)
|
||||
experiment = comet_ml.Experiment(project_name=trainer.args.project or "YOLOv8")
|
||||
experiment.log_parameters(dict(trainer.args))
|
||||
|
||||
|
||||
|
@ -12,7 +12,7 @@ from zipfile import ZipFile
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER
|
||||
from ultralytics.yolo.utils import LOGGER, SETTINGS
|
||||
|
||||
|
||||
def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
|
||||
@ -59,7 +59,11 @@ def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'):
|
||||
return response['tag_name'], [x['name'] for x in response['assets']] # tag, assets
|
||||
|
||||
file = Path(str(file).strip().replace("'", ''))
|
||||
if not file.exists():
|
||||
if file.exists():
|
||||
return str(file)
|
||||
elif (SETTINGS['weights_dir'] / file).exists():
|
||||
return str(SETTINGS['weights_dir'] / file)
|
||||
else:
|
||||
# URL specified
|
||||
name = Path(urllib.parse.unquote(str(file))).name # decode '%2F' to '/' etc.
|
||||
if str(file).startswith(('http:/', 'https:/')): # download
|
||||
@ -94,7 +98,7 @@ def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'):
|
||||
min_bytes=1E5,
|
||||
error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/{tag} or {url3}')
|
||||
|
||||
return str(file)
|
||||
return str(file)
|
||||
|
||||
|
||||
def download(url, dir=Path.cwd(), unzip=True, delete=True, curl=False, threads=1, retry=3):
|
||||
|
@ -58,10 +58,9 @@ class ClassificationPredictor(BasePredictor):
|
||||
|
||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||
def predict(cfg):
|
||||
cfg.model = cfg.model or "squeezenet1_0"
|
||||
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
|
||||
cfg.imgsz = check_imgsz(cfg.imgsz, min_dim=2) # check image size
|
||||
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets"
|
||||
|
||||
predictor = ClassificationPredictor(cfg)
|
||||
predictor.predict_cli()
|
||||
|
||||
|
@ -136,7 +136,7 @@ class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||
def train(cfg):
|
||||
cfg.model = cfg.model or "yolov8n-cls.yaml" # or "resnet18"
|
||||
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
|
||||
cfg.data = cfg.data or "mnist160" # or yolo.ClassificationDataset("mnist")
|
||||
cfg.lr0 = 0.1
|
||||
cfg.weight_decay = 5e-5
|
||||
@ -151,10 +151,4 @@ def train(cfg):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
yolo task=classify mode=train model=yolov8n-cls.pt data=mnist160 epochs=10 imgsz=32
|
||||
yolo task=classify mode=val model=runs/classify/train/weights/last.pt data=mnist160 imgsz=32
|
||||
yolo task=classify mode=predict model=runs/classify/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg
|
||||
yolo mode=export model=runs/classify/train/weights/last.pt imgsz=32 format=torchscript
|
||||
"""
|
||||
train()
|
||||
|
@ -48,8 +48,8 @@ class ClassificationValidator(BaseValidator):
|
||||
|
||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||
def val(cfg):
|
||||
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
|
||||
cfg.data = cfg.data or "imagenette160"
|
||||
cfg.model = cfg.model or "resnet18"
|
||||
validator = ClassificationValidator(args=cfg)
|
||||
validator(model=cfg.model)
|
||||
|
||||
|
@ -197,7 +197,7 @@ class Loss:
|
||||
|
||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||
def train(cfg):
|
||||
cfg.model = cfg.model or "yolov8n.yaml"
|
||||
cfg.model = cfg.model or "yolov8n.pt"
|
||||
cfg.data = cfg.data or "coco128.yaml" # or yolo.ClassificationDataset("mnist")
|
||||
cfg.device = cfg.device if cfg.device is not None else ''
|
||||
# trainer = DetectionTrainer(cfg)
|
||||
@ -208,11 +208,4 @@ def train(cfg):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
CLI usage:
|
||||
python ultralytics/yolo/v8/detect/train.py model=yolov8n.yaml data=coco128 epochs=100 imgsz=640
|
||||
|
||||
TODO:
|
||||
yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=100
|
||||
"""
|
||||
train()
|
||||
|
@ -234,6 +234,7 @@ class DetectionValidator(BaseValidator):
|
||||
|
||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||
def val(cfg):
|
||||
cfg.model = cfg.model or "yolov8n.pt"
|
||||
cfg.data = cfg.data or "coco128.yaml"
|
||||
validator = DetectionValidator(args=cfg)
|
||||
validator(model=cfg.model)
|
||||
|
@ -143,7 +143,7 @@ class SegLoss(Loss):
|
||||
|
||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||
def train(cfg):
|
||||
cfg.model = cfg.model or "yolov8n-seg.yaml"
|
||||
cfg.model = cfg.model or "yolov8n-seg.pt"
|
||||
cfg.data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist")
|
||||
cfg.device = cfg.device if cfg.device is not None else ''
|
||||
# trainer = SegmentationTrainer(cfg)
|
||||
@ -154,11 +154,4 @@ def train(cfg):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
CLI usage:
|
||||
python ultralytics/yolo/v8/segment/train.py model=yolov8n-seg.yaml data=coco128-segments epochs=100 imgsz=640
|
||||
|
||||
TODO:
|
||||
Direct cli support, i.e, yolov8 classify_train args.epochs 10
|
||||
"""
|
||||
train()
|
||||
|
@ -114,8 +114,9 @@ class SegmentationValidator(DetectionValidator):
|
||||
masks=True)
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.process_batch(predn, labelsn)
|
||||
self.stats.append((correct_masks, correct_bboxes, pred[:, 4], pred[:,
|
||||
5], cls.squeeze(-1))) # conf, pcls, tcls
|
||||
|
||||
# Append correct_masks, correct_boxes, pconf, pcls, tcls
|
||||
self.stats.append((correct_masks, correct_bboxes, pred[:, 4], pred[:, 5], cls.squeeze(-1)))
|
||||
|
||||
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
|
||||
if self.args.plots and self.batch_i < 3:
|
||||
|
Reference in New Issue
Block a user