Fix self.reset_classifier num_classes update

pull/2426/head
Ryan 2025-01-22 00:40:16 +08:00 committed by Ross Wightman
parent 84631cb5c6
commit 80a4877376
11 changed files with 11 additions and 1 deletions

View File

@ -633,6 +633,7 @@ class DaVit(nn.Module):
return self.head.fc
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_features(self, x):

View File

@ -455,6 +455,7 @@ class FocalNet(nn.Module):
return self.head.fc
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, pool_type=global_pool)
def forward_features(self, x):

View File

@ -584,6 +584,7 @@ class MetaFormer(nn.Module):
return self.head.fc
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()

View File

@ -557,6 +557,7 @@ class NextViT(nn.Module):
return self.head.fc
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, pool_type=global_pool)
def forward_features(self, x):

View File

@ -434,6 +434,7 @@ class NormFreeNet(nn.Module):
return self.head.fc
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_features(self, x):

View File

@ -384,7 +384,7 @@ class PyramidVisionTransformerV2(nn.Module):
if global_pool is not None:
assert global_pool in ('avg', '')
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)

View File

@ -349,6 +349,7 @@ class RDNet(nn.Module):
return self.head.fc
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_features(self, x):

View File

@ -515,6 +515,7 @@ class RegNet(nn.Module):
return self.head.fc
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, pool_type=global_pool)
def forward_intermediates(

View File

@ -225,6 +225,7 @@ class TResNet(nn.Module):
return self.head.fc
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, pool_type=global_pool)
def forward_features(self, x):

View File

@ -537,6 +537,7 @@ class VisionTransformerSAM(nn.Module):
return self.head
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)
def forward_intermediates(

View File

@ -275,6 +275,7 @@ class XceptionAligned(nn.Module):
return self.head.fc
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, pool_type=global_pool)
def forward_features(self, x):