mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix default_cfgs
This commit is contained in:
parent
848b8c3e57
commit
fc0b6ad183
@ -53,7 +53,7 @@ FEAT_INTER_FILTERS = [
|
||||
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
|
||||
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
|
||||
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
|
||||
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*'
|
||||
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'tnt',
|
||||
]
|
||||
|
||||
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
|
||||
|
@ -20,7 +20,7 @@ from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple, resamp
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import checkpoint
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
|
||||
__all__ = ['TNT'] # model_registry will add each entrypoint fn to this
|
||||
@ -450,11 +450,14 @@ def _cfg(url='', **kwargs):
|
||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
'first_conv': 'pixel_embed.proj', 'classifier': 'head',
|
||||
'paper_ids': 'arXiv:2103.00112',
|
||||
'paper_name': 'Transformer in Transformer',
|
||||
'origin_url': 'https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/tnt_pytorch',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'tnt_s_patch16_224.in1k': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
# url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
|
||||
@ -464,7 +467,7 @@ default_cfgs = {
|
||||
# hf_hub_id='timm/',
|
||||
url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar',
|
||||
),
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
|
Loading…
x
Reference in New Issue
Block a user