|
|
|
@ -26,8 +26,7 @@ import ultralytics.yolo.utils.callbacks as callbacks
|
|
|
|
|
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
|
|
|
|
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr
|
|
|
|
|
from ultralytics.yolo.utils.checks import check_file, print_args
|
|
|
|
|
from ultralytics.yolo.utils.files import increment_path, save_yaml
|
|
|
|
|
from ultralytics.yolo.utils.modeling import get_model
|
|
|
|
|
from ultralytics.yolo.utils.files import get_latest_run, increment_path, save_yaml
|
|
|
|
|
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
|
|
|
|
|
|
|
|
|
|
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
|
|
|
|
@ -38,6 +37,7 @@ class BaseTrainer:
|
|
|
|
|
|
|
|
|
|
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
|
|
|
|
|
self.args = self._get_config(config, overrides)
|
|
|
|
|
self.check_resume()
|
|
|
|
|
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
|
|
|
|
|
|
|
|
|
self.console = LOGGER
|
|
|
|
@ -50,6 +50,7 @@ class BaseTrainer:
|
|
|
|
|
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
|
|
|
|
|
self.batch_size = self.args.batch_size
|
|
|
|
|
self.epochs = self.args.epochs
|
|
|
|
|
self.start_epoch = 0
|
|
|
|
|
print_args(dict(self.args))
|
|
|
|
|
|
|
|
|
|
# Save run settings
|
|
|
|
@ -66,8 +67,6 @@ class BaseTrainer:
|
|
|
|
|
else:
|
|
|
|
|
self.data = check_dataset(self.data)
|
|
|
|
|
self.trainset, self.testset = self.get_dataset(self.data)
|
|
|
|
|
if self.args.model:
|
|
|
|
|
self.model = self.get_model(self.args.model)
|
|
|
|
|
self.ema = None
|
|
|
|
|
|
|
|
|
|
# Optimization utils init
|
|
|
|
@ -136,15 +135,17 @@ class BaseTrainer:
|
|
|
|
|
self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ")
|
|
|
|
|
|
|
|
|
|
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
|
|
|
|
|
self.model = self.model.to(self.device)
|
|
|
|
|
self.model = DDP(self.model, device_ids=[rank])
|
|
|
|
|
|
|
|
|
|
def _setup_train(self, rank, world_size):
|
|
|
|
|
"""
|
|
|
|
|
Builds dataloaders and optimizer on correct rank process
|
|
|
|
|
"""
|
|
|
|
|
# Optimizer
|
|
|
|
|
# model
|
|
|
|
|
ckpt = self.setup_model()
|
|
|
|
|
self.set_model_attributes()
|
|
|
|
|
if world_size > 1:
|
|
|
|
|
self.model = DDP(self.model, device_ids=[rank])
|
|
|
|
|
# Optimizer
|
|
|
|
|
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
|
|
|
|
|
self.args.weight_decay *= self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
|
|
|
|
|
self.optimizer = build_optimizer(model=self.model,
|
|
|
|
@ -158,6 +159,8 @@ class BaseTrainer:
|
|
|
|
|
else:
|
|
|
|
|
self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
|
|
|
|
|
self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
|
|
|
|
self.resume_training(ckpt)
|
|
|
|
|
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
|
|
|
|
|
|
|
|
|
# dataloaders
|
|
|
|
|
batch_size = self.batch_size // world_size
|
|
|
|
@ -174,20 +177,18 @@ class BaseTrainer:
|
|
|
|
|
def _do_train(self, rank=-1, world_size=1):
|
|
|
|
|
if world_size > 1:
|
|
|
|
|
self._setup_ddp(rank, world_size)
|
|
|
|
|
else:
|
|
|
|
|
self.model = self.model.to(self.device)
|
|
|
|
|
|
|
|
|
|
self.trigger_callbacks("before_train")
|
|
|
|
|
self._setup_train(rank, world_size)
|
|
|
|
|
self.trigger_callbacks("before_train")
|
|
|
|
|
|
|
|
|
|
self.epoch = 0
|
|
|
|
|
self.epoch_time = None
|
|
|
|
|
self.epoch_time_start = time.time()
|
|
|
|
|
self.train_time_start = time.time()
|
|
|
|
|
nb = len(self.train_loader) # number of batches
|
|
|
|
|
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
|
|
|
|
|
last_opt_step = -1
|
|
|
|
|
for epoch in range(self.epochs):
|
|
|
|
|
for epoch in range(self.start_epoch, self.epochs):
|
|
|
|
|
self.epoch = epoch
|
|
|
|
|
self.trigger_callbacks("on_epoch_start")
|
|
|
|
|
self.model.train()
|
|
|
|
|
if rank != -1:
|
|
|
|
@ -257,11 +258,10 @@ class BaseTrainer:
|
|
|
|
|
self.save_metrics(metrics=log_vals)
|
|
|
|
|
|
|
|
|
|
# save model
|
|
|
|
|
if (not self.args.nosave) or (self.epoch + 1 == self.epochs):
|
|
|
|
|
if (not self.args.nosave) or (epoch + 1 == self.epochs):
|
|
|
|
|
self.save_model()
|
|
|
|
|
self.trigger_callbacks('on_model_save')
|
|
|
|
|
|
|
|
|
|
self.epoch += 1
|
|
|
|
|
tnow = time.time()
|
|
|
|
|
self.epoch_time = tnow - self.epoch_time_start
|
|
|
|
|
self.epoch_time_start = tnow
|
|
|
|
@ -301,17 +301,21 @@ class BaseTrainer:
|
|
|
|
|
"""
|
|
|
|
|
return data["train"], data.get("val") or data.get("test")
|
|
|
|
|
|
|
|
|
|
def get_model(self, model: Union[str, Path]):
|
|
|
|
|
def setup_model(self):
|
|
|
|
|
"""
|
|
|
|
|
load/create/download model for any task
|
|
|
|
|
"""
|
|
|
|
|
pretrained = True
|
|
|
|
|
if str(model).endswith(".yaml"):
|
|
|
|
|
model = self.args.model
|
|
|
|
|
pretrained = not (str(model).endswith(".yaml"))
|
|
|
|
|
# config
|
|
|
|
|
if not pretrained:
|
|
|
|
|
model = check_file(model)
|
|
|
|
|
pretrained = False
|
|
|
|
|
return self.load_model(model_cfg=None if pretrained else model,
|
|
|
|
|
weights=get_model(model) if pretrained else None,
|
|
|
|
|
data=self.data) # model
|
|
|
|
|
ckpt = self.load_ckpt(model) if pretrained else None
|
|
|
|
|
self.model = self.load_model(model_cfg=None if pretrained else model, weights=ckpt).to(self.device) # model
|
|
|
|
|
return ckpt
|
|
|
|
|
|
|
|
|
|
def load_ckpt(self, ckpt):
|
|
|
|
|
return torch.load(ckpt, map_location='cpu')
|
|
|
|
|
|
|
|
|
|
def optimizer_step(self):
|
|
|
|
|
self.scaler.unscale_(self.optimizer) # unscale gradients
|
|
|
|
@ -350,7 +354,7 @@ class BaseTrainer:
|
|
|
|
|
if rank in {-1, 0}:
|
|
|
|
|
self.console.info(text)
|
|
|
|
|
|
|
|
|
|
def load_model(self, model_cfg, weights, data):
|
|
|
|
|
def load_model(self, model_cfg, weights):
|
|
|
|
|
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
|
|
|
|
|
|
|
|
|
def get_validator(self):
|
|
|
|
@ -409,6 +413,40 @@ class BaseTrainer:
|
|
|
|
|
if f is self.best:
|
|
|
|
|
self.console.info(f'\nValidating {f}...')
|
|
|
|
|
|
|
|
|
|
def check_resume(self):
|
|
|
|
|
resume = self.args.resume
|
|
|
|
|
if resume:
|
|
|
|
|
last = Path(check_file(resume) if isinstance(resume, str) else get_latest_run())
|
|
|
|
|
args_yaml = last.parent.parent / 'args.yaml' # train options yaml
|
|
|
|
|
if args_yaml.is_file():
|
|
|
|
|
args = self._get_config(args_yaml) # replace
|
|
|
|
|
args.model, args.resume, args.exist_ok = str(last), True, True # reinstate
|
|
|
|
|
self.args = args
|
|
|
|
|
|
|
|
|
|
def resume_training(self, ckpt):
|
|
|
|
|
if ckpt is None:
|
|
|
|
|
return
|
|
|
|
|
best_fitness = 0.0
|
|
|
|
|
start_epoch = ckpt['epoch'] + 1
|
|
|
|
|
if ckpt['optimizer'] is not None:
|
|
|
|
|
self.optimizer.load_state_dict(ckpt['optimizer']) # optimizer
|
|
|
|
|
best_fitness = ckpt['best_fitness']
|
|
|
|
|
if self.ema and ckpt.get('ema'):
|
|
|
|
|
self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
|
|
|
|
|
self.ema.updates = ckpt['updates']
|
|
|
|
|
if self.args.resume:
|
|
|
|
|
assert start_epoch > 0, f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \
|
|
|
|
|
f"Start a new training without --resume, i.e. 'yolo task=... mode=train model={self.args.model}'"
|
|
|
|
|
LOGGER.info(
|
|
|
|
|
f'Resuming training from {self.args.model} from epoch {start_epoch} to {self.epochs} total epochs')
|
|
|
|
|
if self.epochs < start_epoch:
|
|
|
|
|
LOGGER.info(
|
|
|
|
|
f"{self.args.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
|
|
|
|
|
)
|
|
|
|
|
self.epochs += ckpt['epoch'] # finetune additional epochs
|
|
|
|
|
self.best_fitness = best_fitness
|
|
|
|
|
self.start_epoch = start_epoch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
|
|
|
|
|
# TODO: 1. docstring with example? 2. Move this inside Trainer? or utils?
|
|
|
|
|