Update torch_utils.py (#1895)
parent
0f11aaf551
commit
ffef77124e
|
@ -61,7 +61,7 @@ def select_device(device='', batch_size=None):
|
|||
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
|
||||
assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability
|
||||
|
||||
cuda = torch.cuda.is_available() and not cpu
|
||||
cuda = not cpu and torch.cuda.is_available()
|
||||
if cuda:
|
||||
n = torch.cuda.device_count()
|
||||
if n > 1 and batch_size: # check that batch_size is compatible with device_count
|
||||
|
|
Loading…
Reference in New Issue