mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update swin_transformer.py
make `SwimTransformer`'s `patch_embed` customizable through the constructor
This commit is contained in:
parent
68a121402f
commit
d023154bb5
@ -469,6 +469,7 @@ class SwinTransformer(nn.Module):
|
|||||||
proj_drop_rate: float = 0.,
|
proj_drop_rate: float = 0.,
|
||||||
attn_drop_rate: float = 0.,
|
attn_drop_rate: float = 0.,
|
||||||
drop_path_rate: float = 0.1,
|
drop_path_rate: float = 0.1,
|
||||||
|
embed_layer: Callable = PatchEmbed,
|
||||||
norm_layer: Union[str, Callable] = nn.LayerNorm,
|
norm_layer: Union[str, Callable] = nn.LayerNorm,
|
||||||
weight_init: str = '',
|
weight_init: str = '',
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -489,6 +490,7 @@ class SwinTransformer(nn.Module):
|
|||||||
drop_rate: Dropout rate.
|
drop_rate: Dropout rate.
|
||||||
attn_drop_rate (float): Attention dropout rate.
|
attn_drop_rate (float): Attention dropout rate.
|
||||||
drop_path_rate (float): Stochastic depth rate.
|
drop_path_rate (float): Stochastic depth rate.
|
||||||
|
embed_layer: Patch embedding layer.
|
||||||
norm_layer (nn.Module): Normalization layer.
|
norm_layer (nn.Module): Normalization layer.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -506,7 +508,7 @@ class SwinTransformer(nn.Module):
|
|||||||
embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
||||||
|
|
||||||
# split image into non-overlapping patches
|
# split image into non-overlapping patches
|
||||||
self.patch_embed = PatchEmbed(
|
self.patch_embed = embed_layer(
|
||||||
img_size=img_size,
|
img_size=img_size,
|
||||||
patch_size=patch_size,
|
patch_size=patch_size,
|
||||||
in_chans=in_chans,
|
in_chans=in_chans,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user