Another small fix for original mambaout models, no classifier nn.Linear when num_classe=0 on init
parent
fad4538801
commit
89dffc5ff0
|
@ -151,7 +151,7 @@ class MlpHead(nn.Module):
|
||||||
self.num_features = in_features
|
self.num_features = in_features
|
||||||
self.pre_logits = nn.Identity()
|
self.pre_logits = nn.Identity()
|
||||||
|
|
||||||
self.fc = nn.Linear(hidden_size, num_classes, bias=bias)
|
self.fc = nn.Linear(hidden_size, num_classes, bias=bias) if num_classes > 0 else nn.Identity()
|
||||||
self.head_dropout = nn.Dropout(drop_rate)
|
self.head_dropout = nn.Dropout(drop_rate)
|
||||||
|
|
||||||
def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False):
|
def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False):
|
||||||
|
|
Loading…
Reference in New Issue