Register orm_act layers as leaf modules

This commit is contained in:
Daniel Suess 2023-08-10 15:37:26 +10:00
parent c692715388
commit 986de90360
No known key found for this signature in database

View File

@ -18,6 +18,15 @@ except ImportError:
from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame
from timm.layers.non_local_attn import BilinearAttnTransform
from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
from timm.layers.norm_act import (
BatchNormAct2d,
SyncBatchNormAct,
FrozenBatchNormAct2d,
GroupNormAct,
GroupNorm1Act,
LayerNormAct,
LayerNormAct2d
)
__all__ = ['register_notrace_module', 'is_notrace_module', 'get_notrace_modules',
'register_notrace_function', 'is_notrace_function', 'get_notrace_functions',
@ -30,7 +39,14 @@ _leaf_modules = {
BilinearAttnTransform, # reason: flow control t <= 1
# Reason: get_same_padding has a max which raises a control flow error
Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0])
CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]),
BatchNormAct2d,
SyncBatchNormAct,
FrozenBatchNormAct2d,
GroupNormAct,
GroupNorm1Act,
LayerNormAct,
LayerNormAct2d,
}
try: