LeViT safetensors load is broken by conversion code that wasn't deactivated

few_more_weights
Ross Wightman 2025-01-16 11:08:35 -08:00 committed by Ross Wightman
parent 21e75a9d25
commit 9265d54a3a
1 changed files with 11 additions and 10 deletions

View File

@ -763,17 +763,18 @@ def checkpoint_filter_fn(state_dict, model):
# filter out attn biases, should not have been persistent # filter out attn biases, should not have been persistent
state_dict = {k: v for k, v in state_dict.items() if 'attention_bias_idxs' not in k} state_dict = {k: v for k, v in state_dict.items() if 'attention_bias_idxs' not in k}
D = model.state_dict() # NOTE: old weight conversion code, disabled
out_dict = {} # D = model.state_dict()
for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()): # out_dict = {}
if va.ndim == 4 and vb.ndim == 2: # for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()):
vb = vb[:, :, None, None] # if va.ndim == 4 and vb.ndim == 2:
if va.shape != vb.shape: # vb = vb[:, :, None, None]
# head or first-conv shapes may change for fine-tune # if va.shape != vb.shape:
assert 'head' in ka or 'stem.conv1.linear' in ka # # head or first-conv shapes may change for fine-tune
out_dict[ka] = vb # assert 'head' in ka or 'stem.conv1.linear' in ka
# out_dict[ka] = vb
return out_dict return state_dict
model_cfgs = dict( model_cfgs = dict(