From d6c2cc91af13e89990c05979638e8ffaa0335b5e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 10 Feb 2024 16:25:33 -0800 Subject: [PATCH] Make NormMlpClassifier head reset args consistent with ClassifierHead --- timm/layers/classifier.py | 8 ++++---- timm/models/davit.py | 2 +- timm/models/tiny_vit.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/timm/layers/classifier.py b/timm/layers/classifier.py index 2eb4ec2e..71e45c87 100644 --- a/timm/layers/classifier.py +++ b/timm/layers/classifier.py @@ -180,10 +180,10 @@ class NormMlpClassifierHead(nn.Module): self.drop = nn.Dropout(drop_rate) self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - def reset(self, num_classes, global_pool=None): - if global_pool is not None: - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.flatten = nn.Flatten(1) if global_pool else nn.Identity() + def reset(self, num_classes, pool_type=None): + if pool_type is not None: + self.global_pool = SelectAdaptivePool2d(pool_type=pool_type) + self.flatten = nn.Flatten(1) if pool_type else nn.Identity() self.use_conv = self.global_pool.is_identity() linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear if self.hidden_size: diff --git a/timm/models/davit.py b/timm/models/davit.py index e7f2ed0e..d4d6ad69 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -569,7 +569,7 @@ class DaVit(nn.Module): return self.head.fc def reset_classifier(self, num_classes, global_pool=None): - self.head.reset(num_classes, global_pool=global_pool) + self.head.reset(num_classes, global_pool) def forward_features(self, x): x = self.stem(x) diff --git a/timm/models/tiny_vit.py b/timm/models/tiny_vit.py index 96a88db7..b4b29648 100644 --- a/timm/models/tiny_vit.py +++ b/timm/models/tiny_vit.py @@ -535,7 +535,7 @@ class TinyVit(nn.Module): def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes - self.head.reset(num_classes, global_pool=global_pool) + self.head.reset(num_classes, pool_type=global_pool) def forward_features(self, x): x = self.patch_embed(x)