Patch and pos embed resample done in float32 always (cast to float and back). Fix #1811

This commit is contained in:
Ross Wightman 2023-08-03 11:32:17 -07:00
parent 150356c493
commit 8e4480e4b6
2 changed files with 10 additions and 3 deletions

View File

@ -197,7 +197,11 @@ def resample_patch_embed(
return resampled_kernel.reshape(new_size)
v_resample_kernel = vmap(vmap(resample_kernel, 0, 0), 1, 1)
return v_resample_kernel(patch_embed)
orig_dtype = patch_embed.dtype
patch_embed = patch_embed.float()
patch_embed = v_resample_kernel(patch_embed)
patch_embed = patch_embed.to(orig_dtype)
return patch_embed
# def divs(n, m=None):

View File

@ -40,9 +40,12 @@ def resample_abs_pos_embed(
# do the interpolation
embed_dim = posemb.shape[-1]
orig_dtype = posemb.dtype
posemb = posemb.float() # interpolate needs float32
posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim)
posemb = posemb.to(orig_dtype)
# add back extra (class, etc) prefix tokens
if posemb_prefix is not None:
@ -64,12 +67,12 @@ def resample_abs_pos_embed_nhwc(
if new_size[0] == posemb.shape[-3] and new_size[1] == posemb.shape[-2]:
return posemb
previous_dtype = posemb.dtype
orig_dtype = posemb.dtype
posemb = posemb.float()
# do the interpolation
posemb = posemb.reshape(1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 3, 1, 2)
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
posemb = posemb.permute(0, 2, 3, 1).to(previous_dtype)
posemb = posemb.permute(0, 2, 3, 1).to(orig_dtype)
if not torch.jit.is_scripting() and verbose:
_logger.info(f'Resized position embedding: {posemb.shape[-3:-1]} to {new_size}.')