diff --git a/timm/layers/attention.py b/timm/layers/attention.py index 936ce6c0..73cf93f8 100644 --- a/timm/layers/attention.py +++ b/timm/layers/attention.py @@ -8,7 +8,6 @@ from .config import use_fused_attn from .pos_embed_sincos import apply_rot_embed_cat -@torch.fx.wrap def maybe_add_mask(scores: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): return scores if attn_mask is None else scores + attn_mask diff --git a/timm/models/_features_fx.py b/timm/models/_features_fx.py index 6679b38b..0352d79a 100644 --- a/timm/models/_features_fx.py +++ b/timm/models/_features_fx.py @@ -18,7 +18,7 @@ except ImportError: # Layers we went to treat as leaf modules from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame, Format -from timm.layers import resample_abs_pos_embed, resample_abs_pos_embed_nhwc +from timm.layers import resample_abs_pos_embed, resample_abs_pos_embed_nhwc, maybe_add_mask from timm.layers.non_local_attn import BilinearAttnTransform from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame from timm.layers.norm_act import ( @@ -79,6 +79,7 @@ def get_notrace_modules(): _autowrap_functions = { resample_abs_pos_embed, resample_abs_pos_embed_nhwc, + maybe_add_mask, }