Fix numel use in helpers for checkpoint remap
parent
2054f11c6f
commit
041de79f9e
|
@ -95,7 +95,7 @@ def remap_state_dict(
|
|||
"""
|
||||
out_dict = {}
|
||||
for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()):
|
||||
assert va.numel == vb.numel, f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
|
||||
assert va.numel() == vb.numel(), f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
|
||||
if va.shape != vb.shape:
|
||||
if allow_reshape:
|
||||
vb = vb.reshape(va.shape)
|
||||
|
|
Loading…
Reference in New Issue