add resuming (#63)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Laughing 2 years ago committed by GitHub
parent de3e6ca54d
commit fbeeb5d1e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -26,8 +26,7 @@ import ultralytics.yolo.utils.callbacks as callbacks
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr
from ultralytics.yolo.utils.checks import check_file, print_args
from ultralytics.yolo.utils.files import increment_path, save_yaml
from ultralytics.yolo.utils.modeling import get_model
from ultralytics.yolo.utils.files import get_latest_run, increment_path, save_yaml
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
@ -38,6 +37,7 @@ class BaseTrainer:
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
self.args = self._get_config(config, overrides)
self.check_resume()
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
self.console = LOGGER
@ -50,6 +50,7 @@ class BaseTrainer:
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
self.batch_size = self.args.batch_size
self.epochs = self.args.epochs
self.start_epoch = 0
print_args(dict(self.args))
# Save run settings
@ -66,8 +67,6 @@ class BaseTrainer:
else:
self.data = check_dataset(self.data)
self.trainset, self.testset = self.get_dataset(self.data)
if self.args.model:
self.model = self.get_model(self.args.model)
self.ema = None
# Optimization utils init
@ -136,15 +135,17 @@ class BaseTrainer:
self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ")
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
self.model = self.model.to(self.device)
self.model = DDP(self.model, device_ids=[rank])
def _setup_train(self, rank, world_size):
"""
Builds dataloaders and optimizer on correct rank process
"""
# Optimizer
# model
ckpt = self.setup_model()
self.set_model_attributes()
if world_size > 1:
self.model = DDP(self.model, device_ids=[rank])
# Optimizer
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
self.args.weight_decay *= self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
self.optimizer = build_optimizer(model=self.model,
@ -158,6 +159,8 @@ class BaseTrainer:
else:
self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
self.resume_training(ckpt)
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
# dataloaders
batch_size = self.batch_size // world_size
@ -174,20 +177,18 @@ class BaseTrainer:
def _do_train(self, rank=-1, world_size=1):
if world_size > 1:
self._setup_ddp(rank, world_size)
else:
self.model = self.model.to(self.device)
self.trigger_callbacks("before_train")
self._setup_train(rank, world_size)
self.trigger_callbacks("before_train")
self.epoch = 0
self.epoch_time = None
self.epoch_time_start = time.time()
self.train_time_start = time.time()
nb = len(self.train_loader) # number of batches
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
last_opt_step = -1
for epoch in range(self.epochs):
for epoch in range(self.start_epoch, self.epochs):
self.epoch = epoch
self.trigger_callbacks("on_epoch_start")
self.model.train()
if rank != -1:
@ -257,11 +258,10 @@ class BaseTrainer:
self.save_metrics(metrics=log_vals)
# save model
if (not self.args.nosave) or (self.epoch + 1 == self.epochs):
if (not self.args.nosave) or (epoch + 1 == self.epochs):
self.save_model()
self.trigger_callbacks('on_model_save')
self.epoch += 1
tnow = time.time()
self.epoch_time = tnow - self.epoch_time_start
self.epoch_time_start = tnow
@ -301,17 +301,21 @@ class BaseTrainer:
"""
return data["train"], data.get("val") or data.get("test")
def get_model(self, model: Union[str, Path]):
def setup_model(self):
"""
load/create/download model for any task
"""
pretrained = True
if str(model).endswith(".yaml"):
model = self.args.model
pretrained = not (str(model).endswith(".yaml"))
# config
if not pretrained:
model = check_file(model)
pretrained = False
return self.load_model(model_cfg=None if pretrained else model,
weights=get_model(model) if pretrained else None,
data=self.data) # model
ckpt = self.load_ckpt(model) if pretrained else None
self.model = self.load_model(model_cfg=None if pretrained else model, weights=ckpt).to(self.device) # model
return ckpt
def load_ckpt(self, ckpt):
return torch.load(ckpt, map_location='cpu')
def optimizer_step(self):
self.scaler.unscale_(self.optimizer) # unscale gradients
@ -350,7 +354,7 @@ class BaseTrainer:
if rank in {-1, 0}:
self.console.info(text)
def load_model(self, model_cfg, weights, data):
def load_model(self, model_cfg, weights):
raise NotImplementedError("This task trainer doesn't support loading cfg files")
def get_validator(self):
@ -409,6 +413,40 @@ class BaseTrainer:
if f is self.best:
self.console.info(f'\nValidating {f}...')
def check_resume(self):
resume = self.args.resume
if resume:
last = Path(check_file(resume) if isinstance(resume, str) else get_latest_run())
args_yaml = last.parent.parent / 'args.yaml' # train options yaml
if args_yaml.is_file():
args = self._get_config(args_yaml) # replace
args.model, args.resume, args.exist_ok = str(last), True, True # reinstate
self.args = args
def resume_training(self, ckpt):
if ckpt is None:
return
best_fitness = 0.0
start_epoch = ckpt['epoch'] + 1
if ckpt['optimizer'] is not None:
self.optimizer.load_state_dict(ckpt['optimizer']) # optimizer
best_fitness = ckpt['best_fitness']
if self.ema and ckpt.get('ema'):
self.ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
self.ema.updates = ckpt['updates']
if self.args.resume:
assert start_epoch > 0, f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \
f"Start a new training without --resume, i.e. 'yolo task=... mode=train model={self.args.model}'"
LOGGER.info(
f'Resuming training from {self.args.model} from epoch {start_epoch} to {self.epochs} total epochs')
if self.epochs < start_epoch:
LOGGER.info(
f"{self.args.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
)
self.epochs += ckpt['epoch'] # finetune additional epochs
self.best_fitness = best_fitness
self.start_epoch = start_epoch
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
# TODO: 1. docstring with example? 2. Move this inside Trainer? or utils?

@ -33,6 +33,7 @@ overlap_mask: True # masks overlap
mask_ratio: 4 # mask downsample ratio
# Classification
dropout: False # use dropout
resume: False
# Val/Test settings ----------------------------------------------------------------------------------------------------

@ -1,4 +1,5 @@
import contextlib
import glob
import os
from datetime import datetime
from pathlib import Path
@ -74,3 +75,9 @@ def file_date(path=__file__):
# Return human-readable file modification date, i.e. '2021-3-26'
t = datetime.fromtimestamp(Path(path).stat().st_mtime)
return f'{t.year}-{t.month}-{t.day}'
def get_latest_run(search_dir='.'):
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
return max(last_list, key=os.path.getctime) if last_list else ''

@ -4,6 +4,7 @@ import torch
from ultralytics.yolo import v8
from ultralytics.yolo.data import build_classification_dataloader
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
from ultralytics.yolo.utils.modeling import get_model
from ultralytics.yolo.utils.modeling.tasks import ClassificationModel
@ -12,13 +13,13 @@ class ClassificationTrainer(BaseTrainer):
def set_model_attributes(self):
self.model.names = self.data["names"]
def load_model(self, model_cfg, weights, data):
def load_model(self, model_cfg, weights):
# TODO: why treat clf models as unique. We should have clf yamls?
if weights and not weights.__class__.__name__.startswith("yolo"): # torchvision
model = weights
else:
model = ClassificationModel(model_cfg, weights, data["nc"])
ClassificationModel.reshape_outputs(model, data["nc"])
model = ClassificationModel(model_cfg, weights, self.data["nc"])
ClassificationModel.reshape_outputs(model, self.data["nc"])
for m in model.modules():
if not weights and hasattr(m, 'reset_parameters'):
m.reset_parameters()
@ -28,6 +29,9 @@ class ClassificationTrainer(BaseTrainer):
p.requires_grad = True # for training
return model
def load_ckpt(self, ckpt):
return get_model(ckpt)
def get_dataloader(self, dataset_path, batch_size, rank=0, mode="train"):
return build_classification_dataloader(path=dataset_path,
imgsz=self.args.img_size,
@ -46,6 +50,12 @@ class ClassificationTrainer(BaseTrainer):
loss = torch.nn.functional.cross_entropy(preds, batch["cls"])
return loss, loss
def check_resume(self):
pass
def resume_training(self, ckpt):
pass
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
def train(cfg):

@ -15,10 +15,10 @@ from .val import DetectionValidator
# BaseTrainer python usage
class DetectionTrainer(SegmentationTrainer):
def load_model(self, model_cfg, weights, data):
def load_model(self, model_cfg, weights):
model = DetectionModel(model_cfg or weights["model"].yaml,
ch=3,
nc=data["nc"],
nc=self.data["nc"],
anchors=self.args.get("anchors"))
if weights:
model.load(weights)

@ -26,10 +26,10 @@ class SegmentationTrainer(BaseTrainer):
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
return batch
def load_model(self, model_cfg, weights, data):
def load_model(self, model_cfg, weights):
model = SegmentationModel(model_cfg or weights["model"].yaml,
ch=3,
nc=data["nc"],
nc=self.data["nc"],
anchors=self.args.get("anchors"))
if weights:
model.load(weights)

@ -242,7 +242,7 @@ class SegmentationValidator(BaseValidator):
cls,
bboxes,
masks,
paths,
paths=paths,
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
names=self.names)

Loading…
Cancel
Save