diff --git a/timm/layers/patch_embed.py b/timm/layers/patch_embed.py index b9a23921..473b095a 100644 --- a/timm/layers/patch_embed.py +++ b/timm/layers/patch_embed.py @@ -197,7 +197,11 @@ def resample_patch_embed( return resampled_kernel.reshape(new_size) v_resample_kernel = vmap(vmap(resample_kernel, 0, 0), 1, 1) - return v_resample_kernel(patch_embed) + orig_dtype = patch_embed.dtype + patch_embed = patch_embed.float() + patch_embed = v_resample_kernel(patch_embed) + patch_embed = patch_embed.to(orig_dtype) + return patch_embed # def divs(n, m=None): diff --git a/timm/layers/pos_embed.py b/timm/layers/pos_embed.py index 426c1c13..6be0017f 100644 --- a/timm/layers/pos_embed.py +++ b/timm/layers/pos_embed.py @@ -40,9 +40,12 @@ def resample_abs_pos_embed( # do the interpolation embed_dim = posemb.shape[-1] + orig_dtype = posemb.dtype + posemb = posemb.float() # interpolate needs float32 posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2) posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias) posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim) + posemb = posemb.to(orig_dtype) # add back extra (class, etc) prefix tokens if posemb_prefix is not None: @@ -64,12 +67,12 @@ def resample_abs_pos_embed_nhwc( if new_size[0] == posemb.shape[-3] and new_size[1] == posemb.shape[-2]: return posemb - previous_dtype = posemb.dtype + orig_dtype = posemb.dtype posemb = posemb.float() # do the interpolation posemb = posemb.reshape(1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 3, 1, 2) posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias) - posemb = posemb.permute(0, 2, 3, 1).to(previous_dtype) + posemb = posemb.permute(0, 2, 3, 1).to(orig_dtype) if not torch.jit.is_scripting() and verbose: _logger.info(f'Resized position embedding: {posemb.shape[-3:-1]} to {new_size}.')