mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add global_pool to mambaout __init__ and pass to heads
This commit is contained in:
parent
9d1dfe8dbe
commit
5dc5ee5b42
@ -300,6 +300,7 @@ class MambaOut(nn.Module):
|
|||||||
self,
|
self,
|
||||||
in_chans=3,
|
in_chans=3,
|
||||||
num_classes=1000,
|
num_classes=1000,
|
||||||
|
global_pool='avg',
|
||||||
depths=(3, 3, 9, 3),
|
depths=(3, 3, 9, 3),
|
||||||
dims=(96, 192, 384, 576),
|
dims=(96, 192, 384, 576),
|
||||||
norm_layer=LayerNorm,
|
norm_layer=LayerNorm,
|
||||||
@ -369,7 +370,7 @@ class MambaOut(nn.Module):
|
|||||||
self.head = MlpHead(
|
self.head = MlpHead(
|
||||||
prev_dim,
|
prev_dim,
|
||||||
num_classes,
|
num_classes,
|
||||||
pool_type='avg',
|
pool_type=global_pool,
|
||||||
drop_rate=drop_rate,
|
drop_rate=drop_rate,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
)
|
)
|
||||||
@ -379,7 +380,7 @@ class MambaOut(nn.Module):
|
|||||||
prev_dim,
|
prev_dim,
|
||||||
num_classes,
|
num_classes,
|
||||||
hidden_size=int(prev_dim * 4),
|
hidden_size=int(prev_dim * 4),
|
||||||
pool_type='avg',
|
pool_type=global_pool,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
drop_rate=drop_rate,
|
drop_rate=drop_rate,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user