Fix numel use in helpers for checkpoint remap

pull/1741/head
Ross Wightman 2023-03-20 09:36:48 -07:00
parent 2054f11c6f
commit 041de79f9e
1 changed files with 1 additions and 1 deletions

View File

@ -95,7 +95,7 @@ def remap_state_dict(
"""
out_dict = {}
for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()):
assert va.numel == vb.numel, f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
assert va.numel() == vb.numel(), f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
if va.shape != vb.shape:
if allow_reshape:
vb = vb.reshape(va.shape)