mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
torch.fx.wrap not working with older pytorch, trying register_notrace instead
This commit is contained in:
parent
842a786626
commit
b7ced7c40c
@ -8,7 +8,6 @@ from .config import use_fused_attn
|
||||
from .pos_embed_sincos import apply_rot_embed_cat
|
||||
|
||||
|
||||
@torch.fx.wrap
|
||||
def maybe_add_mask(scores: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
||||
return scores if attn_mask is None else scores + attn_mask
|
||||
|
||||
|
@ -18,7 +18,7 @@ except ImportError:
|
||||
|
||||
# Layers we went to treat as leaf modules
|
||||
from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame, Format
|
||||
from timm.layers import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
|
||||
from timm.layers import resample_abs_pos_embed, resample_abs_pos_embed_nhwc, maybe_add_mask
|
||||
from timm.layers.non_local_attn import BilinearAttnTransform
|
||||
from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
|
||||
from timm.layers.norm_act import (
|
||||
@ -79,6 +79,7 @@ def get_notrace_modules():
|
||||
_autowrap_functions = {
|
||||
resample_abs_pos_embed,
|
||||
resample_abs_pos_embed_nhwc,
|
||||
maybe_add_mask,
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user