Add global_pool to mambaout __init__ and pass to heads

This commit is contained in:
Ross Wightman 2024-09-13 19:51:33 -07:00
parent 9d1dfe8dbe
commit 5dc5ee5b42

View File

@ -300,6 +300,7 @@ class MambaOut(nn.Module):
self,
in_chans=3,
num_classes=1000,
global_pool='avg',
depths=(3, 3, 9, 3),
dims=(96, 192, 384, 576),
norm_layer=LayerNorm,
@ -369,7 +370,7 @@ class MambaOut(nn.Module):
self.head = MlpHead(
prev_dim,
num_classes,
pool_type='avg',
pool_type=global_pool,
drop_rate=drop_rate,
norm_layer=norm_layer,
)
@ -379,7 +380,7 @@ class MambaOut(nn.Module):
prev_dim,
num_classes,
hidden_size=int(prev_dim * 4),
pool_type='avg',
pool_type=global_pool,
norm_layer=norm_layer,
drop_rate=drop_rate,
)