From 5e9ff5798f5c7bd463944c483fd9619e701dd349 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 10 Jun 2024 12:06:47 -0700 Subject: [PATCH] Adding pos embed resize fns to FX autowrap exceptions --- timm/layers/pos_embed.py | 5 ++--- timm/models/_features_fx.py | 6 +++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/timm/layers/pos_embed.py b/timm/layers/pos_embed.py index 3e67be00..0d50207d 100644 --- a/timm/layers/pos_embed.py +++ b/timm/layers/pos_embed.py @@ -15,7 +15,7 @@ _logger = logging.getLogger(__name__) def resample_abs_pos_embed( - posemb, + posemb: torch.Tensor, new_size: List[int], old_size: Optional[List[int]] = None, num_prefix_tokens: int = 1, @@ -58,7 +58,7 @@ def resample_abs_pos_embed( def resample_abs_pos_embed_nhwc( - posemb, + posemb: torch.Tensor, new_size: List[int], interpolation: str = 'bicubic', antialias: bool = True, @@ -69,7 +69,6 @@ def resample_abs_pos_embed_nhwc( orig_dtype = posemb.dtype posemb = posemb.float() - # do the interpolation posemb = posemb.reshape(1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 3, 1, 2) posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias) posemb = posemb.permute(0, 2, 3, 1).to(orig_dtype) diff --git a/timm/models/_features_fx.py b/timm/models/_features_fx.py index 3a276046..1ea4a4f4 100644 --- a/timm/models/_features_fx.py +++ b/timm/models/_features_fx.py @@ -18,6 +18,7 @@ except ImportError: # Layers we went to treat as leaf modules from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame, Format +from timm.layers import resample_abs_pos_embed, resample_abs_pos_embed_nhwc from timm.layers.non_local_attn import BilinearAttnTransform from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame from timm.layers.norm_act import ( @@ -75,7 +76,10 @@ def get_notrace_modules(): # Functions we want to autowrap (treat them as leaves) -_autowrap_functions = set() +_autowrap_functions = { + resample_abs_pos_embed, + resample_abs_pos_embed_nhwc, +} def register_notrace_function(func: Callable):