LeViT safetensors load is broken by conversion code that wasn't deactivated

few_more_weights
Ross Wightman 2025-01-16 11:08:35 -08:00 committed by Ross Wightman
parent 21e75a9d25
commit 9265d54a3a
1 changed files with 11 additions and 10 deletions

View File

@ -763,17 +763,18 @@ def checkpoint_filter_fn(state_dict, model):
# filter out attn biases, should not have been persistent
state_dict = {k: v for k, v in state_dict.items() if 'attention_bias_idxs' not in k}
D = model.state_dict()
out_dict = {}
for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()):
if va.ndim == 4 and vb.ndim == 2:
vb = vb[:, :, None, None]
if va.shape != vb.shape:
# head or first-conv shapes may change for fine-tune
assert 'head' in ka or 'stem.conv1.linear' in ka
out_dict[ka] = vb
# NOTE: old weight conversion code, disabled
# D = model.state_dict()
# out_dict = {}
# for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()):
# if va.ndim == 4 and vb.ndim == 2:
# vb = vb[:, :, None, None]
# if va.shape != vb.shape:
# # head or first-conv shapes may change for fine-tune
# assert 'head' in ka or 'stem.conv1.linear' in ka
# out_dict[ka] = vb
return out_dict
return state_dict
model_cfgs = dict(