mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
torch.fx.wrap not working with older pytorch, trying register_notrace instead
This commit is contained in:
parent
842a786626
commit
b7ced7c40c
@ -8,7 +8,6 @@ from .config import use_fused_attn
|
|||||||
from .pos_embed_sincos import apply_rot_embed_cat
|
from .pos_embed_sincos import apply_rot_embed_cat
|
||||||
|
|
||||||
|
|
||||||
@torch.fx.wrap
|
|
||||||
def maybe_add_mask(scores: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
def maybe_add_mask(scores: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
||||||
return scores if attn_mask is None else scores + attn_mask
|
return scores if attn_mask is None else scores + attn_mask
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ except ImportError:
|
|||||||
|
|
||||||
# Layers we went to treat as leaf modules
|
# Layers we went to treat as leaf modules
|
||||||
from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame, Format
|
from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame, Format
|
||||||
from timm.layers import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
|
from timm.layers import resample_abs_pos_embed, resample_abs_pos_embed_nhwc, maybe_add_mask
|
||||||
from timm.layers.non_local_attn import BilinearAttnTransform
|
from timm.layers.non_local_attn import BilinearAttnTransform
|
||||||
from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
|
from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
|
||||||
from timm.layers.norm_act import (
|
from timm.layers.norm_act import (
|
||||||
@ -79,6 +79,7 @@ def get_notrace_modules():
|
|||||||
_autowrap_functions = {
|
_autowrap_functions = {
|
||||||
resample_abs_pos_embed,
|
resample_abs_pos_embed,
|
||||||
resample_abs_pos_embed_nhwc,
|
resample_abs_pos_embed_nhwc,
|
||||||
|
maybe_add_mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user