mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
BEiT-V2 checkpoints didn't remove 'module' from weights, adapt checkpoint filter
This commit is contained in:
parent
73049dc2aa
commit
c8ab747bf4
@ -384,6 +384,13 @@ class Beit(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _beit_checkpoint_filter_fn(state_dict, model):
|
||||||
|
if 'module' in state_dict:
|
||||||
|
# beit v2 didn't strip module
|
||||||
|
state_dict = state_dict['module']
|
||||||
|
return checkpoint_filter_fn(state_dict, model)
|
||||||
|
|
||||||
|
|
||||||
def _create_beit(variant, pretrained=False, **kwargs):
|
def _create_beit(variant, pretrained=False, **kwargs):
|
||||||
if kwargs.get('features_only', None):
|
if kwargs.get('features_only', None):
|
||||||
raise RuntimeError('features_only not implemented for Beit models.')
|
raise RuntimeError('features_only not implemented for Beit models.')
|
||||||
@ -391,7 +398,7 @@ def _create_beit(variant, pretrained=False, **kwargs):
|
|||||||
model = build_model_with_cfg(
|
model = build_model_with_cfg(
|
||||||
Beit, variant, pretrained,
|
Beit, variant, pretrained,
|
||||||
# FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes
|
# FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes
|
||||||
pretrained_filter_fn=checkpoint_filter_fn,
|
pretrained_filter_fn=_beit_checkpoint_filter_fn,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user