mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
BEiT-V2 checkpoints didn't remove 'module' from weights, adapt checkpoint filter
This commit is contained in:
parent
73049dc2aa
commit
c8ab747bf4
@ -384,6 +384,13 @@ class Beit(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def _beit_checkpoint_filter_fn(state_dict, model):
|
||||
if 'module' in state_dict:
|
||||
# beit v2 didn't strip module
|
||||
state_dict = state_dict['module']
|
||||
return checkpoint_filter_fn(state_dict, model)
|
||||
|
||||
|
||||
def _create_beit(variant, pretrained=False, **kwargs):
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Beit models.')
|
||||
@ -391,7 +398,7 @@ def _create_beit(variant, pretrained=False, **kwargs):
|
||||
model = build_model_with_cfg(
|
||||
Beit, variant, pretrained,
|
||||
# FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
pretrained_filter_fn=_beit_checkpoint_filter_fn,
|
||||
**kwargs)
|
||||
return model
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user