Update swin_transformer.py

make `SwimTransformer`'s `patch_embed` customizable through the constructor
This commit is contained in:
Laureηt 2023-10-30 16:15:05 +00:00 committed by Ross Wightman
parent 68a121402f
commit d023154bb5

View File

@ -469,6 +469,7 @@ class SwinTransformer(nn.Module):
proj_drop_rate: float = 0., proj_drop_rate: float = 0.,
attn_drop_rate: float = 0., attn_drop_rate: float = 0.,
drop_path_rate: float = 0.1, drop_path_rate: float = 0.1,
embed_layer: Callable = PatchEmbed,
norm_layer: Union[str, Callable] = nn.LayerNorm, norm_layer: Union[str, Callable] = nn.LayerNorm,
weight_init: str = '', weight_init: str = '',
**kwargs, **kwargs,
@ -489,6 +490,7 @@ class SwinTransformer(nn.Module):
drop_rate: Dropout rate. drop_rate: Dropout rate.
attn_drop_rate (float): Attention dropout rate. attn_drop_rate (float): Attention dropout rate.
drop_path_rate (float): Stochastic depth rate. drop_path_rate (float): Stochastic depth rate.
embed_layer: Patch embedding layer.
norm_layer (nn.Module): Normalization layer. norm_layer (nn.Module): Normalization layer.
""" """
super().__init__() super().__init__()
@ -506,7 +508,7 @@ class SwinTransformer(nn.Module):
embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
# split image into non-overlapping patches # split image into non-overlapping patches
self.patch_embed = PatchEmbed( self.patch_embed = embed_layer(
img_size=img_size, img_size=img_size,
patch_size=patch_size, patch_size=patch_size,
in_chans=in_chans, in_chans=in_chans,