update model initialization design, supports custom data/num_classes (#44)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -63,10 +63,8 @@ class BaseTrainer:
|
||||
else:
|
||||
self.data = check_dataset(self.data)
|
||||
self.trainset, self.testset = self.get_dataset(self.data)
|
||||
if self.args.cfg is not None:
|
||||
self.model = self.load_cfg(check_file(self.args.cfg))
|
||||
if self.args.model is not None:
|
||||
self.model = self.get_model(self.args.model, self.args.pretrained).to(self.device)
|
||||
if self.args.model:
|
||||
self.model = self.get_model(self.args.model, self.data)
|
||||
|
||||
# epoch level metrics
|
||||
self.metrics = {} # handle metrics returned by validator
|
||||
@ -261,20 +259,20 @@ class BaseTrainer:
|
||||
"""
|
||||
return data["train"], data["val"]
|
||||
|
||||
def get_model(self, model, pretrained):
|
||||
def get_model(self, model: str, data: Dict):
|
||||
"""
|
||||
load/create/download model for any task
|
||||
"""
|
||||
model = get_model(model)
|
||||
for m in model.modules():
|
||||
if not pretrained and hasattr(m, 'reset_parameters'):
|
||||
m.reset_parameters()
|
||||
for p in model.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
pretrained = False
|
||||
if not str(model).endswith(".yaml"):
|
||||
pretrained = True
|
||||
weights = get_model(model) # rename this to something less confusing?
|
||||
model = self.load_model(model_cfg=model if not pretrained else None,
|
||||
weights=weights if pretrained else None,
|
||||
data=self.data)
|
||||
return model
|
||||
|
||||
def load_cfg(self, cfg):
|
||||
def load_model(self, model_cfg, weights, data):
|
||||
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
||||
|
||||
def get_validator(self):
|
||||
|
Reference in New Issue
Block a user