mambaout.py: fixed bug

pull/2305/head
Feraidoon Mehri 2024-10-17 01:03:28 +03:30 committed by GitHub
parent 8cb2548962
commit ca20e102fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 1 additions and 1 deletions

View File

@ -151,7 +151,7 @@ class MlpHead(nn.Module):
self.num_features = in_features
self.pre_logits = nn.Identity()
self.fc = nn.Linear(hidden_size, num_classes, bias=bias) if num_classes > 0 else nn.Identity()
self.fc = nn.Linear(self.num_features, num_classes, bias=bias) if num_classes > 0 else nn.Identity()
self.head_dropout = nn.Dropout(drop_rate)
def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False):