mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Adding pos embed resize fns to FX autowrap exceptions
This commit is contained in:
parent
f0fb471b26
commit
5e9ff5798f
@ -15,7 +15,7 @@ _logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def resample_abs_pos_embed(
|
def resample_abs_pos_embed(
|
||||||
posemb,
|
posemb: torch.Tensor,
|
||||||
new_size: List[int],
|
new_size: List[int],
|
||||||
old_size: Optional[List[int]] = None,
|
old_size: Optional[List[int]] = None,
|
||||||
num_prefix_tokens: int = 1,
|
num_prefix_tokens: int = 1,
|
||||||
@ -58,7 +58,7 @@ def resample_abs_pos_embed(
|
|||||||
|
|
||||||
|
|
||||||
def resample_abs_pos_embed_nhwc(
|
def resample_abs_pos_embed_nhwc(
|
||||||
posemb,
|
posemb: torch.Tensor,
|
||||||
new_size: List[int],
|
new_size: List[int],
|
||||||
interpolation: str = 'bicubic',
|
interpolation: str = 'bicubic',
|
||||||
antialias: bool = True,
|
antialias: bool = True,
|
||||||
@ -69,7 +69,6 @@ def resample_abs_pos_embed_nhwc(
|
|||||||
|
|
||||||
orig_dtype = posemb.dtype
|
orig_dtype = posemb.dtype
|
||||||
posemb = posemb.float()
|
posemb = posemb.float()
|
||||||
# do the interpolation
|
|
||||||
posemb = posemb.reshape(1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 3, 1, 2)
|
posemb = posemb.reshape(1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 3, 1, 2)
|
||||||
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
|
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
|
||||||
posemb = posemb.permute(0, 2, 3, 1).to(orig_dtype)
|
posemb = posemb.permute(0, 2, 3, 1).to(orig_dtype)
|
||||||
|
@ -18,6 +18,7 @@ except ImportError:
|
|||||||
|
|
||||||
# Layers we went to treat as leaf modules
|
# Layers we went to treat as leaf modules
|
||||||
from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame, Format
|
from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame, Format
|
||||||
|
from timm.layers import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
|
||||||
from timm.layers.non_local_attn import BilinearAttnTransform
|
from timm.layers.non_local_attn import BilinearAttnTransform
|
||||||
from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
|
from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
|
||||||
from timm.layers.norm_act import (
|
from timm.layers.norm_act import (
|
||||||
@ -75,7 +76,10 @@ def get_notrace_modules():
|
|||||||
|
|
||||||
|
|
||||||
# Functions we want to autowrap (treat them as leaves)
|
# Functions we want to autowrap (treat them as leaves)
|
||||||
_autowrap_functions = set()
|
_autowrap_functions = {
|
||||||
|
resample_abs_pos_embed,
|
||||||
|
resample_abs_pos_embed_nhwc,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def register_notrace_function(func: Callable):
|
def register_notrace_function(func: Callable):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user