mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add global_pool to mambaout __init__ and pass to heads
This commit is contained in:
parent
9d1dfe8dbe
commit
5dc5ee5b42
@ -300,6 +300,7 @@ class MambaOut(nn.Module):
|
||||
self,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
global_pool='avg',
|
||||
depths=(3, 3, 9, 3),
|
||||
dims=(96, 192, 384, 576),
|
||||
norm_layer=LayerNorm,
|
||||
@ -369,7 +370,7 @@ class MambaOut(nn.Module):
|
||||
self.head = MlpHead(
|
||||
prev_dim,
|
||||
num_classes,
|
||||
pool_type='avg',
|
||||
pool_type=global_pool,
|
||||
drop_rate=drop_rate,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
@ -379,7 +380,7 @@ class MambaOut(nn.Module):
|
||||
prev_dim,
|
||||
num_classes,
|
||||
hidden_size=int(prev_dim * 4),
|
||||
pool_type='avg',
|
||||
pool_type=global_pool,
|
||||
norm_layer=norm_layer,
|
||||
drop_rate=drop_rate,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user