use float in resample_abs_pos_embed_nhwc

since F.interpolate doesn't always support BFloat16
pull/1890/head
Sepehr Sameni 2023-07-28 16:01:42 -07:00 committed by GitHub
parent 8cb0ddac45
commit 40a518c194
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 2 deletions

View File

@ -64,12 +64,14 @@ def resample_abs_pos_embed_nhwc(
if new_size[0] == posemb.shape[-3] and new_size[1] == posemb.shape[-2]:
return posemb
previous_dtype = posemb.dtype
posemb = posemb.float()
# do the interpolation
posemb = posemb.reshape(1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 3, 1, 2)
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
posemb = posemb.permute(0, 2, 3, 1)
posemb = posemb.permute(0, 2, 3, 1).to(previous_dtype)
if not torch.jit.is_scripting() and verbose:
_logger.info(f'Resized position embedding: {posemb.shape[-3:-1]} to {new_size}.')
return posemb
return posemb