Adding pos embed resize fns to FX autowrap exceptions

This commit is contained in:
Ross Wightman 2024-06-10 12:06:47 -07:00
parent f0fb471b26
commit 5e9ff5798f
2 changed files with 7 additions and 4 deletions

View File

@ -15,7 +15,7 @@ _logger = logging.getLogger(__name__)
def resample_abs_pos_embed(
posemb,
posemb: torch.Tensor,
new_size: List[int],
old_size: Optional[List[int]] = None,
num_prefix_tokens: int = 1,
@ -58,7 +58,7 @@ def resample_abs_pos_embed(
def resample_abs_pos_embed_nhwc(
posemb,
posemb: torch.Tensor,
new_size: List[int],
interpolation: str = 'bicubic',
antialias: bool = True,
@ -69,7 +69,6 @@ def resample_abs_pos_embed_nhwc(
orig_dtype = posemb.dtype
posemb = posemb.float()
# do the interpolation
posemb = posemb.reshape(1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 3, 1, 2)
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
posemb = posemb.permute(0, 2, 3, 1).to(orig_dtype)

View File

@ -18,6 +18,7 @@ except ImportError:
# Layers we went to treat as leaf modules
from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame, Format
from timm.layers import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
from timm.layers.non_local_attn import BilinearAttnTransform
from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
from timm.layers.norm_act import (
@ -75,7 +76,10 @@ def get_notrace_modules():
# Functions we want to autowrap (treat them as leaves)
_autowrap_functions = set()
_autowrap_functions = {
resample_abs_pos_embed,
resample_abs_pos_embed_nhwc,
}
def register_notrace_function(func: Callable):