mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Adding pos embed resize fns to FX autowrap exceptions
This commit is contained in:
parent
f0fb471b26
commit
5e9ff5798f
@ -15,7 +15,7 @@ _logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def resample_abs_pos_embed(
|
||||
posemb,
|
||||
posemb: torch.Tensor,
|
||||
new_size: List[int],
|
||||
old_size: Optional[List[int]] = None,
|
||||
num_prefix_tokens: int = 1,
|
||||
@ -58,7 +58,7 @@ def resample_abs_pos_embed(
|
||||
|
||||
|
||||
def resample_abs_pos_embed_nhwc(
|
||||
posemb,
|
||||
posemb: torch.Tensor,
|
||||
new_size: List[int],
|
||||
interpolation: str = 'bicubic',
|
||||
antialias: bool = True,
|
||||
@ -69,7 +69,6 @@ def resample_abs_pos_embed_nhwc(
|
||||
|
||||
orig_dtype = posemb.dtype
|
||||
posemb = posemb.float()
|
||||
# do the interpolation
|
||||
posemb = posemb.reshape(1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 3, 1, 2)
|
||||
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
|
||||
posemb = posemb.permute(0, 2, 3, 1).to(orig_dtype)
|
||||
|
@ -18,6 +18,7 @@ except ImportError:
|
||||
|
||||
# Layers we went to treat as leaf modules
|
||||
from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame, Format
|
||||
from timm.layers import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
|
||||
from timm.layers.non_local_attn import BilinearAttnTransform
|
||||
from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
|
||||
from timm.layers.norm_act import (
|
||||
@ -75,7 +76,10 @@ def get_notrace_modules():
|
||||
|
||||
|
||||
# Functions we want to autowrap (treat them as leaves)
|
||||
_autowrap_functions = set()
|
||||
_autowrap_functions = {
|
||||
resample_abs_pos_embed,
|
||||
resample_abs_pos_embed_nhwc,
|
||||
}
|
||||
|
||||
|
||||
def register_notrace_function(func: Callable):
|
||||
|
Loading…
x
Reference in New Issue
Block a user