Merge pull request #1890 from Separius/patch-1
use float in resample_abs_pos_embed_nhwcpull/1903/head
commit
3b8ef3f32f
|
@ -64,12 +64,14 @@ def resample_abs_pos_embed_nhwc(
|
|||
if new_size[0] == posemb.shape[-3] and new_size[1] == posemb.shape[-2]:
|
||||
return posemb
|
||||
|
||||
previous_dtype = posemb.dtype
|
||||
posemb = posemb.float()
|
||||
# do the interpolation
|
||||
posemb = posemb.reshape(1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 3, 1, 2)
|
||||
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
|
||||
posemb = posemb.permute(0, 2, 3, 1)
|
||||
posemb = posemb.permute(0, 2, 3, 1).to(previous_dtype)
|
||||
|
||||
if not torch.jit.is_scripting() and verbose:
|
||||
_logger.info(f'Resized position embedding: {posemb.shape[-3:-1]} to {new_size}.')
|
||||
|
||||
return posemb
|
||||
return posemb
|
||||
|
|
Loading…
Reference in New Issue