Add initial model interface (#30)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
		
							
								
								
									
										13
									
								
								ultralytics/tests/test_model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								ultralytics/tests/test_model.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,13 @@ | ||||
| from ultralytics.yolo import YOLO | ||||
|  | ||||
|  | ||||
| def test_model(): | ||||
|     model = YOLO() | ||||
|     model.new("assets/dummy_model.yaml") | ||||
|     model.model = "squeezenet1_0"  # temp solution before get_model is implemented | ||||
|     # model.load("yolov5n.pt") | ||||
|     model.train(data="imagenette160", epochs=1, lr0=0.01) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     test_model() | ||||
| @ -1,4 +1,7 @@ | ||||
| import ultralytics.yolo.v8 as v8 | ||||
|  | ||||
| from .engine.model import YOLO | ||||
| from .engine.trainer import BaseTrainer | ||||
| from .engine.validator import BaseValidator | ||||
|  | ||||
| __all__ = ["BaseTrainer", "BaseValidator"]  # allow simpler import | ||||
| __all__ = ["BaseTrainer", "BaseValidator", "YOLO"]  # allow simpler import | ||||
|  | ||||
| @ -728,7 +728,7 @@ def classify_albumentations( | ||||
|                 if vflip > 0: | ||||
|                     T += [A.VerticalFlip(p=vflip)] | ||||
|                 if jitter > 0: | ||||
|                     color_jitter = (float(jitter),) * 3  # repeat value for brightness, contrast, satuaration, 0 hue | ||||
|                     color_jitter = (float(jitter),) * 3  # repeat value for brightness, contrast, saturation, 0 hue | ||||
|                     T += [A.ColorJitter(*color_jitter, 0)] | ||||
|         else:  # Use fixed crop for eval set (reproducibility) | ||||
|             T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)] | ||||
|  | ||||
| @ -51,7 +51,8 @@ def exif_size(img): | ||||
| def verify_image_label(args): | ||||
|     # Verify one image-label pair | ||||
|     im_file, lb_file, prefix, keypoint = args | ||||
|     nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", None, None  # number (missing, found, empty, corrupt), message, segments, keypoints | ||||
|     # number (missing, found, empty, corrupt), message, segments, keypoints | ||||
|     nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", None, None | ||||
|     try: | ||||
|         # verify images | ||||
|         im = Image.open(im_file) | ||||
| @ -86,10 +87,10 @@ def verify_image_label(args): | ||||
|                     kpts = np.zeros((lb.shape[0], 39)) | ||||
|                     for i in range(len(lb)): | ||||
|                         kpt = np.delete(lb[i, 5:], np.arange(2, lb.shape[1] - 5, | ||||
|                                                              3))  # remove the occlusion paramater from the GT | ||||
|                                                              3))  # remove the occlusion parameter from the GT | ||||
|                         kpts[i] = np.hstack((lb[i, :5], kpt)) | ||||
|                     lb = kpts | ||||
|                     assert lb.shape[1] == 39, "labels require 39 columns each after removing occlusion paramater" | ||||
|                     assert lb.shape[1] == 39, "labels require 39 columns each after removing occlusion parameter" | ||||
|                 else: | ||||
|                     assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected" | ||||
|                     assert (lb >= 0).all(), f"negative label values {lb[lb < 0]}" | ||||
|  | ||||
							
								
								
									
										63
									
								
								ultralytics/yolo/engine/model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								ultralytics/yolo/engine/model.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,63 @@ | ||||
| """ | ||||
| Top-level YOLO model interface. First principle usage example - https://github.com/ultralytics/ultralytics/issues/13 | ||||
| """ | ||||
| import torch | ||||
| import yaml | ||||
|  | ||||
| import ultralytics.yolo as yolo | ||||
| from ultralytics.yolo.utils import LOGGER | ||||
| from ultralytics.yolo.utils.checks import check_yaml | ||||
| from ultralytics.yolo.utils.modeling.tasks import ClassificationModel, DetectionModel, SegmentationModel | ||||
|  | ||||
| # map head: [model, trainer] | ||||
| MODEL_MAP = { | ||||
|     "Classify": [ClassificationModel, 'yolo.VERSION.classify.train.ClassificationTrainer'], | ||||
|     "Detect": [ClassificationModel, 'yolo.VERSION.classify.train.ClassificationTrainer'],  # temp | ||||
|     "Segment": []} | ||||
|  | ||||
|  | ||||
| class YOLO: | ||||
|  | ||||
|     def __init__(self, version=8) -> None: | ||||
|         self.version = version | ||||
|         self.model = None | ||||
|         self.trainer = None | ||||
|         self.pretrained_weights = None | ||||
|  | ||||
|     def new(self, cfg: str): | ||||
|         cfg = check_yaml(cfg)  # check YAML | ||||
|         self.model, self.trainer = self._get_model_and_trainer(cfg) | ||||
|  | ||||
|     def load(self, weights, autodownload=True): | ||||
|         if not isinstance(self.pretrained_weights, type(None)): | ||||
|             LOGGER.info("Overwriting weights") | ||||
|         # TODO: weights = smart_file_loader(weights) | ||||
|         if self.model: | ||||
|             self.model.load(weights) | ||||
|             LOGGER.info("Checkpoint loaded successfully") | ||||
|         else: | ||||
|             # TODO: infer model and trainer | ||||
|             pass | ||||
|  | ||||
|         self.pretrained_weights = weights | ||||
|  | ||||
|     def reset(self): | ||||
|         pass | ||||
|  | ||||
|     def train(self, **kwargs): | ||||
|         if 'data' not in kwargs: | ||||
|             raise Exception("data is required to train") | ||||
|         if not self.model: | ||||
|             raise Exception("model not initialized. Use .new() or .load()") | ||||
|         kwargs["model"] = self.model | ||||
|         trainer = self.trainer(overrides=kwargs) | ||||
|         trainer.train() | ||||
|  | ||||
|     def _get_model_and_trainer(self, cfg): | ||||
|         with open(cfg, encoding='ascii', errors='ignore') as f: | ||||
|             cfg = yaml.safe_load(f)  # model dict | ||||
|         model, trainer = MODEL_MAP[cfg["head"][-1][-2]] | ||||
|         # warning: eval is unsafe. Use with caution | ||||
|         trainer = eval(trainer.replace("VERSION", f"v{self.version}")) | ||||
|  | ||||
|         return model(cfg), trainer | ||||
| @ -7,7 +7,7 @@ import time | ||||
| from collections import defaultdict | ||||
| from datetime import datetime | ||||
| from pathlib import Path | ||||
| from typing import Union | ||||
| from typing import Dict, Union | ||||
|  | ||||
| import torch | ||||
| import torch.distributed as dist | ||||
| @ -29,30 +29,29 @@ DEFAULT_CONFIG = "defaults.yaml" | ||||
|  | ||||
| class BaseTrainer: | ||||
|  | ||||
|     def __init__(self, config=CONFIG_PATH_ABS / DEFAULT_CONFIG): | ||||
|     def __init__(self, config=CONFIG_PATH_ABS / DEFAULT_CONFIG, overrides={}): | ||||
|         self.console = LOGGER | ||||
|         self.model, self.data, self.train, self.hyps = self._get_config(config) | ||||
|         self.args = self._get_config(config, overrides) | ||||
|         self.validator = None | ||||
|         self.callbacks = defaultdict(list) | ||||
|         self.console.info(f"Training config: \n train: \n {self.train} \n hyps: \n {self.hyps}")  # to debug | ||||
|         self.console.info(f"Training config: \n args: \n {self.args}")  # to debug | ||||
|         # Directories | ||||
|         self.save_dir = increment_path(Path(self.train.project) / self.train.name, exist_ok=self.train.exist_ok) | ||||
|         self.save_dir = increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok) | ||||
|         self.wdir = self.save_dir / 'weights' | ||||
|         self.wdir.mkdir(parents=True, exist_ok=True)  # make dir | ||||
|         self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' | ||||
|  | ||||
|         # Save run settings | ||||
|         save_yaml(self.save_dir / 'train.yaml', OmegaConf.to_container(self.train, resolve=True)) | ||||
|         save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True)) | ||||
|  | ||||
|         # device | ||||
|         self.device = utils.torch_utils.select_device(self.train.device, self.train.batch_size) | ||||
|         self.device = utils.torch_utils.select_device(self.args.device, self.args.batch_size) | ||||
|         self.console.info(f"running on device {self.device}") | ||||
|         self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu') | ||||
|  | ||||
|         # Model and Dataloaders. | ||||
|         self.trainset, self.testset = self.get_dataset()  # initialize dataset before as nc is needed for model | ||||
|         self.model = self.get_model() | ||||
|         self.model = self.model.to(self.device) | ||||
|         self.trainset, self.testset = self.get_dataset(self.args.data) | ||||
|         self.model = self.get_model(self.args.model, self.args.pretrained).to(self.device) | ||||
|  | ||||
|         # epoch level metrics | ||||
|         self.metrics = {}  # handle metrics returned by validator | ||||
| @ -63,18 +62,24 @@ class BaseTrainer: | ||||
|         for callback, func in loggers.default_callbacks.items(): | ||||
|             self.add_callback(callback, func) | ||||
|  | ||||
|     def _get_config(self, config: Union[str, Path, DictConfig] = None): | ||||
|     def _get_config(self, config: Union[str, DictConfig], overrides: Union[str, Dict] = {}): | ||||
|         """ | ||||
|         Accepts yaml file name or DictConfig containing experiment configuration. | ||||
|         Returns train and hyps namespace | ||||
|         Returns training args namespace | ||||
|         :param config: Optional file name or DictConfig object | ||||
|         """ | ||||
|         try: | ||||
|             if isinstance(config, (str, Path)): | ||||
|                 config = OmegaConf.load(config) | ||||
|             return config.model, config.data, config.train, config.hyps | ||||
|         except KeyError as e: | ||||
|             raise KeyError("Missing key(s) in config") from e | ||||
|         if isinstance(config, (str, Path)): | ||||
|             config = OmegaConf.load(config) | ||||
|         elif isinstance(config, Dict): | ||||
|             config = OmegaConf.create(config) | ||||
|  | ||||
|         # override | ||||
|         if isinstance(overrides, str): | ||||
|             overrides = OmegaConf.load(overrides) | ||||
|         elif isinstance(overrides, Dict): | ||||
|             overrides = OmegaConf.create(overrides) | ||||
|  | ||||
|         return OmegaConf.merge(config, overrides) | ||||
|  | ||||
|     def add_callback(self, onevent: str, callback): | ||||
|         """ | ||||
| @ -92,7 +97,7 @@ class BaseTrainer: | ||||
|         for callback in self.callbacks.get(onevent, []): | ||||
|             callback(self) | ||||
|  | ||||
|     def run(self): | ||||
|     def train(self): | ||||
|         world_size = torch.cuda.device_count() | ||||
|         if world_size > 1: | ||||
|             mp.spawn(self._do_train, args=(world_size,), nprocs=world_size, join=True) | ||||
| @ -109,21 +114,21 @@ class BaseTrainer: | ||||
|         dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size) | ||||
|         self.model = self.model.to(self.device) | ||||
|         self.model = DDP(self.model, device_ids=[rank]) | ||||
|         self.train.batch_size = self.train.batch_size // world_size | ||||
|         self.args.batch_size = self.args.batch_size // world_size | ||||
|  | ||||
|     def _setup_train(self, rank): | ||||
|         """ | ||||
|         Builds dataloaders and optimizer on correct rank process | ||||
|         """ | ||||
|         self.optimizer = build_optimizer(model=self.model, | ||||
|                                          name=self.train.optimizer, | ||||
|                                          lr=self.hyps.lr0, | ||||
|                                          momentum=self.hyps.momentum, | ||||
|                                          decay=self.hyps.weight_decay) | ||||
|         self.train_loader = self.get_dataloader(self.trainset, batch_size=self.train.batch_size, rank=rank) | ||||
|                                          name=self.args.optimizer, | ||||
|                                          lr=self.args.lr0, | ||||
|                                          momentum=self.args.momentum, | ||||
|                                          decay=self.args.weight_decay) | ||||
|         self.train_loader = self.get_dataloader(self.trainset, batch_size=self.args.batch_size, rank=rank) | ||||
|         if rank in {0, -1}: | ||||
|             print(" Creating testloader rank :", rank) | ||||
|             self.test_loader = self.get_dataloader(self.testset, batch_size=self.train.batch_size * 2, rank=rank) | ||||
|             self.test_loader = self.get_dataloader(self.testset, batch_size=self.args.batch_size * 2, rank=rank) | ||||
|             self.validator = self.get_validator() | ||||
|             print("created testloader :", rank) | ||||
|  | ||||
| @ -138,7 +143,7 @@ class BaseTrainer: | ||||
|         self.epoch_time = None | ||||
|         self.epoch_time_start = time.time() | ||||
|         self.train_time_start = time.time() | ||||
|         for epoch in range(self.train.epochs): | ||||
|         for epoch in range(self.args.epochs): | ||||
|             # callback hook. on_epoch_start | ||||
|             self.model.train() | ||||
|             pbar = enumerate(self.train_loader) | ||||
| @ -165,7 +170,7 @@ class BaseTrainer: | ||||
|                 # log | ||||
|                 mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0)  # (GB) | ||||
|                 if rank in {-1, 0}: | ||||
|                     pbar.desc = f"{f'{epoch + 1}/{self.train.epochs}':>10}{mem:>10}{tloss:>12.3g}" + ' ' * 36 | ||||
|                     pbar.desc = f"{f'{epoch + 1}/{self.args.epochs}':>10}{mem:>10}{tloss:>12.3g}" + ' ' * 36 | ||||
|  | ||||
|             if rank in [-1, 0]: | ||||
|                 # validation | ||||
| @ -174,7 +179,7 @@ class BaseTrainer: | ||||
|                 # callback: on_val_end() | ||||
|  | ||||
|                 # save model | ||||
|                 if (not self.train.nosave) or (self.epoch + 1 == self.train.epochs): | ||||
|                 if (not self.args.nosave) or (self.epoch + 1 == self.args.epochs): | ||||
|                     self.save_model() | ||||
|                     # callback; on_model_save | ||||
|  | ||||
| @ -198,7 +203,7 @@ class BaseTrainer: | ||||
|             'ema': None,  # deepcopy(ema.ema).half(), | ||||
|             'updates': None,  # ema.updates, | ||||
|             'optimizer': None,  # optimizer.state_dict(), | ||||
|             'train_args': self.train, | ||||
|             'train_args': self.args, | ||||
|             'date': datetime.now().isoformat()} | ||||
|  | ||||
|         # Save last, best and delete | ||||
| @ -207,22 +212,22 @@ class BaseTrainer: | ||||
|             torch.save(ckpt, self.best) | ||||
|         del ckpt | ||||
|  | ||||
|     def get_dataloader(self, path): | ||||
|     def get_dataloader(self, dataset_path, batch_size=16, rank=0): | ||||
|         """ | ||||
|         Returns dataloader derived from torch.data.Dataloader | ||||
|         """ | ||||
|         pass | ||||
|  | ||||
|     def get_dataset(self): | ||||
|     def get_dataset(self, data): | ||||
|         """ | ||||
|         Uses self.dataset to download the dataset if needed and verify it. | ||||
|         Download the dataset if needed and verify it. | ||||
|         Returns train and val split datasets | ||||
|         """ | ||||
|         pass | ||||
|  | ||||
|     def get_model(self): | ||||
|     def get_model(self, model, pretrained=True): | ||||
|         """ | ||||
|         Uses self.model to load/create/download dataset for any task | ||||
|         load/create/download model for any task | ||||
|         """ | ||||
|         pass | ||||
|  | ||||
| @ -238,7 +243,7 @@ class BaseTrainer: | ||||
|  | ||||
|     def preprocess_batch(self, images, labels): | ||||
|         """ | ||||
|         Allows custom preprocessing model inputs and ground truths depeding on task type | ||||
|         Allows custom preprocessing model inputs and ground truths depending on task type | ||||
|         """ | ||||
|         return images.to(self.device, non_blocking=True), labels.to(self.device) | ||||
|  | ||||
|  | ||||
| @ -1,53 +1,56 @@ | ||||
| model: null | ||||
| data: null | ||||
| train: | ||||
|   epochs: 300 | ||||
|   batch_size: 16 | ||||
|   img_size: 640 | ||||
|   nosave: False | ||||
|   cache: False # True/ram for ram, or disc | ||||
|   device: '' # cuda device, i.e. 0 or 0,1,2,3 or cpu | ||||
|   workers: 8 | ||||
|   project: "ultralytics-yolo" | ||||
|   name: "exp" # TODO: make this informative, maybe exp{#number}_{datetime} ? | ||||
|   exist_ok: False | ||||
|   pretrained: False | ||||
|   optimizer: "Adam" # choices=['SGD', 'Adam', 'AdamW', 'RMSProp'] | ||||
|   verbose: False | ||||
|   seed: 0 | ||||
|   local_rank: -1 | ||||
|  | ||||
| hyps: | ||||
|   lr0: 0.001  # initial learning rate (SGD=1E-2, Adam=1E-3) | ||||
|   lrf: 0.01  # final OneCycleLR learning rate (lr0 * lrf) | ||||
|   momentum: 0.937  # SGD momentum/Adam beta1 | ||||
|   weight_decay: 0.0005  # optimizer weight decay 5e-4 | ||||
|   warmup_epochs: 3.0  # warmup epochs (fractions ok) | ||||
|   warmup_momentum: 0.8  # warmup initial momentum | ||||
|   warmup_bias_lr: 0.1  # warmup initial bias lr | ||||
|   box: 0.05  # box loss gain | ||||
|   cls: 0.5  # cls loss gain | ||||
|   cls_pw: 1.0  # cls BCELoss positive_weight | ||||
|   obj: 1.0  # obj loss gain (scale with pixels) | ||||
|   obj_pw: 1.0  # obj BCELoss positive_weight | ||||
|   iou_t: 0.20  # IoU training threshold | ||||
|   anchor_t: 4.0  # anchor-multiple threshold | ||||
|   # anchors: 3  # anchors per output layer (0 to ignore) | ||||
|   fl_gamma: 0.0  # focal loss gamma (efficientDet default gamma=1.5) | ||||
|   hsv_h: 0.015  # image HSV-Hue augmentation (fraction) | ||||
|   hsv_s: 0.7  # image HSV-Saturation augmentation (fraction) | ||||
|   hsv_v: 0.4  # image HSV-Value augmentation (fraction) | ||||
|   degrees: 0.0  # image rotation (+/- deg) | ||||
|   translate: 0.1  # image translation (+/- fraction) | ||||
|   scale: 0.5  # image scale (+/- gain) | ||||
|   shear: 0.0  # image shear (+/- deg) | ||||
|   perspective: 0.0  # image perspective (+/- fraction), range 0-0.001 | ||||
|   flipud: 0.0  # image flip up-down (probability) | ||||
|   fliplr: 0.5  # image flip left-right (probability) | ||||
|   mosaic: 1.0  # image mosaic (probability) | ||||
|   mixup: 0.0  # image mixup (probability) | ||||
|   copy_paste: 0.0  # segment copy-paste (probability) | ||||
| # Training options | ||||
| epochs: 300 | ||||
| batch_size: 16 | ||||
| img_size: 640 | ||||
| nosave: False | ||||
| cache: False # True/ram for ram, or disc | ||||
| device: '' # cuda device, i.e. 0 or 0,1,2,3 or cpu | ||||
| workers: 8 | ||||
| project: "ultralytics-yolo" | ||||
| name: "exp" # TODO: make this informative, maybe exp{#number}_{datetime} ? | ||||
| exist_ok: False | ||||
| pretrained: False | ||||
| optimizer: "Adam" # choices=['SGD', 'Adam', 'AdamW', 'RMSProp'] | ||||
| verbose: False | ||||
| seed: 0 | ||||
| local_rank: -1 | ||||
| #-----------------------------------# | ||||
|  | ||||
| # Hyper-parameters | ||||
| lr0: 0.001  # initial learning rate (SGD=1E-2, Adam=1E-3) | ||||
| lrf: 0.01  # final OneCycleLR learning rate (lr0 * lrf) | ||||
| momentum: 0.937  # SGD momentum/Adam beta1 | ||||
| weight_decay: 0.0005  # optimizer weight decay 5e-4 | ||||
| warmup_epochs: 3.0  # warmup epochs (fractions ok) | ||||
| warmup_momentum: 0.8  # warmup initial momentum | ||||
| warmup_bias_lr: 0.1  # warmup initial bias lr | ||||
| box: 0.05  # box loss gain | ||||
| cls: 0.5  # cls loss gain | ||||
| cls_pw: 1.0  # cls BCELoss positive_weight | ||||
| obj: 1.0  # obj loss gain (scale with pixels) | ||||
| obj_pw: 1.0  # obj BCELoss positive_weight | ||||
| iou_t: 0.20  # IoU training threshold | ||||
| anchor_t: 4.0  # anchor-multiple threshold | ||||
| # anchors: 3  # anchors per output layer (0 to ignore) | ||||
| fl_gamma: 0.0  # focal loss gamma (efficientDet default gamma=1.5) | ||||
| hsv_h: 0.015  # image HSV-Hue augmentation (fraction) | ||||
| hsv_s: 0.7  # image HSV-Saturation augmentation (fraction) | ||||
| hsv_v: 0.4  # image HSV-Value augmentation (fraction) | ||||
| degrees: 0.0  # image rotation (+/- deg) | ||||
| translate: 0.1  # image translation (+/- fraction) | ||||
| scale: 0.5  # image scale (+/- gain) | ||||
| shear: 0.0  # image shear (+/- deg) | ||||
| perspective: 0.0  # image perspective (+/- fraction), range 0-0.001 | ||||
| flipud: 0.0  # image flip up-down (probability) | ||||
| fliplr: 0.5  # image flip left-right (probability) | ||||
| mosaic: 1.0  # image mosaic (probability) | ||||
| mixup: 0.0  # image mixup (probability) | ||||
| copy_paste: 0.0  # segment copy-paste (probability) | ||||
|  | ||||
| # Hydra configs ------------------------------------- | ||||
| # to disable hydra directory creation | ||||
| hydra: | ||||
|   output_subdir: null | ||||
|  | ||||
| @ -8,7 +8,8 @@ from ultralytics.yolo.utils import LOGGER | ||||
| from ultralytics.yolo.utils.anchors import check_anchor_order | ||||
| from ultralytics.yolo.utils.modeling import parse_model | ||||
| from ultralytics.yolo.utils.modeling.modules import * | ||||
| from ultralytics.yolo.utils.torch_utils import fuse_conv_and_bn, initialize_weights, model_info, scale_img, time_sync | ||||
| from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_state_dicts, model_info, | ||||
|                                                 scale_img, time_sync) | ||||
|  | ||||
|  | ||||
| class BaseModel(nn.Module): | ||||
| @ -67,6 +68,10 @@ class BaseModel(nn.Module): | ||||
|                 m.anchor_grid = list(map(fn, m.anchor_grid)) | ||||
|         return self | ||||
|  | ||||
|     def load(self, weights): | ||||
|         # Force all tasks implement this function | ||||
|         raise NotImplementedError("This function needs to be implemented by derived classes!") | ||||
|  | ||||
|  | ||||
| class DetectionModel(BaseModel): | ||||
|     # YOLO detection model | ||||
| @ -166,6 +171,12 @@ class DetectionModel(BaseModel): | ||||
|             b.data[:, 5:5 + m.nc] += math.log(0.6 / (m.nc - 0.99999)) if cf is None else torch.log(cf / cf.sum())  # cls | ||||
|             mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) | ||||
|  | ||||
|     def load(self, weights): | ||||
|         ckpt = torch.load(weights, map_location='cpu')  # load checkpoint to CPU to avoid CUDA memory leak | ||||
|         csd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32 | ||||
|         csd = intersect_state_dicts(csd, self.state_dict())  # intersect | ||||
|         self.load_state_dict(csd, strict=False)  # load | ||||
|  | ||||
|  | ||||
| class SegmentationModel(DetectionModel): | ||||
|     # YOLOv5 segmentation model | ||||
| @ -197,3 +208,9 @@ class ClassificationModel(BaseModel): | ||||
|     def _from_yaml(self, cfg): | ||||
|         # Create a YOLOv5 classification model from a *.yaml file | ||||
|         self.model = None | ||||
|  | ||||
|     def load(self, weights): | ||||
|         ckpt = torch.load(weights, map_location='cpu')  # load checkpoint to CPU to avoid CUDA memory leak | ||||
|         csd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32 | ||||
|         csd = intersect_state_dicts(csd, self.state_dict())  # intersect | ||||
|         self.load_state_dict(csd, strict=False)  # load | ||||
|  | ||||
| @ -174,3 +174,8 @@ def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')): | ||||
|         return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn) | ||||
|  | ||||
|     return decorate | ||||
|  | ||||
|  | ||||
| def intersect_state_dicts(da, db, exclude=()): | ||||
|     # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values | ||||
|     return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape} | ||||
|  | ||||
| @ -1,3 +1,4 @@ | ||||
| from ultralytics.yolo.v8.classify import train | ||||
| from ultralytics.yolo.v8.classify.train import ClassificationTrainer | ||||
| from ultralytics.yolo.v8.classify.val import ClassificationValidator | ||||
|  | ||||
| __all__ = ["train"] | ||||
|  | ||||
| @ -5,11 +5,10 @@ from pathlib import Path | ||||
| import hydra | ||||
| import torch | ||||
| import torchvision | ||||
| from val import ClassificationValidator | ||||
|  | ||||
| from ultralytics.yolo import BaseTrainer, v8 | ||||
| from ultralytics.yolo import v8 | ||||
| from ultralytics.yolo.data import build_classification_dataloader | ||||
| from ultralytics.yolo.engine.trainer import CONFIG_PATH_ABS, DEFAULT_CONFIG | ||||
| from ultralytics.yolo.engine.trainer import CONFIG_PATH_ABS, DEFAULT_CONFIG, BaseTrainer | ||||
| from ultralytics.yolo.utils.downloads import download | ||||
| from ultralytics.yolo.utils.files import WorkingDirectory | ||||
| from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, torch_distributed_zero_first | ||||
| @ -18,9 +17,9 @@ from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, torch_distributed_zer | ||||
| # BaseTrainer python usage | ||||
| class ClassificationTrainer(BaseTrainer): | ||||
|  | ||||
|     def get_dataset(self): | ||||
|     def get_dataset(self, dataset): | ||||
|         # temporary solution. Replace with new ultralytics.yolo.ClassificationDataset module | ||||
|         data = Path("datasets") / self.data | ||||
|         data = Path("datasets") / dataset | ||||
|         with torch_distributed_zero_first(LOCAL_RANK), WorkingDirectory(Path.cwd()): | ||||
|             data_dir = data if data.is_dir() else (Path.cwd() / data) | ||||
|             if not data_dir.is_dir(): | ||||
| @ -29,7 +28,7 @@ class ClassificationTrainer(BaseTrainer): | ||||
|                 if str(data) == 'imagenet': | ||||
|                     subprocess.run(f"bash {v8.ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True) | ||||
|                 else: | ||||
|                     url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{self.data}.zip' | ||||
|                     url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip' | ||||
|                     download(url, dir=data_dir.parent) | ||||
|                 # TODO: add colorstr | ||||
|                 s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {'bold', data_dir}\n" | ||||
| @ -39,17 +38,18 @@ class ClassificationTrainer(BaseTrainer): | ||||
|  | ||||
|         return train_set, test_set | ||||
|  | ||||
|     def get_dataloader(self, dataset, batch_size=None, rank=0): | ||||
|         return build_classification_dataloader(path=dataset, batch_size=self.train.batch_size, rank=rank) | ||||
|     def get_dataloader(self, dataset_path, batch_size=None, rank=0): | ||||
|         return build_classification_dataloader(path=dataset_path, batch_size=self.args.batch_size, rank=rank) | ||||
|  | ||||
|     def get_model(self): | ||||
|     def get_model(self, model, pretrained): | ||||
|         # temp. minimal. only supports torchvision models | ||||
|         if self.model in torchvision.models.__dict__:  # TorchVision models i.e. resnet50, efficientnet_b0 | ||||
|             model = torchvision.models.__dict__[self.model](weights='IMAGENET1K_V1' if self.train.pretrained else None) | ||||
|         model = self.args.model | ||||
|         if model in torchvision.models.__dict__:  # TorchVision models i.e. resnet50, efficientnet_b0 | ||||
|             model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None) | ||||
|         else: | ||||
|             raise ModuleNotFoundError(f'--model {self.model} not found.') | ||||
|             raise ModuleNotFoundError(f'--model {model} not found.') | ||||
|         for m in model.modules(): | ||||
|             if not self.train.pretrained and hasattr(m, 'reset_parameters'): | ||||
|             if not pretrained and hasattr(m, 'reset_parameters'): | ||||
|                 m.reset_parameters() | ||||
|         for p in model.parameters(): | ||||
|             p.requires_grad = True  # for training | ||||
| @ -57,7 +57,7 @@ class ClassificationTrainer(BaseTrainer): | ||||
|         return model | ||||
|  | ||||
|     def get_validator(self): | ||||
|         return ClassificationValidator(self.test_loader, self.device, logger=self.console)  # validator | ||||
|         return v8.classify.ClassificationValidator(self.test_loader, self.device, logger=self.console) | ||||
|  | ||||
|     def criterion(self, preds, targets): | ||||
|         return torch.nn.functional.cross_entropy(preds, targets) | ||||
| @ -66,17 +66,17 @@ class ClassificationTrainer(BaseTrainer): | ||||
| @hydra.main(version_base=None, config_path=CONFIG_PATH_ABS, config_name=str(DEFAULT_CONFIG).split(".")[0]) | ||||
| def train(cfg): | ||||
|     cfg.model = cfg.model or "squeezenet1_0" | ||||
|     cfg.data = cfg.data or "imagenette160"  # or yolo.ClassificationDataset("mnist") | ||||
|     cfg.data = cfg.data or "imagenette"  # or yolo.ClassificationDataset("mnist") | ||||
|     trainer = ClassificationTrainer(cfg) | ||||
|     trainer.run() | ||||
|     trainer.train() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     """ | ||||
|     CLI usage: | ||||
|     python ../path/to/train.py train.epochs=10 train.project="name" hyps.lr0=0.1 | ||||
|     python ../path/to/train.py args.epochs=10 args.project="name" hyps.lr0=0.1 | ||||
|  | ||||
|     TODO: | ||||
|     Direct cli support, i.e, yolov8 classify_train train.epochs 10 | ||||
|     Direct cli support, i.e, yolov8 classify_train args.epochs 10 | ||||
|     """ | ||||
|     train() | ||||
|  | ||||
| @ -1,9 +1,9 @@ | ||||
| import torch | ||||
|  | ||||
| from ultralytics import yolo | ||||
| from ultralytics.yolo.engine.validator import BaseValidator | ||||
|  | ||||
|  | ||||
| class ClassificationValidator(yolo.BaseValidator): | ||||
| class ClassificationValidator(BaseValidator): | ||||
|  | ||||
|     def init_metrics(self): | ||||
|         self.correct = torch.tensor([]) | ||||
|  | ||||
		Reference in New Issue
	
	Block a user