Another small fix for original mambaout models, no classifier nn.Linear when num_classe=0 on init

intern300m
Ross Wightman 2024-10-16 12:36:36 -07:00
parent fad4538801
commit 89dffc5ff0
1 changed files with 1 additions and 1 deletions

View File

@ -151,7 +151,7 @@ class MlpHead(nn.Module):
self.num_features = in_features self.num_features = in_features
self.pre_logits = nn.Identity() self.pre_logits = nn.Identity()
self.fc = nn.Linear(hidden_size, num_classes, bias=bias) self.fc = nn.Linear(hidden_size, num_classes, bias=bias) if num_classes > 0 else nn.Identity()
self.head_dropout = nn.Dropout(drop_rate) self.head_dropout = nn.Dropout(drop_rate)
def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False): def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False):