Add `device=cuda` support (#3133)

single_channel
Glenn Jocher 1 year ago committed by GitHub
parent ff91fbd9c3
commit 8e60fc9276
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -64,6 +64,8 @@ def select_device(device='', batch=0, newline=False, verbose=True):
if cpu or mps: if cpu or mps:
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
elif device: # non-cpu device requested elif device: # non-cpu device requested
if device == 'cuda':
device = '0'
visible = os.environ.get('CUDA_VISIBLE_DEVICES', None) visible = os.environ.get('CUDA_VISIBLE_DEVICES', None)
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available() os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', ''))): if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', ''))):

Loading…
Cancel
Save