fast_vit: propagate act_layer argument

This commit is contained in:
Yassine 2023-09-26 19:38:38 -07:00 committed by Ross Wightman
parent 95ba90157f
commit 8ba2038e6b

View File

@ -421,7 +421,8 @@ class ReparamLargeKernelConv(nn.Module):
def convolutional_stem( def convolutional_stem(
in_chs: int, in_chs: int,
out_chs: int, out_chs: int,
inference_mode: bool = False inference_mode: bool = False,
act_layer: nn.Module = nn.GELU,
) -> nn.Sequential: ) -> nn.Sequential:
"""Build convolutional stem with MobileOne blocks. """Build convolutional stem with MobileOne blocks.
@ -439,6 +440,7 @@ def convolutional_stem(
out_chs=out_chs, out_chs=out_chs,
kernel_size=3, kernel_size=3,
stride=2, stride=2,
act_layer=act_layer,
inference_mode=inference_mode, inference_mode=inference_mode,
), ),
MobileOneBlock( MobileOneBlock(
@ -447,13 +449,15 @@ def convolutional_stem(
kernel_size=3, kernel_size=3,
stride=2, stride=2,
group_size=1, group_size=1,
act_layer=act_layer,
inference_mode=inference_mode, inference_mode=inference_mode,
), ),
MobileOneBlock( MobileOneBlock(
in_chs=out_chs, in_chs=out_chs,
out_chs=out_chs, out_chs=out_chs,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
act_layer=act_layer,
inference_mode=inference_mode, inference_mode=inference_mode,
), ),
) )
@ -1121,6 +1125,7 @@ class FastVit(nn.Module):
in_chans, in_chans,
embed_dims[0], embed_dims[0],
inference_mode, inference_mode,
act_layer
) )
# Build the main stages of the network architecture # Build the main stages of the network architecture