Fix SAM autodownload to `/weights` (#3655)

single_channel
Glenn Jocher 1 year ago committed by GitHub
parent 82920ef7ec
commit 48d7dbdbf9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -100,7 +100,7 @@ def _build_sam(
) )
sam.eval() sam.eval()
if checkpoint is not None: if checkpoint is not None:
attempt_download_asset(checkpoint) checkpoint = attempt_download_asset(checkpoint)
with open(checkpoint, 'rb') as f: with open(checkpoint, 'rb') as f:
state_dict = torch.load(f) state_dict = torch.load(f)
sam.load_state_dict(state_dict) sam.load_state_dict(state_dict)

Loading…
Cancel
Save