From ca20e102feea9d6147163243a654067e4d7e5031 Mon Sep 17 00:00:00 2001 From: Feraidoon Mehri <36224762+NightMachinery@users.noreply.github.com> Date: Thu, 17 Oct 2024 01:03:28 +0330 Subject: [PATCH] mambaout.py: fixed bug --- timm/models/mambaout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/mambaout.py b/timm/models/mambaout.py index 3cc6e082..f53a9cdf 100644 --- a/timm/models/mambaout.py +++ b/timm/models/mambaout.py @@ -151,7 +151,7 @@ class MlpHead(nn.Module): self.num_features = in_features self.pre_logits = nn.Identity() - self.fc = nn.Linear(hidden_size, num_classes, bias=bias) if num_classes > 0 else nn.Identity() + self.fc = nn.Linear(self.num_features, num_classes, bias=bias) if num_classes > 0 else nn.Identity() self.head_dropout = nn.Dropout(drop_rate) def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False):