fix pit & add to test

This commit is contained in:
Ryan 2025-05-08 02:02:54 +08:00 committed by Ross Wightman
parent 2e9b2a76fb
commit d1140c1a0f
3 changed files with 5 additions and 5 deletions

View File

@ -56,7 +56,7 @@ FEAT_INTER_FILTERS = [
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*',
'tiny_vit', 'vovnet', 'tresnet', 'rexnet', 'resnetv2', 'repghost', 'repvit', 'pvt_v2', 'nextvit', 'nest',
'mambaout', 'inception_next', 'inception_v4', 'hgnet', 'gcvit', 'focalnet', 'efficientformer_v2', 'edgenext',
'davit',
'davit', 'rdnet', 'convnext', 'pit'
]
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.

View File

@ -486,8 +486,8 @@ class ConvNeXt(nn.Module):
):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
take_indices, max_index = feature_take_indices(len(self.stages), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm_pre = nn.Identity()
if prune_head:

View File

@ -313,7 +313,7 @@ class PoolingVisionTransformer(nn.Module):
""" Prune layers not required for specified intermediates.
"""
take_indices, max_index = feature_take_indices(len(self.transformers), indices)
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
self.transformers = self.transformers[:max_index + 1] # truncate blocks w/ stem as idx 0
if prune_norm:
self.norm = nn.Identity()
if prune_head:
@ -380,7 +380,7 @@ def _create_pit(variant, pretrained=False, **kwargs):
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(feature_cls='hook', no_rewrite=True, out_indices=out_indices),
feature_cfg=dict(feature_cls='hook', out_indices=out_indices),
**kwargs,
)
return model