LeViT safetensors load is broken by conversion code that wasn't deactivated
parent
21e75a9d25
commit
9265d54a3a
|
@ -763,17 +763,18 @@ def checkpoint_filter_fn(state_dict, model):
|
||||||
# filter out attn biases, should not have been persistent
|
# filter out attn biases, should not have been persistent
|
||||||
state_dict = {k: v for k, v in state_dict.items() if 'attention_bias_idxs' not in k}
|
state_dict = {k: v for k, v in state_dict.items() if 'attention_bias_idxs' not in k}
|
||||||
|
|
||||||
D = model.state_dict()
|
# NOTE: old weight conversion code, disabled
|
||||||
out_dict = {}
|
# D = model.state_dict()
|
||||||
for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()):
|
# out_dict = {}
|
||||||
if va.ndim == 4 and vb.ndim == 2:
|
# for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()):
|
||||||
vb = vb[:, :, None, None]
|
# if va.ndim == 4 and vb.ndim == 2:
|
||||||
if va.shape != vb.shape:
|
# vb = vb[:, :, None, None]
|
||||||
# head or first-conv shapes may change for fine-tune
|
# if va.shape != vb.shape:
|
||||||
assert 'head' in ka or 'stem.conv1.linear' in ka
|
# # head or first-conv shapes may change for fine-tune
|
||||||
out_dict[ka] = vb
|
# assert 'head' in ka or 'stem.conv1.linear' in ka
|
||||||
|
# out_dict[ka] = vb
|
||||||
|
|
||||||
return out_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
model_cfgs = dict(
|
model_cfgs = dict(
|
||||||
|
|
Loading…
Reference in New Issue