Fix default_cfgs

This commit is contained in:
Ryan 2025-05-03 08:40:07 +08:00
parent 848b8c3e57
commit fc0b6ad183
2 changed files with 7 additions and 4 deletions

View File

@ -53,7 +53,7 @@ FEAT_INTER_FILTERS = [
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*'
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'tnt',
]
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.

View File

@ -20,7 +20,7 @@ from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple, resamp
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import checkpoint
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model
__all__ = ['TNT'] # model_registry will add each entrypoint fn to this
@ -450,11 +450,14 @@ def _cfg(url='', **kwargs):
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'pixel_embed.proj', 'classifier': 'head',
'paper_ids': 'arXiv:2103.00112',
'paper_name': 'Transformer in Transformer',
'origin_url': 'https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/tnt_pytorch',
**kwargs
}
default_cfgs = {
default_cfgs = generate_default_cfgs({
'tnt_s_patch16_224.in1k': _cfg(
# hf_hub_id='timm/',
# url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
@ -464,7 +467,7 @@ default_cfgs = {
# hf_hub_id='timm/',
url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar',
),
}
})
def checkpoint_filter_fn(state_dict, model):