diff --git a/timm/models/mambaout.py b/timm/models/mambaout.py index 8eac6e7b..71d12fe6 100644 --- a/timm/models/mambaout.py +++ b/timm/models/mambaout.py @@ -479,7 +479,6 @@ class MambaOut(nn.Module): self.reset_classifier(0, '') return take_indices - def forward_features(self, x): x = self.stem(x) x = self.stages(x) diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index e4375b34..b7d4e7e4 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -1302,7 +1302,8 @@ class MaxxVit(nn.Module): if intermediates_only: return intermediates - x = self.norm(x) + if feat_idx == last_idx: + x = self.norm(x) return x, intermediates diff --git a/timm/models/nest.py b/timm/models/nest.py index 9ee50463..8b4ce5ed 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -449,6 +449,7 @@ class Nest(nn.Module): # forward pass x = self.patch_embed(x) + last_idx = self.num_blocks - 1 if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript stages = self.levels else: @@ -457,13 +458,18 @@ class Nest(nn.Module): for feat_idx, stage in enumerate(stages): x = stage(x) if feat_idx in take_indices: - intermediates.append(x) + if norm and feat_idx == last_idx: + x_inter = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + intermediates.append(x_inter) + else: + intermediates.append(x) if intermediates_only: return intermediates - # Layer norm done over channel dim only (to NHWC and back) - x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + if feat_idx == last_idx: + # Layer norm done over channel dim only (to NHWC and back) + x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) return x, intermediates diff --git a/timm/models/nextvit.py b/timm/models/nextvit.py index 5d6ec972..2f232e29 100644 --- a/timm/models/nextvit.py +++ b/timm/models/nextvit.py @@ -588,6 +588,7 @@ class NextViT(nn.Module): # forward pass x = self.stem(x) + last_idx = len(self.stages) - 1 if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript stages = self.stages else: @@ -596,12 +597,17 @@ class NextViT(nn.Module): for feat_idx, stage in enumerate(stages): x = stage(x) if feat_idx in take_indices: - intermediates.append(x) + if feat_idx == last_idx: + x_inter = self.norm(x) if norm else x + intermediates.append(x_inter) + else: + intermediates.append(x) if intermediates_only: return intermediates - x = self.norm(x) + if feat_idx == last_idx: + x = self.norm(x) return x, intermediates diff --git a/timm/models/rdnet.py b/timm/models/rdnet.py index b55cc33c..3c556e37 100644 --- a/timm/models/rdnet.py +++ b/timm/models/rdnet.py @@ -309,7 +309,7 @@ class RDNet(nn.Module): x = self.stem(x) if feat_idx in take_indices: intermediates.append(x) - + last_idx = len(self.dense_stages) if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript dense_stages = self.dense_stages else: @@ -324,7 +324,8 @@ class RDNet(nn.Module): if intermediates_only: return intermediates - x = self.norm_pre(x) + if feat_idx == last_idx: + x = self.norm_pre(x) return x, intermediates diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 1bac794c..1cc3b864 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -574,7 +574,7 @@ class ResNetV2(nn.Module): x = self.stem(x) if feat_idx in take_indices: intermediates.append(x) - + last_idx = len(self.stages) if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript stages = self.stages else: @@ -583,12 +583,17 @@ class ResNetV2(nn.Module): for feat_idx, stage in enumerate(stages, start=1): x = stage(x) if feat_idx in take_indices: - intermediates.append(x) + if feat_idx == last_idx: + x_inter = self.norm(x) if norm else x + intermediates.append(x_inter) + else: + intermediates.append(x) if intermediates_only: return intermediates - x = self.norm(x) + if feat_idx == last_idx: + x = self.norm(x) return x, intermediates