Update swin_transformer.py

make `SwimTransformer`'s `patch_embed` customizable through the constructor
This commit is contained in:
Laureηt 2023-10-30 16:15:05 +00:00 committed by Ross Wightman
parent 68a121402f
commit d023154bb5

View File

@ -469,6 +469,7 @@ class SwinTransformer(nn.Module):
proj_drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.1,
embed_layer: Callable = PatchEmbed,
norm_layer: Union[str, Callable] = nn.LayerNorm,
weight_init: str = '',
**kwargs,
@ -489,6 +490,7 @@ class SwinTransformer(nn.Module):
drop_rate: Dropout rate.
attn_drop_rate (float): Attention dropout rate.
drop_path_rate (float): Stochastic depth rate.
embed_layer: Patch embedding layer.
norm_layer (nn.Module): Normalization layer.
"""
super().__init__()
@ -506,7 +508,7 @@ class SwinTransformer(nn.Module):
embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
self.patch_embed = embed_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,