mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
fix pit & add to test
This commit is contained in:
parent
2e9b2a76fb
commit
d1140c1a0f
@ -56,7 +56,7 @@ FEAT_INTER_FILTERS = [
|
||||
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*',
|
||||
'tiny_vit', 'vovnet', 'tresnet', 'rexnet', 'resnetv2', 'repghost', 'repvit', 'pvt_v2', 'nextvit', 'nest',
|
||||
'mambaout', 'inception_next', 'inception_v4', 'hgnet', 'gcvit', 'focalnet', 'efficientformer_v2', 'edgenext',
|
||||
'davit',
|
||||
'davit', 'rdnet', 'convnext', 'pit'
|
||||
]
|
||||
|
||||
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
|
||||
|
@ -486,8 +486,8 @@ class ConvNeXt(nn.Module):
|
||||
):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
|
||||
self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
|
||||
take_indices, max_index = feature_take_indices(len(self.stages), indices)
|
||||
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_norm:
|
||||
self.norm_pre = nn.Identity()
|
||||
if prune_head:
|
||||
|
@ -313,7 +313,7 @@ class PoolingVisionTransformer(nn.Module):
|
||||
""" Prune layers not required for specified intermediates.
|
||||
"""
|
||||
take_indices, max_index = feature_take_indices(len(self.transformers), indices)
|
||||
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
self.transformers = self.transformers[:max_index + 1] # truncate blocks w/ stem as idx 0
|
||||
if prune_norm:
|
||||
self.norm = nn.Identity()
|
||||
if prune_head:
|
||||
@ -380,7 +380,7 @@ def _create_pit(variant, pretrained=False, **kwargs):
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(feature_cls='hook', no_rewrite=True, out_indices=out_indices),
|
||||
feature_cfg=dict(feature_cls='hook', out_indices=out_indices),
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
Loading…
x
Reference in New Issue
Block a user