Smart Model loading (#31)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -41,21 +41,6 @@ class ClassificationTrainer(BaseTrainer):
|
||||
def get_dataloader(self, dataset_path, batch_size=None, rank=0):
|
||||
return build_classification_dataloader(path=dataset_path, batch_size=self.args.batch_size, rank=rank)
|
||||
|
||||
def get_model(self, model, pretrained):
|
||||
# temp. minimal. only supports torchvision models
|
||||
model = self.args.model
|
||||
if model in torchvision.models.__dict__: # TorchVision models i.e. resnet50, efficientnet_b0
|
||||
model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
|
||||
else:
|
||||
raise ModuleNotFoundError(f'--model {model} not found.')
|
||||
for m in model.modules():
|
||||
if not pretrained and hasattr(m, 'reset_parameters'):
|
||||
m.reset_parameters()
|
||||
for p in model.parameters():
|
||||
p.requires_grad = True # for training
|
||||
|
||||
return model
|
||||
|
||||
def get_validator(self):
|
||||
return v8.classify.ClassificationValidator(self.test_loader, self.device, logger=self.console)
|
||||
|
||||
@ -65,8 +50,8 @@ class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
@hydra.main(version_base=None, config_path=CONFIG_PATH_ABS, config_name=str(DEFAULT_CONFIG).split(".")[0])
|
||||
def train(cfg):
|
||||
cfg.model = cfg.model or "squeezenet1_0"
|
||||
cfg.data = cfg.data or "imagenette" # or yolo.ClassificationDataset("mnist")
|
||||
cfg.model = cfg.model or "resnet18"
|
||||
cfg.data = cfg.data or "imagenette160" # or yolo.ClassificationDataset("mnist")
|
||||
trainer = ClassificationTrainer(cfg)
|
||||
trainer.train()
|
||||
|
||||
|
Reference in New Issue
Block a user