From db1031a1a9f6117dd40df98e533dea2ccebaf71c Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Thu, 17 Nov 2022 10:35:53 +0530 Subject: [PATCH] Allow setting model attributes before training (#45) --- ultralytics/yolo/engine/trainer.py | 20 +++++++------------- ultralytics/yolo/v8/segment/train.py | 10 ++++++++++ 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index c48da2b..ec82738 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -133,6 +133,7 @@ class BaseTrainer: """ Builds dataloaders and optimizer on correct rank process """ + self.set_model_attributes() self.optimizer = build_optimizer(model=self.model, name=self.args.optimizer, lr=self.args.lr0, @@ -146,19 +147,6 @@ class BaseTrainer: print("created testloader :", rank) self.console.info(self.progress_string()) - def _set_model_attributes(self): - # TODO: fix and use after self.data_dict is available - ''' - head = utils.torch_utils.de_parallel(self.model).model[-1] - self.args.box *= 3 / head.nl # scale to layers - self.args.cls *= head.nc / 80 * 3 / head.nl # scale to classes and layers - self.args.obj *= (self.args.img_size / 640) ** 2 * 3 / nl # scale to image size and layers - model.nc = nc # attach number of classes to model - model.hyp = hyp # attach hyperparameters to model - model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights - model.names = names - ''' - def _do_train(self, rank, world_size): if world_size > 1: self._setup_ddp(rank, world_size) @@ -302,6 +290,12 @@ class BaseTrainer: if not self.best_fitness or self.best_fitness < self.fitness: self.best_fitness = self.fitness + def set_model_attributes(self): + """ + To set or update model parameters before training. + """ + pass + def build_targets(self, preds, targets): pass diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py index 1dd64c8..10c3522 100644 --- a/ultralytics/yolo/v8/segment/train.py +++ b/ultralytics/yolo/v8/segment/train.py @@ -54,6 +54,16 @@ class SegmentationTrainer(BaseTrainer): model.load(weights) return model + def set_model_attributes(self): + nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps) + self.args.box *= 3 / nl # scale to layers + self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers + self.args.obj *= (self.args.img_size / 640) ** 2 * 3 / nl # scale to image size and layers + self.model.nc = self.data["nc"] # attach number of classes to model + self.model.args = self.args # attach hyperparameters to model + # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc + self.model.names = self.data["names"] + def get_validator(self): return v8.segment.SegmentationValidator(self.test_loader, self.device, logger=self.console)