LeViT safetensors load is broken by conversion code that wasn't deactivated
parent
21e75a9d25
commit
9265d54a3a
|
@ -763,17 +763,18 @@ def checkpoint_filter_fn(state_dict, model):
|
|||
# filter out attn biases, should not have been persistent
|
||||
state_dict = {k: v for k, v in state_dict.items() if 'attention_bias_idxs' not in k}
|
||||
|
||||
D = model.state_dict()
|
||||
out_dict = {}
|
||||
for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()):
|
||||
if va.ndim == 4 and vb.ndim == 2:
|
||||
vb = vb[:, :, None, None]
|
||||
if va.shape != vb.shape:
|
||||
# head or first-conv shapes may change for fine-tune
|
||||
assert 'head' in ka or 'stem.conv1.linear' in ka
|
||||
out_dict[ka] = vb
|
||||
# NOTE: old weight conversion code, disabled
|
||||
# D = model.state_dict()
|
||||
# out_dict = {}
|
||||
# for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()):
|
||||
# if va.ndim == 4 and vb.ndim == 2:
|
||||
# vb = vb[:, :, None, None]
|
||||
# if va.shape != vb.shape:
|
||||
# # head or first-conv shapes may change for fine-tune
|
||||
# assert 'head' in ka or 'stem.conv1.linear' in ka
|
||||
# out_dict[ka] = vb
|
||||
|
||||
return out_dict
|
||||
return state_dict
|
||||
|
||||
|
||||
model_cfgs = dict(
|
||||
|
|
Loading…
Reference in New Issue