mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Patch and pos embed resample done in float32 always (cast to float and back). Fix #1811
This commit is contained in:
parent
150356c493
commit
8e4480e4b6
@ -197,7 +197,11 @@ def resample_patch_embed(
|
||||
return resampled_kernel.reshape(new_size)
|
||||
|
||||
v_resample_kernel = vmap(vmap(resample_kernel, 0, 0), 1, 1)
|
||||
return v_resample_kernel(patch_embed)
|
||||
orig_dtype = patch_embed.dtype
|
||||
patch_embed = patch_embed.float()
|
||||
patch_embed = v_resample_kernel(patch_embed)
|
||||
patch_embed = patch_embed.to(orig_dtype)
|
||||
return patch_embed
|
||||
|
||||
|
||||
# def divs(n, m=None):
|
||||
|
@ -40,9 +40,12 @@ def resample_abs_pos_embed(
|
||||
|
||||
# do the interpolation
|
||||
embed_dim = posemb.shape[-1]
|
||||
orig_dtype = posemb.dtype
|
||||
posemb = posemb.float() # interpolate needs float32
|
||||
posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
|
||||
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
|
||||
posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim)
|
||||
posemb = posemb.to(orig_dtype)
|
||||
|
||||
# add back extra (class, etc) prefix tokens
|
||||
if posemb_prefix is not None:
|
||||
@ -64,12 +67,12 @@ def resample_abs_pos_embed_nhwc(
|
||||
if new_size[0] == posemb.shape[-3] and new_size[1] == posemb.shape[-2]:
|
||||
return posemb
|
||||
|
||||
previous_dtype = posemb.dtype
|
||||
orig_dtype = posemb.dtype
|
||||
posemb = posemb.float()
|
||||
# do the interpolation
|
||||
posemb = posemb.reshape(1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 3, 1, 2)
|
||||
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
|
||||
posemb = posemb.permute(0, 2, 3, 1).to(previous_dtype)
|
||||
posemb = posemb.permute(0, 2, 3, 1).to(orig_dtype)
|
||||
|
||||
if not torch.jit.is_scripting() and verbose:
|
||||
_logger.info(f'Resized position embedding: {posemb.shape[-3:-1]} to {new_size}.')
|
||||
|
Loading…
x
Reference in New Issue
Block a user