Merge pull request #2305 from NightMachinery/patch-2

mambaout.py: fixed bug
This commit is contained in:
Ross Wightman 2024-10-16 14:39:43 -07:00 committed by GitHub
commit a852318b63
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -151,7 +151,7 @@ class MlpHead(nn.Module):
self.num_features = in_features
self.pre_logits = nn.Identity()
self.fc = nn.Linear(hidden_size, num_classes, bias=bias) if num_classes > 0 else nn.Identity()
self.fc = nn.Linear(self.num_features, num_classes, bias=bias) if num_classes > 0 else nn.Identity()
self.head_dropout = nn.Dropout(drop_rate)
def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False):