torch.fx.wrap not working with older pytorch, trying register_notrace instead

This commit is contained in:
Ross Wightman 2025-05-25 14:13:36 -07:00
parent 842a786626
commit b7ced7c40c
2 changed files with 2 additions and 2 deletions

View File

@ -8,7 +8,6 @@ from .config import use_fused_attn
from .pos_embed_sincos import apply_rot_embed_cat
@torch.fx.wrap
def maybe_add_mask(scores: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
return scores if attn_mask is None else scores + attn_mask

View File

@ -18,7 +18,7 @@ except ImportError:
# Layers we went to treat as leaf modules
from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame, Format
from timm.layers import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
from timm.layers import resample_abs_pos_embed, resample_abs_pos_embed_nhwc, maybe_add_mask
from timm.layers.non_local_attn import BilinearAttnTransform
from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
from timm.layers.norm_act import (
@ -79,6 +79,7 @@ def get_notrace_modules():
_autowrap_functions = {
resample_abs_pos_embed,
resample_abs_pos_embed_nhwc,
maybe_add_mask,
}