Model builder (#29)

Co-authored-by: Ayush Chaurasia <ayush.chuararsia@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia
2022-10-19 02:44:23 +05:30
committed by GitHub
parent c5cb76b356
commit 7b560f7861
27 changed files with 2622 additions and 407 deletions

View File

@ -20,7 +20,8 @@ from tqdm import tqdm
import ultralytics.yolo.utils as utils
import ultralytics.yolo.utils.loggers as loggers
from ultralytics.yolo.utils.general import LOGGER, ROOT
from ultralytics.yolo.utils import LOGGER, ROOT
from ultralytics.yolo.utils.files import increment_path, save_yaml
CONFIG_PATH_ABS = ROOT / "yolo/utils/configs"
DEFAULT_CONFIG = "defaults.yaml"
@ -35,16 +36,16 @@ class BaseTrainer:
self.callbacks = defaultdict(list)
self.console.info(f"Training config: \n train: \n {self.train} \n hyps: \n {self.hyps}") # to debug
# Directories
self.save_dir = utils.increment_path(Path(self.train.project) / self.train.name, exist_ok=self.train.exist_ok)
self.save_dir = increment_path(Path(self.train.project) / self.train.name, exist_ok=self.train.exist_ok)
self.wdir = self.save_dir / 'weights'
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt'
# Save run settings
utils.save_yaml(self.save_dir / 'train.yaml', OmegaConf.to_container(self.train, resolve=True))
save_yaml(self.save_dir / 'train.yaml', OmegaConf.to_container(self.train, resolve=True))
# device
self.device = utils.select_device(self.train.device, self.train.batch_size)
self.device = utils.torch_utils.select_device(self.train.device, self.train.batch_size)
self.console.info(f"running on device {self.device}")
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')

View File

@ -3,7 +3,8 @@ import logging
import torch
from tqdm import tqdm
from ultralytics.yolo.utils import Profile, select_device
from ultralytics.yolo.utils.ops import Profile
from ultralytics.yolo.utils.torch_utils import select_device
class BaseValidator: