|
|
@ -47,7 +47,7 @@ class ClassificationTrainer(BaseTrainer):
|
|
|
|
return torch.nn.functional.cross_entropy(preds, targets)
|
|
|
|
return torch.nn.functional.cross_entropy(preds, targets)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.stem)
|
|
|
|
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
|
|
|
|
def train(cfg):
|
|
|
|
def train(cfg):
|
|
|
|
cfg.model = cfg.model or "resnet18"
|
|
|
|
cfg.model = cfg.model or "resnet18"
|
|
|
|
cfg.data = cfg.data or "imagenette160" # or yolo.ClassificationDataset("mnist")
|
|
|
|
cfg.data = cfg.data or "imagenette160" # or yolo.ClassificationDataset("mnist")
|
|
|
|