Torch CUDA synchronize update (#1826)
* torch.cuda.synchronize() update * torch.cuda.synchronize() update * torch.cuda.synchronize() update * newlinepull/1836/head
parent
0b6266f5e0
commit
9f5a18bb80
utils
|
@ -36,42 +36,41 @@ def init_torch_seeds(seed=0):
|
|||
# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
|
||||
torch.manual_seed(seed)
|
||||
if seed == 0: # slower, more reproducible
|
||||
cudnn.deterministic = True
|
||||
cudnn.benchmark = False
|
||||
cudnn.benchmark, cudnn.deterministic = False, True
|
||||
else: # faster, less reproducible
|
||||
cudnn.deterministic = False
|
||||
cudnn.benchmark = True
|
||||
cudnn.benchmark, cudnn.deterministic = True, False
|
||||
|
||||
|
||||
def select_device(device='', batch_size=None):
|
||||
# device = 'cpu' or '0' or '0,1,2,3'
|
||||
cpu_request = device.lower() == 'cpu'
|
||||
if device and not cpu_request: # if device requested other than 'cpu'
|
||||
s = f'Using torch {torch.__version__} ' # string
|
||||
cpu = device.lower() == 'cpu'
|
||||
if cpu:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
|
||||
elif device: # non-cpu device requested
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
|
||||
assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availablity
|
||||
assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability
|
||||
|
||||
cuda = False if cpu_request else torch.cuda.is_available()
|
||||
cuda = torch.cuda.is_available() and not cpu
|
||||
if cuda:
|
||||
c = 1024 ** 2 # bytes to MB
|
||||
ng = torch.cuda.device_count()
|
||||
if ng > 1 and batch_size: # check that batch_size is compatible with device_count
|
||||
assert batch_size % ng == 0, f'batch-size {batch_size} not multiple of GPU count {ng}'
|
||||
x = [torch.cuda.get_device_properties(i) for i in range(ng)]
|
||||
s = f'Using torch {torch.__version__} '
|
||||
for i, d in enumerate((device or '0').split(',')):
|
||||
if i == 1:
|
||||
s = ' ' * len(s)
|
||||
logger.info(f"{s}CUDA:{d} ({x[i].name}, {x[i].total_memory / c}MB)")
|
||||
n = torch.cuda.device_count()
|
||||
if n > 1 and batch_size: # check that batch_size is compatible with device_count
|
||||
assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
|
||||
space = ' ' * len(s)
|
||||
for i, d in enumerate(device.split(',') if device else range(n)):
|
||||
p = torch.cuda.get_device_properties(i)
|
||||
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
|
||||
else:
|
||||
logger.info(f'Using torch {torch.__version__} CPU')
|
||||
s += 'CPU'
|
||||
|
||||
logger.info('') # skip a line
|
||||
logger.info(f'{s}\n') # skip a line
|
||||
return torch.device('cuda:0' if cuda else 'cpu')
|
||||
|
||||
|
||||
def time_synchronized():
|
||||
# pytorch-accurate time
|
||||
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
return time.time()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue