diff --git a/timm/models/mambaout.py b/timm/models/mambaout.py index c2f2f07b..bda69b11 100644 --- a/timm/models/mambaout.py +++ b/timm/models/mambaout.py @@ -300,6 +300,7 @@ class MambaOut(nn.Module): self, in_chans=3, num_classes=1000, + global_pool='avg', depths=(3, 3, 9, 3), dims=(96, 192, 384, 576), norm_layer=LayerNorm, @@ -369,7 +370,7 @@ class MambaOut(nn.Module): self.head = MlpHead( prev_dim, num_classes, - pool_type='avg', + pool_type=global_pool, drop_rate=drop_rate, norm_layer=norm_layer, ) @@ -379,7 +380,7 @@ class MambaOut(nn.Module): prev_dim, num_classes, hidden_size=int(prev_dim * 4), - pool_type='avg', + pool_type=global_pool, norm_layer=norm_layer, drop_rate=drop_rate, )