update model initialization design, supports custom data/num_classes (#44)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia
2022-11-15 20:06:29 +05:30
committed by GitHub
parent 1f3aad86c1
commit 832ea56eb4
8 changed files with 67 additions and 44 deletions

View File

@ -63,10 +63,8 @@ class BaseTrainer:
else:
self.data = check_dataset(self.data)
self.trainset, self.testset = self.get_dataset(self.data)
if self.args.cfg is not None:
self.model = self.load_cfg(check_file(self.args.cfg))
if self.args.model is not None:
self.model = self.get_model(self.args.model, self.args.pretrained).to(self.device)
if self.args.model:
self.model = self.get_model(self.args.model, self.data)
# epoch level metrics
self.metrics = {} # handle metrics returned by validator
@ -261,20 +259,20 @@ class BaseTrainer:
"""
return data["train"], data["val"]
def get_model(self, model, pretrained):
def get_model(self, model: str, data: Dict):
"""
load/create/download model for any task
"""
model = get_model(model)
for m in model.modules():
if not pretrained and hasattr(m, 'reset_parameters'):
m.reset_parameters()
for p in model.parameters():
p.requires_grad = True
pretrained = False
if not str(model).endswith(".yaml"):
pretrained = True
weights = get_model(model) # rename this to something less confusing?
model = self.load_model(model_cfg=model if not pretrained else None,
weights=weights if pretrained else None,
data=self.data)
return model
def load_cfg(self, cfg):
def load_model(self, model_cfg, weights, data):
raise NotImplementedError("This task trainer doesn't support loading cfg files")
def get_validator(self):