diff --git a/tests/test_models.py b/tests/test_models.py index bb2a92ed..aa866ccd 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -56,7 +56,7 @@ FEAT_INTER_FILTERS = [ 'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'tiny_vit', 'vovnet', 'tresnet', 'rexnet', 'resnetv2', 'repghost', 'repvit', 'pvt_v2', 'nextvit', 'nest', 'mambaout', 'inception_next', 'inception_v4', 'hgnet', 'gcvit', 'focalnet', 'efficientformer_v2', 'edgenext', - 'davit', + 'davit', 'rdnet', 'convnext', 'pit' ] # transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output. diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 2f445118..e2eb48d3 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -486,8 +486,8 @@ class ConvNeXt(nn.Module): ): """ Prune layers not required for specified intermediates. """ - take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices) - self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0 + take_indices, max_index = feature_take_indices(len(self.stages), indices) + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 if prune_norm: self.norm_pre = nn.Identity() if prune_head: diff --git a/timm/models/pit.py b/timm/models/pit.py index 109cfaf8..1d5386a9 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -313,7 +313,7 @@ class PoolingVisionTransformer(nn.Module): """ Prune layers not required for specified intermediates. """ take_indices, max_index = feature_take_indices(len(self.transformers), indices) - self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 + self.transformers = self.transformers[:max_index + 1] # truncate blocks w/ stem as idx 0 if prune_norm: self.norm = nn.Identity() if prune_head: @@ -380,7 +380,7 @@ def _create_pit(variant, pretrained=False, **kwargs): variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, - feature_cfg=dict(feature_cls='hook', no_rewrite=True, out_indices=out_indices), + feature_cfg=dict(feature_cls='hook', out_indices=out_indices), **kwargs, ) return model