diff --git a/timm/models/fastvit.py b/timm/models/fastvit.py index f61b54e5..fbeeb39e 100644 --- a/timm/models/fastvit.py +++ b/timm/models/fastvit.py @@ -421,7 +421,8 @@ class ReparamLargeKernelConv(nn.Module): def convolutional_stem( in_chs: int, out_chs: int, - inference_mode: bool = False + inference_mode: bool = False, + act_layer: nn.Module = nn.GELU, ) -> nn.Sequential: """Build convolutional stem with MobileOne blocks. @@ -439,6 +440,7 @@ def convolutional_stem( out_chs=out_chs, kernel_size=3, stride=2, + act_layer=act_layer, inference_mode=inference_mode, ), MobileOneBlock( @@ -447,13 +449,15 @@ def convolutional_stem( kernel_size=3, stride=2, group_size=1, + act_layer=act_layer, inference_mode=inference_mode, ), MobileOneBlock( in_chs=out_chs, out_chs=out_chs, kernel_size=1, - stride=1, + stride=1, + act_layer=act_layer, inference_mode=inference_mode, ), ) @@ -1121,6 +1125,7 @@ class FastVit(nn.Module): in_chans, embed_dims[0], inference_mode, + act_layer ) # Build the main stages of the network architecture