Default to old checkpoint format for now, still want compatibility with older torch ver for released models

This commit is contained in:
Ross Wightman 2020-10-13 12:58:04 -07:00
parent a4d8fea61e
commit 9305313291
2 changed files with 10 additions and 2 deletions

View File

@ -103,7 +103,11 @@ def main():
v = v.clamp(float32_info.min, float32_info.max)
final_state_dict[k] = v.to(dtype=torch.float32)
torch.save(final_state_dict, args.output)
try:
torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False)
except:
torch.save(final_state_dict, args.output)
with open(args.output, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest()
print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash))

View File

@ -57,7 +57,11 @@ def main():
new_state_dict[name] = v
print("=> Loaded state_dict from '{}'".format(args.checkpoint))
torch.save(new_state_dict, _TEMP_NAME)
try:
torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False)
except:
torch.save(new_state_dict, _TEMP_NAME)
with open(_TEMP_NAME, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest()