mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
use float in resample_abs_pos_embed_nhwc
since F.interpolate doesn't always support BFloat16
This commit is contained in:
parent
8cb0ddac45
commit
40a518c194
@ -64,12 +64,14 @@ def resample_abs_pos_embed_nhwc(
|
||||
if new_size[0] == posemb.shape[-3] and new_size[1] == posemb.shape[-2]:
|
||||
return posemb
|
||||
|
||||
previous_dtype = posemb.dtype
|
||||
posemb = posemb.float()
|
||||
# do the interpolation
|
||||
posemb = posemb.reshape(1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 3, 1, 2)
|
||||
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
|
||||
posemb = posemb.permute(0, 2, 3, 1)
|
||||
posemb = posemb.permute(0, 2, 3, 1).to(previous_dtype)
|
||||
|
||||
if not torch.jit.is_scripting() and verbose:
|
||||
_logger.info(f'Resized position embedding: {posemb.shape[-3:-1]} to {new_size}.')
|
||||
|
||||
return posemb
|
||||
return posemb
|
||||
|
Loading…
x
Reference in New Issue
Block a user