mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #2105 from SmilingWolf/main
SwinV2: add configurable act_layer argument
This commit is contained in:
commit
2ec2f1aa73
@ -22,7 +22,7 @@ import torch.utils.checkpoint as checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead,\
|
||||
resample_patch_embed, ndgrid
|
||||
resample_patch_embed, ndgrid, get_act_layer, LayerType
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
@ -206,7 +206,7 @@ class SwinTransformerV2Block(nn.Module):
|
||||
proj_drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
drop_path: float = 0.,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
act_layer: LayerType = "gelu",
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
pretrained_window_size: _int_or_tuple_2_t = 0,
|
||||
) -> None:
|
||||
@ -235,6 +235,7 @@ class SwinTransformerV2Block(nn.Module):
|
||||
self.shift_size: Tuple[int, int] = ss
|
||||
self.window_area = self.window_size[0] * self.window_size[1]
|
||||
self.mlp_ratio = mlp_ratio
|
||||
act_layer = get_act_layer(act_layer)
|
||||
|
||||
self.attn = WindowAttention(
|
||||
dim,
|
||||
@ -372,6 +373,7 @@ class SwinTransformerV2Stage(nn.Module):
|
||||
proj_drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
drop_path: float = 0.,
|
||||
act_layer: Union[str, Callable] = 'gelu',
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
pretrained_window_size: _int_or_tuple_2_t = 0,
|
||||
output_nchw: bool = False,
|
||||
@ -390,6 +392,7 @@ class SwinTransformerV2Stage(nn.Module):
|
||||
proj_drop: Projection dropout rate
|
||||
attn_drop: Attention dropout rate.
|
||||
drop_path: Stochastic depth rate.
|
||||
act_layer: Activation layer type.
|
||||
norm_layer: Normalization layer.
|
||||
pretrained_window_size: Local window size in pretraining.
|
||||
output_nchw: Output tensors on NCHW format instead of NHWC.
|
||||
@ -424,6 +427,7 @@ class SwinTransformerV2Stage(nn.Module):
|
||||
proj_drop=proj_drop,
|
||||
attn_drop=attn_drop,
|
||||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
pretrained_window_size=pretrained_window_size,
|
||||
)
|
||||
@ -471,6 +475,7 @@ class SwinTransformerV2(nn.Module):
|
||||
proj_drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.1,
|
||||
act_layer: Union[str, Callable] = 'gelu',
|
||||
norm_layer: Callable = nn.LayerNorm,
|
||||
pretrained_window_sizes: Tuple[int, ...] = (0, 0, 0, 0),
|
||||
**kwargs,
|
||||
@ -492,6 +497,7 @@ class SwinTransformerV2(nn.Module):
|
||||
attn_drop_rate: Attention dropout rate.
|
||||
drop_path_rate: Stochastic depth rate.
|
||||
norm_layer: Normalization layer.
|
||||
act_layer: Activation layer type.
|
||||
patch_norm: If True, add normalization after patch embedding.
|
||||
pretrained_window_sizes: Pretrained window sizes of each layer.
|
||||
output_fmt: Output tensor format if not None, otherwise output 'NHWC' by default.
|
||||
@ -541,6 +547,7 @@ class SwinTransformerV2(nn.Module):
|
||||
proj_drop=proj_drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
pretrained_window_size=pretrained_window_sizes[i],
|
||||
)]
|
||||
|
Loading…
x
Reference in New Issue
Block a user