Merge pull request #2105 from SmilingWolf/main

SwinV2: add configurable act_layer argument
This commit is contained in:
Ross Wightman 2024-03-05 22:17:10 -08:00 committed by GitHub
commit 2ec2f1aa73
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -22,7 +22,7 @@ import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead,\ from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead,\
resample_patch_embed, ndgrid resample_patch_embed, ndgrid, get_act_layer, LayerType
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function from ._features_fx import register_notrace_function
from ._registry import generate_default_cfgs, register_model, register_model_deprecations from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@ -206,7 +206,7 @@ class SwinTransformerV2Block(nn.Module):
proj_drop: float = 0., proj_drop: float = 0.,
attn_drop: float = 0., attn_drop: float = 0.,
drop_path: float = 0., drop_path: float = 0.,
act_layer: nn.Module = nn.GELU, act_layer: LayerType = "gelu",
norm_layer: nn.Module = nn.LayerNorm, norm_layer: nn.Module = nn.LayerNorm,
pretrained_window_size: _int_or_tuple_2_t = 0, pretrained_window_size: _int_or_tuple_2_t = 0,
) -> None: ) -> None:
@ -235,6 +235,7 @@ class SwinTransformerV2Block(nn.Module):
self.shift_size: Tuple[int, int] = ss self.shift_size: Tuple[int, int] = ss
self.window_area = self.window_size[0] * self.window_size[1] self.window_area = self.window_size[0] * self.window_size[1]
self.mlp_ratio = mlp_ratio self.mlp_ratio = mlp_ratio
act_layer = get_act_layer(act_layer)
self.attn = WindowAttention( self.attn = WindowAttention(
dim, dim,
@ -372,6 +373,7 @@ class SwinTransformerV2Stage(nn.Module):
proj_drop: float = 0., proj_drop: float = 0.,
attn_drop: float = 0., attn_drop: float = 0.,
drop_path: float = 0., drop_path: float = 0.,
act_layer: Union[str, Callable] = 'gelu',
norm_layer: nn.Module = nn.LayerNorm, norm_layer: nn.Module = nn.LayerNorm,
pretrained_window_size: _int_or_tuple_2_t = 0, pretrained_window_size: _int_or_tuple_2_t = 0,
output_nchw: bool = False, output_nchw: bool = False,
@ -390,6 +392,7 @@ class SwinTransformerV2Stage(nn.Module):
proj_drop: Projection dropout rate proj_drop: Projection dropout rate
attn_drop: Attention dropout rate. attn_drop: Attention dropout rate.
drop_path: Stochastic depth rate. drop_path: Stochastic depth rate.
act_layer: Activation layer type.
norm_layer: Normalization layer. norm_layer: Normalization layer.
pretrained_window_size: Local window size in pretraining. pretrained_window_size: Local window size in pretraining.
output_nchw: Output tensors on NCHW format instead of NHWC. output_nchw: Output tensors on NCHW format instead of NHWC.
@ -424,6 +427,7 @@ class SwinTransformerV2Stage(nn.Module):
proj_drop=proj_drop, proj_drop=proj_drop,
attn_drop=attn_drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
act_layer=act_layer,
norm_layer=norm_layer, norm_layer=norm_layer,
pretrained_window_size=pretrained_window_size, pretrained_window_size=pretrained_window_size,
) )
@ -471,6 +475,7 @@ class SwinTransformerV2(nn.Module):
proj_drop_rate: float = 0., proj_drop_rate: float = 0.,
attn_drop_rate: float = 0., attn_drop_rate: float = 0.,
drop_path_rate: float = 0.1, drop_path_rate: float = 0.1,
act_layer: Union[str, Callable] = 'gelu',
norm_layer: Callable = nn.LayerNorm, norm_layer: Callable = nn.LayerNorm,
pretrained_window_sizes: Tuple[int, ...] = (0, 0, 0, 0), pretrained_window_sizes: Tuple[int, ...] = (0, 0, 0, 0),
**kwargs, **kwargs,
@ -492,6 +497,7 @@ class SwinTransformerV2(nn.Module):
attn_drop_rate: Attention dropout rate. attn_drop_rate: Attention dropout rate.
drop_path_rate: Stochastic depth rate. drop_path_rate: Stochastic depth rate.
norm_layer: Normalization layer. norm_layer: Normalization layer.
act_layer: Activation layer type.
patch_norm: If True, add normalization after patch embedding. patch_norm: If True, add normalization after patch embedding.
pretrained_window_sizes: Pretrained window sizes of each layer. pretrained_window_sizes: Pretrained window sizes of each layer.
output_fmt: Output tensor format if not None, otherwise output 'NHWC' by default. output_fmt: Output tensor format if not None, otherwise output 'NHWC' by default.
@ -541,6 +547,7 @@ class SwinTransformerV2(nn.Module):
proj_drop=proj_drop_rate, proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[i], drop_path=dpr[i],
act_layer=act_layer,
norm_layer=norm_layer, norm_layer=norm_layer,
pretrained_window_size=pretrained_window_sizes[i], pretrained_window_size=pretrained_window_sizes[i],
)] )]