Fix self.reset_classifier num_classes update
parent
84631cb5c6
commit
80a4877376
|
@ -633,6 +633,7 @@ class DaVit(nn.Module):
|
|||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
|
|
|
@ -455,6 +455,7 @@ class FocalNet(nn.Module):
|
|||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
|
|
|
@ -584,6 +584,7 @@ class MetaFormer(nn.Module):
|
|||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
||||
|
|
|
@ -557,6 +557,7 @@ class NextViT(nn.Module):
|
|||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
|
|
|
@ -434,6 +434,7 @@ class NormFreeNet(nn.Module):
|
|||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
|
|
|
@ -384,7 +384,7 @@ class PyramidVisionTransformerV2(nn.Module):
|
|||
if global_pool is not None:
|
||||
assert global_pool in ('avg', '')
|
||||
self.global_pool = global_pool
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
|
|
|
@ -349,6 +349,7 @@ class RDNet(nn.Module):
|
|||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
|
|
|
@ -515,6 +515,7 @@ class RegNet(nn.Module):
|
|||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
|
|
|
@ -225,6 +225,7 @@ class TResNet(nn.Module):
|
|||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
|
|
|
@ -537,6 +537,7 @@ class VisionTransformerSAM(nn.Module):
|
|||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_intermediates(
|
||||
|
|
|
@ -275,6 +275,7 @@ class XceptionAligned(nn.Module):
|
|||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
|
|
Loading…
Reference in New Issue