[Fix] Fix type in Swin Transformer (#1274)
parent
574adbe43b
commit
4de0c3708d
|
@ -479,7 +479,7 @@ class SwinTransformer(BaseModule):
|
||||||
embed_dims (int): The feature dimension. Default: 96.
|
embed_dims (int): The feature dimension. Default: 96.
|
||||||
patch_size (int | tuple[int]): Patch size. Default: 4.
|
patch_size (int | tuple[int]): Patch size. Default: 4.
|
||||||
window_size (int): Window size. Default: 7.
|
window_size (int): Window size. Default: 7.
|
||||||
mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
|
mlp_ratio (int | float): Ratio of mlp hidden dim to embedding dim.
|
||||||
Default: 4.
|
Default: 4.
|
||||||
depths (tuple[int]): Depths of each Swin Transformer stage.
|
depths (tuple[int]): Depths of each Swin Transformer stage.
|
||||||
Default: (2, 2, 6, 2).
|
Default: (2, 2, 6, 2).
|
||||||
|
@ -610,7 +610,7 @@ class SwinTransformer(BaseModule):
|
||||||
stage = SwinBlockSequence(
|
stage = SwinBlockSequence(
|
||||||
embed_dims=in_channels,
|
embed_dims=in_channels,
|
||||||
num_heads=num_heads[i],
|
num_heads=num_heads[i],
|
||||||
feedforward_channels=mlp_ratio * in_channels,
|
feedforward_channels=int(mlp_ratio * in_channels),
|
||||||
depth=depths[i],
|
depth=depths[i],
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
|
|
Loading…
Reference in New Issue