mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update swin_transformer.py
make `SwimTransformer`'s `patch_embed` customizable through the constructor
This commit is contained in:
parent
68a121402f
commit
d023154bb5
@ -469,6 +469,7 @@ class SwinTransformer(nn.Module):
|
||||
proj_drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.1,
|
||||
embed_layer: Callable = PatchEmbed,
|
||||
norm_layer: Union[str, Callable] = nn.LayerNorm,
|
||||
weight_init: str = '',
|
||||
**kwargs,
|
||||
@ -489,6 +490,7 @@ class SwinTransformer(nn.Module):
|
||||
drop_rate: Dropout rate.
|
||||
attn_drop_rate (float): Attention dropout rate.
|
||||
drop_path_rate (float): Stochastic depth rate.
|
||||
embed_layer: Patch embedding layer.
|
||||
norm_layer (nn.Module): Normalization layer.
|
||||
"""
|
||||
super().__init__()
|
||||
@ -506,7 +508,7 @@ class SwinTransformer(nn.Module):
|
||||
embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
||||
|
||||
# split image into non-overlapping patches
|
||||
self.patch_embed = PatchEmbed(
|
||||
self.patch_embed = embed_layer(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
|
Loading…
x
Reference in New Issue
Block a user