Smart Model loading (#31)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia
2022-11-01 04:22:12 +05:30
committed by GitHub
parent 1054819a59
commit 92c60758dd
4 changed files with 80 additions and 42 deletions

View File

@ -41,21 +41,6 @@ class ClassificationTrainer(BaseTrainer):
def get_dataloader(self, dataset_path, batch_size=None, rank=0):
return build_classification_dataloader(path=dataset_path, batch_size=self.args.batch_size, rank=rank)
def get_model(self, model, pretrained):
# temp. minimal. only supports torchvision models
model = self.args.model
if model in torchvision.models.__dict__: # TorchVision models i.e. resnet50, efficientnet_b0
model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
else:
raise ModuleNotFoundError(f'--model {model} not found.')
for m in model.modules():
if not pretrained and hasattr(m, 'reset_parameters'):
m.reset_parameters()
for p in model.parameters():
p.requires_grad = True # for training
return model
def get_validator(self):
return v8.classify.ClassificationValidator(self.test_loader, self.device, logger=self.console)
@ -65,8 +50,8 @@ class ClassificationTrainer(BaseTrainer):
@hydra.main(version_base=None, config_path=CONFIG_PATH_ABS, config_name=str(DEFAULT_CONFIG).split(".")[0])
def train(cfg):
cfg.model = cfg.model or "squeezenet1_0"
cfg.data = cfg.data or "imagenette" # or yolo.ClassificationDataset("mnist")
cfg.model = cfg.model or "resnet18"
cfg.data = cfg.data or "imagenette160" # or yolo.ClassificationDataset("mnist")
trainer = ClassificationTrainer(cfg)
trainer.train()