BEiT-V2 checkpoints didn't remove 'module' from weights, adapt checkpoint filter

This commit is contained in:
Ross Wightman 2022-09-13 17:56:49 -07:00
parent 73049dc2aa
commit c8ab747bf4

View File

@ -384,6 +384,13 @@ class Beit(nn.Module):
return x
def _beit_checkpoint_filter_fn(state_dict, model):
if 'module' in state_dict:
# beit v2 didn't strip module
state_dict = state_dict['module']
return checkpoint_filter_fn(state_dict, model)
def _create_beit(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Beit models.')
@ -391,7 +398,7 @@ def _create_beit(variant, pretrained=False, **kwargs):
model = build_model_with_cfg(
Beit, variant, pretrained,
# FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes
pretrained_filter_fn=checkpoint_filter_fn,
pretrained_filter_fn=_beit_checkpoint_filter_fn,
**kwargs)
return model