diff --git a/timm/models/davit.py b/timm/models/davit.py index b58bbbbf..65009888 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -633,6 +633,7 @@ class DaVit(nn.Module): return self.head.fc def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + self.num_classes = num_classes self.head.reset(num_classes, global_pool) def forward_features(self, x): diff --git a/timm/models/focalnet.py b/timm/models/focalnet.py index 5e3fb962..5608facb 100644 --- a/timm/models/focalnet.py +++ b/timm/models/focalnet.py @@ -455,6 +455,7 @@ class FocalNet(nn.Module): return self.head.fc def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + self.num_classes = num_classes self.head.reset(num_classes, pool_type=global_pool) def forward_features(self, x): diff --git a/timm/models/metaformer.py b/timm/models/metaformer.py index cba2d8b3..23ef3724 100644 --- a/timm/models/metaformer.py +++ b/timm/models/metaformer.py @@ -584,6 +584,7 @@ class MetaFormer(nn.Module): return self.head.fc def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + self.num_classes = num_classes if global_pool is not None: self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity() diff --git a/timm/models/nextvit.py b/timm/models/nextvit.py index e05bb8b5..01e63fce 100644 --- a/timm/models/nextvit.py +++ b/timm/models/nextvit.py @@ -557,6 +557,7 @@ class NextViT(nn.Module): return self.head.fc def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + self.num_classes = num_classes self.head.reset(num_classes, pool_type=global_pool) def forward_features(self, x): diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 2aa27564..68e92128 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -434,6 +434,7 @@ class NormFreeNet(nn.Module): return self.head.fc def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + self.num_classes = num_classes self.head.reset(num_classes, global_pool) def forward_features(self, x): diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index dd7011b0..6977350a 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -384,7 +384,7 @@ class PyramidVisionTransformerV2(nn.Module): if global_pool is not None: assert global_pool in ('avg', '') self.global_pool = global_pool - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.patch_embed(x) diff --git a/timm/models/rdnet.py b/timm/models/rdnet.py index af00548c..320b5c69 100644 --- a/timm/models/rdnet.py +++ b/timm/models/rdnet.py @@ -349,6 +349,7 @@ class RDNet(nn.Module): return self.head.fc def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + self.num_classes = num_classes self.head.reset(num_classes, global_pool) def forward_features(self, x): diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 240f8bcf..49a19aa1 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -515,6 +515,7 @@ class RegNet(nn.Module): return self.head.fc def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + self.num_classes = num_classes self.head.reset(num_classes, pool_type=global_pool) def forward_intermediates( diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index dec24c1f..37c37e2f 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -225,6 +225,7 @@ class TResNet(nn.Module): return self.head.fc def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + self.num_classes = num_classes self.head.reset(num_classes, pool_type=global_pool) def forward_features(self, x): diff --git a/timm/models/vision_transformer_sam.py b/timm/models/vision_transformer_sam.py index 3d0b37a4..75bb12e5 100644 --- a/timm/models/vision_transformer_sam.py +++ b/timm/models/vision_transformer_sam.py @@ -537,6 +537,7 @@ class VisionTransformerSAM(nn.Module): return self.head def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + self.num_classes = num_classes self.head.reset(num_classes, global_pool) def forward_intermediates( diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py index f9071ed3..3c67fa80 100644 --- a/timm/models/xception_aligned.py +++ b/timm/models/xception_aligned.py @@ -275,6 +275,7 @@ class XceptionAligned(nn.Module): return self.head.fc def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + self.num_classes = num_classes self.head.reset(num_classes, pool_type=global_pool) def forward_features(self, x):