From fc0b6ad183de6ffbea8a03cc10afe3c986f3f8eb Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Sat, 3 May 2025 08:40:07 +0800 Subject: [PATCH] Fix default_cfgs --- tests/test_models.py | 2 +- timm/models/tnt.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 3ba3615d..35585a88 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -53,7 +53,7 @@ FEAT_INTER_FILTERS = [ 'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos', 'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2', 'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet', - 'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*' + 'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'tnt', ] # transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output. diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 7decfa9a..8d83fbe0 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -20,7 +20,7 @@ from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple, resamp from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import checkpoint -from ._registry import register_model +from ._registry import generate_default_cfgs, register_model __all__ = ['TNT'] # model_registry will add each entrypoint fn to this @@ -450,11 +450,14 @@ def _cfg(url='', **kwargs): 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 'first_conv': 'pixel_embed.proj', 'classifier': 'head', + 'paper_ids': 'arXiv:2103.00112', + 'paper_name': 'Transformer in Transformer', + 'origin_url': 'https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/tnt_pytorch', **kwargs } -default_cfgs = { +default_cfgs = generate_default_cfgs({ 'tnt_s_patch16_224.in1k': _cfg( # hf_hub_id='timm/', # url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar', @@ -464,7 +467,7 @@ default_cfgs = { # hf_hub_id='timm/', url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar', ), -} +}) def checkpoint_filter_fn(state_dict, model):