Switch RandoErasing back to on GPU normal sampling

pull/62/head
Ross Wightman 2019-12-05 22:35:08 -08:00
parent 3129bdb2c1
commit 0161de0127
1 changed files with 2 additions and 4 deletions

View File

@ -7,12 +7,10 @@ def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device='
# NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
# paths, flip the order so normal is run on CPU if this becomes a problem
# Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
# will revert back to doing normal_() on GPU when it's in next release
if per_pixel:
return torch.empty(
patch_size, dtype=dtype).normal_().to(device=device)
return torch.empty(patch_size, dtype=dtype, device=device).normal_()
elif rand_color:
return torch.empty((patch_size[0], 1, 1), dtype=dtype).normal_().to(device=device)
return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_()
else:
return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)