General console printout updates (#48)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -25,8 +25,8 @@ from tqdm import tqdm
|
||||
import ultralytics.yolo.utils as utils
|
||||
import ultralytics.yolo.utils.loggers as loggers
|
||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||
from ultralytics.yolo.utils import LOGGER, ROOT
|
||||
from ultralytics.yolo.utils.checks import check_file, check_yaml
|
||||
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT
|
||||
from ultralytics.yolo.utils.checks import print_args
|
||||
from ultralytics.yolo.utils.files import increment_path, save_yaml
|
||||
from ultralytics.yolo.utils.modeling import get_model
|
||||
|
||||
@ -41,19 +41,17 @@ class BaseTrainer:
|
||||
self.validator = None
|
||||
self.model = None
|
||||
self.callbacks = defaultdict(list)
|
||||
self.console.info(f"Training config: \n args: \n {self.args}") # to debug
|
||||
# Directories
|
||||
self.save_dir = increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok)
|
||||
self.wdir = self.save_dir / 'weights'
|
||||
self.wdir = self.save_dir / 'weights' # weights dir
|
||||
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
||||
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt'
|
||||
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
|
||||
print_args(dict(self.args))
|
||||
|
||||
# Save run settings
|
||||
save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True))
|
||||
|
||||
# device
|
||||
self.device = utils.torch_utils.select_device(self.args.device, self.args.batch_size)
|
||||
self.console.info(f"running on device {self.device}")
|
||||
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
|
||||
|
||||
# Model and Dataloaders.
|
||||
@ -64,7 +62,7 @@ class BaseTrainer:
|
||||
self.data = check_dataset(self.data)
|
||||
self.trainset, self.testset = self.get_dataset(self.data)
|
||||
if self.args.model:
|
||||
self.model = self.get_model(self.args.model, self.data)
|
||||
self.model = self.get_model(self.args.model)
|
||||
|
||||
# epoch level metrics
|
||||
self.metrics = {} # handle metrics returned by validator
|
||||
@ -115,7 +113,7 @@ class BaseTrainer:
|
||||
if world_size > 1:
|
||||
mp.spawn(self._do_train, args=(world_size,), nprocs=world_size, join=True)
|
||||
else:
|
||||
self._do_train(-1, 1)
|
||||
self._do_train()
|
||||
|
||||
def _setup_ddp(self, rank, world_size):
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
@ -147,7 +145,7 @@ class BaseTrainer:
|
||||
print("created testloader :", rank)
|
||||
self.console.info(self.progress_string())
|
||||
|
||||
def _do_train(self, rank, world_size):
|
||||
def _do_train(self, rank=-1, world_size=1):
|
||||
if world_size > 1:
|
||||
self._setup_ddp(rank, world_size)
|
||||
else:
|
||||
@ -165,9 +163,7 @@ class BaseTrainer:
|
||||
self.model.train()
|
||||
pbar = enumerate(self.train_loader)
|
||||
if rank in {-1, 0}:
|
||||
pbar = tqdm(enumerate(self.train_loader),
|
||||
total=len(self.train_loader),
|
||||
bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
|
||||
pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader), bar_format=TQDM_BAR_FORMAT)
|
||||
tloss = None
|
||||
for i, batch in pbar:
|
||||
# img, label (classification)/ img, targets, paths, _, masks(detection)
|
||||
@ -249,18 +245,14 @@ class BaseTrainer:
|
||||
"""
|
||||
return data["train"], data["val"]
|
||||
|
||||
def get_model(self, model: str, data: Dict):
|
||||
def get_model(self, model: Union[str, Path]):
|
||||
"""
|
||||
load/create/download model for any task
|
||||
"""
|
||||
pretrained = False
|
||||
if not str(model).endswith(".yaml"):
|
||||
pretrained = True
|
||||
weights = get_model(model) # rename this to something less confusing?
|
||||
model = self.load_model(model_cfg=model if not pretrained else None,
|
||||
weights=weights if pretrained else None,
|
||||
data=self.data)
|
||||
return model
|
||||
pretrained = not str(model).endswith(".yaml")
|
||||
return self.load_model(model_cfg=None if pretrained else model,
|
||||
weights=get_model(model) if pretrained else None,
|
||||
data=self.data) # model
|
||||
|
||||
def load_model(self, model_cfg, weights, data):
|
||||
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
||||
|
@ -5,6 +5,7 @@ from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
||||
from ultralytics.yolo.utils import TQDM_BAR_FORMAT
|
||||
from ultralytics.yolo.utils.ops import Profile
|
||||
from ultralytics.yolo.utils.torch_utils import de_parallel, select_device
|
||||
|
||||
@ -49,7 +50,7 @@ class BaseValidator:
|
||||
loss = 0
|
||||
n_batches = len(self.dataloader)
|
||||
desc = self.get_desc()
|
||||
bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
|
||||
bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format=TQDM_BAR_FORMAT)
|
||||
self.init_metrics(de_parallel(model))
|
||||
with torch.no_grad():
|
||||
for batch_i, batch in enumerate(bar):
|
||||
|
Reference in New Issue
Block a user