Allow setting model attributes before training (#45)

single_channel
Ayush Chaurasia 2 years ago committed by GitHub
parent 832ea56eb4
commit db1031a1a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -133,6 +133,7 @@ class BaseTrainer:
""" """
Builds dataloaders and optimizer on correct rank process Builds dataloaders and optimizer on correct rank process
""" """
self.set_model_attributes()
self.optimizer = build_optimizer(model=self.model, self.optimizer = build_optimizer(model=self.model,
name=self.args.optimizer, name=self.args.optimizer,
lr=self.args.lr0, lr=self.args.lr0,
@ -146,19 +147,6 @@ class BaseTrainer:
print("created testloader :", rank) print("created testloader :", rank)
self.console.info(self.progress_string()) self.console.info(self.progress_string())
def _set_model_attributes(self):
# TODO: fix and use after self.data_dict is available
'''
head = utils.torch_utils.de_parallel(self.model).model[-1]
self.args.box *= 3 / head.nl # scale to layers
self.args.cls *= head.nc / 80 * 3 / head.nl # scale to classes and layers
self.args.obj *= (self.args.img_size / 640) ** 2 * 3 / nl # scale to image size and layers
model.nc = nc # attach number of classes to model
model.hyp = hyp # attach hyperparameters to model
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
model.names = names
'''
def _do_train(self, rank, world_size): def _do_train(self, rank, world_size):
if world_size > 1: if world_size > 1:
self._setup_ddp(rank, world_size) self._setup_ddp(rank, world_size)
@ -302,6 +290,12 @@ class BaseTrainer:
if not self.best_fitness or self.best_fitness < self.fitness: if not self.best_fitness or self.best_fitness < self.fitness:
self.best_fitness = self.fitness self.best_fitness = self.fitness
def set_model_attributes(self):
"""
To set or update model parameters before training.
"""
pass
def build_targets(self, preds, targets): def build_targets(self, preds, targets):
pass pass

@ -54,6 +54,16 @@ class SegmentationTrainer(BaseTrainer):
model.load(weights) model.load(weights)
return model return model
def set_model_attributes(self):
nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)
self.args.box *= 3 / nl # scale to layers
self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
self.args.obj *= (self.args.img_size / 640) ** 2 * 3 / nl # scale to image size and layers
self.model.nc = self.data["nc"] # attach number of classes to model
self.model.args = self.args # attach hyperparameters to model
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
self.model.names = self.data["names"]
def get_validator(self): def get_validator(self):
return v8.segment.SegmentationValidator(self.test_loader, self.device, logger=self.console) return v8.segment.SegmentationValidator(self.test_loader, self.device, logger=self.console)

Loading…
Cancel
Save