diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index 34452c7c..dc152e43 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -469,6 +469,7 @@ class SwinTransformer(nn.Module): proj_drop_rate: float = 0., attn_drop_rate: float = 0., drop_path_rate: float = 0.1, + embed_layer: Callable = PatchEmbed, norm_layer: Union[str, Callable] = nn.LayerNorm, weight_init: str = '', **kwargs, @@ -489,6 +490,7 @@ class SwinTransformer(nn.Module): drop_rate: Dropout rate. attn_drop_rate (float): Attention dropout rate. drop_path_rate (float): Stochastic depth rate. + embed_layer: Patch embedding layer. norm_layer (nn.Module): Normalization layer. """ super().__init__() @@ -506,7 +508,7 @@ class SwinTransformer(nn.Module): embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] # split image into non-overlapping patches - self.patch_embed = PatchEmbed( + self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans,