From fe25c8d3fbf913b1b6a7af4b5b618307db68cf6c Mon Sep 17 00:00:00 2001 From: Ezra-Yu <1105212286@qq.com> Date: Fri, 10 Sep 2021 11:42:38 +0800 Subject: [PATCH] [Enchence] Add datetime info and saving model using torch<1.6 format (#439) * Add date and save ckpt usingg torch<1.6 format * fix lint * use digit_version and rasie error when there is no target out_dir * add '.' --- .../publish_model.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) rename tools/{model_converters => convert_models}/publish_model.py (62%) diff --git a/tools/model_converters/publish_model.py b/tools/convert_models/publish_model.py similarity index 62% rename from tools/model_converters/publish_model.py rename to tools/convert_models/publish_model.py index 5d3912e45..b3350c9cd 100644 --- a/tools/model_converters/publish_model.py +++ b/tools/convert_models/publish_model.py @@ -1,8 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse +import datetime +import os import subprocess import torch +from mmcv import digit_version def parse_args(): @@ -21,18 +24,30 @@ def process_checkpoint(in_file, out_file): del checkpoint['optimizer'] # if it is necessary to remove some sensitive data in checkpoint['meta'], # add the code here. - torch.save(checkpoint, out_file) + if digit_version(torch.__version__) >= digit_version('1.6'): + torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False) + else: + torch.save(checkpoint, out_file) + sha = subprocess.check_output(['sha256sum', out_file]).decode() if out_file.endswith('.pth'): out_file_name = out_file[:-4] else: out_file_name = out_file - final_file = out_file_name + f'-{sha[:8]}.pth' + + current_date = datetime.datetime.now().strftime('%Y%m%d') + final_file = out_file_name + f'_{current_date}-{sha[:8]}.pth' subprocess.Popen(['mv', out_file, final_file]) + print(f'Successfully generated the publish-ckpt as {final_file}.') + def main(): args = parse_args() + out_dir = os.path.dirname(args.out_file) + if not os.path.exists(out_dir): + raise ValueError(f'Directory {out_dir} does not exist, ' + 'please generate it manually.') process_checkpoint(args.in_file, args.out_file)