mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
fast_vit: propagate act_layer argument
This commit is contained in:
parent
95ba90157f
commit
8ba2038e6b
@ -421,7 +421,8 @@ class ReparamLargeKernelConv(nn.Module):
|
|||||||
def convolutional_stem(
|
def convolutional_stem(
|
||||||
in_chs: int,
|
in_chs: int,
|
||||||
out_chs: int,
|
out_chs: int,
|
||||||
inference_mode: bool = False
|
inference_mode: bool = False,
|
||||||
|
act_layer: nn.Module = nn.GELU,
|
||||||
) -> nn.Sequential:
|
) -> nn.Sequential:
|
||||||
"""Build convolutional stem with MobileOne blocks.
|
"""Build convolutional stem with MobileOne blocks.
|
||||||
|
|
||||||
@ -439,6 +440,7 @@ def convolutional_stem(
|
|||||||
out_chs=out_chs,
|
out_chs=out_chs,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=2,
|
stride=2,
|
||||||
|
act_layer=act_layer,
|
||||||
inference_mode=inference_mode,
|
inference_mode=inference_mode,
|
||||||
),
|
),
|
||||||
MobileOneBlock(
|
MobileOneBlock(
|
||||||
@ -447,13 +449,15 @@ def convolutional_stem(
|
|||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=2,
|
stride=2,
|
||||||
group_size=1,
|
group_size=1,
|
||||||
|
act_layer=act_layer,
|
||||||
inference_mode=inference_mode,
|
inference_mode=inference_mode,
|
||||||
),
|
),
|
||||||
MobileOneBlock(
|
MobileOneBlock(
|
||||||
in_chs=out_chs,
|
in_chs=out_chs,
|
||||||
out_chs=out_chs,
|
out_chs=out_chs,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
|
act_layer=act_layer,
|
||||||
inference_mode=inference_mode,
|
inference_mode=inference_mode,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -1121,6 +1125,7 @@ class FastVit(nn.Module):
|
|||||||
in_chans,
|
in_chans,
|
||||||
embed_dims[0],
|
embed_dims[0],
|
||||||
inference_mode,
|
inference_mode,
|
||||||
|
act_layer
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build the main stages of the network architecture
|
# Build the main stages of the network architecture
|
||||||
|
Loading…
x
Reference in New Issue
Block a user