diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index b152b544..e8a5a9cf 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -22,7 +22,7 @@ import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead,\ - resample_patch_embed, ndgrid + resample_patch_embed, ndgrid, get_act_layer, LayerType from ._builder import build_model_with_cfg from ._features_fx import register_notrace_function from ._registry import generate_default_cfgs, register_model, register_model_deprecations @@ -206,7 +206,7 @@ class SwinTransformerV2Block(nn.Module): proj_drop: float = 0., attn_drop: float = 0., drop_path: float = 0., - act_layer: nn.Module = nn.GELU, + act_layer: LayerType = "gelu", norm_layer: nn.Module = nn.LayerNorm, pretrained_window_size: _int_or_tuple_2_t = 0, ) -> None: @@ -235,6 +235,7 @@ class SwinTransformerV2Block(nn.Module): self.shift_size: Tuple[int, int] = ss self.window_area = self.window_size[0] * self.window_size[1] self.mlp_ratio = mlp_ratio + act_layer = get_act_layer(act_layer) self.attn = WindowAttention( dim, @@ -372,6 +373,7 @@ class SwinTransformerV2Stage(nn.Module): proj_drop: float = 0., attn_drop: float = 0., drop_path: float = 0., + act_layer: Union[str, Callable] = 'gelu', norm_layer: nn.Module = nn.LayerNorm, pretrained_window_size: _int_or_tuple_2_t = 0, output_nchw: bool = False, @@ -390,6 +392,7 @@ class SwinTransformerV2Stage(nn.Module): proj_drop: Projection dropout rate attn_drop: Attention dropout rate. drop_path: Stochastic depth rate. + act_layer: Activation layer type. norm_layer: Normalization layer. pretrained_window_size: Local window size in pretraining. output_nchw: Output tensors on NCHW format instead of NHWC. @@ -424,6 +427,7 @@ class SwinTransformerV2Stage(nn.Module): proj_drop=proj_drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + act_layer=act_layer, norm_layer=norm_layer, pretrained_window_size=pretrained_window_size, ) @@ -471,6 +475,7 @@ class SwinTransformerV2(nn.Module): proj_drop_rate: float = 0., attn_drop_rate: float = 0., drop_path_rate: float = 0.1, + act_layer: Union[str, Callable] = 'gelu', norm_layer: Callable = nn.LayerNorm, pretrained_window_sizes: Tuple[int, ...] = (0, 0, 0, 0), **kwargs, @@ -492,6 +497,7 @@ class SwinTransformerV2(nn.Module): attn_drop_rate: Attention dropout rate. drop_path_rate: Stochastic depth rate. norm_layer: Normalization layer. + act_layer: Activation layer type. patch_norm: If True, add normalization after patch embedding. pretrained_window_sizes: Pretrained window sizes of each layer. output_fmt: Output tensor format if not None, otherwise output 'NHWC' by default. @@ -541,6 +547,7 @@ class SwinTransformerV2(nn.Module): proj_drop=proj_drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], + act_layer=act_layer, norm_layer=norm_layer, pretrained_window_size=pretrained_window_sizes[i], )]