[Fix] Fix type in Swin Transformer (#1274)

This commit is contained in:
MengzhangLI 2022-02-09 19:08:32 +08:00 committed by GitHub
parent a39f5856ce
commit b16310182b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -479,7 +479,7 @@ class SwinTransformer(BaseModule):
embed_dims (int): The feature dimension. Default: 96. embed_dims (int): The feature dimension. Default: 96.
patch_size (int | tuple[int]): Patch size. Default: 4. patch_size (int | tuple[int]): Patch size. Default: 4.
window_size (int): Window size. Default: 7. window_size (int): Window size. Default: 7.
mlp_ratio (int): Ratio of mlp hidden dim to embedding dim. mlp_ratio (int | float): Ratio of mlp hidden dim to embedding dim.
Default: 4. Default: 4.
depths (tuple[int]): Depths of each Swin Transformer stage. depths (tuple[int]): Depths of each Swin Transformer stage.
Default: (2, 2, 6, 2). Default: (2, 2, 6, 2).
@ -610,7 +610,7 @@ class SwinTransformer(BaseModule):
stage = SwinBlockSequence( stage = SwinBlockSequence(
embed_dims=in_channels, embed_dims=in_channels,
num_heads=num_heads[i], num_heads=num_heads[i],
feedforward_channels=mlp_ratio * in_channels, feedforward_channels=int(mlp_ratio * in_channels),
depth=depths[i], depth=depths[i],
window_size=window_size, window_size=window_size,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,