mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
fix loading pretrained model
This commit is contained in:
parent
bb50b69a57
commit
3718c5a5bd
@ -337,12 +337,24 @@ def _create_crossvit(variant, pretrained=False, **kwargs):
|
|||||||
if kwargs.get('features_only', None):
|
if kwargs.get('features_only', None):
|
||||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||||
|
|
||||||
|
def pretrained_filter_fn(state_dict):
|
||||||
|
new_state_dict = {}
|
||||||
|
for key in state_dict.keys():
|
||||||
|
if 'pos_embed' in key or 'cls_token' in key:
|
||||||
|
new_key = key.replace(".", "_")
|
||||||
|
else:
|
||||||
|
new_key = key
|
||||||
|
new_state_dict[new_key] = state_dict[key]
|
||||||
|
return new_state_dict
|
||||||
|
|
||||||
return build_model_with_cfg(
|
return build_model_with_cfg(
|
||||||
CrossViT, variant, pretrained,
|
CrossViT, variant, pretrained,
|
||||||
default_cfg=default_cfgs[variant],
|
default_cfg=default_cfgs[variant],
|
||||||
|
pretrained_filter_fn=pretrained_filter_fn,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def crossvit_tiny_224(pretrained=False, **kwargs):
|
def crossvit_tiny_224(pretrained=False, **kwargs):
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user