mmrazor/tools/model_converters/publish_model.py

91 lines
2.8 KiB
Python
Raw Normal View History

2021-12-23 03:09:46 +08:00
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import datetime
from pathlib import Path
2022-07-15 19:03:20 +08:00
from typing import Optional
2021-12-23 03:09:46 +08:00
import mmcv
import torch
from mmcv import digit_version
def parse_args():
parser = argparse.ArgumentParser(
description='Process a checkpoint to be published')
2022-07-15 19:03:20 +08:00
parser.add_argument('in_file', help='input checkpoint filename', type=str)
parser.add_argument(
'out_file', help='output checkpoint filename', default=None, type=str)
parser.add_argument(
'subnet_cfg_file',
help='input subnet config filename',
default=None,
type=str)
2021-12-23 03:09:46 +08:00
args = parser.parse_args()
return args
def cal_file_sha256(file_path: str) -> str:
import hashlib
BLOCKSIZE = 65536
sha256_hash = hashlib.sha256()
with open(file_path, 'rb') as f:
block = f.read(BLOCKSIZE)
while block:
sha256_hash.update(block)
block = f.read(BLOCKSIZE)
return sha256_hash.hexdigest()
2022-07-15 19:03:20 +08:00
def process_checkpoint(in_file: str,
out_file: Optional[str] = None,
subnet_cfg_file: Optional[str] = None) -> None:
2021-12-23 03:09:46 +08:00
checkpoint = torch.load(in_file, map_location='cpu')
# remove optimizer for smaller file size
if 'optimizer' in checkpoint:
del checkpoint['optimizer']
2022-07-15 19:03:20 +08:00
if out_file is None:
out_file = in_file
2021-12-23 03:09:46 +08:00
# if it is necessary to remove some sensitive data in checkpoint['meta'],
# add the code here.
if digit_version(torch.__version__) >= digit_version('1.6'):
torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False)
else:
torch.save(checkpoint, out_file)
sha = cal_file_sha256(out_file)
if out_file.endswith('.pth'):
out_file_name = out_file[:-4]
else:
out_file_name = out_file
current_date = datetime.datetime.now().strftime('%Y%m%d')
final_file_prefix = out_file_name + f'_{current_date}-{sha[:8]}'
final_ckpt_file = f'{final_file_prefix}.pth'
Path(out_file).rename(final_ckpt_file)
print(f'Successfully generated the publish-ckpt as {final_ckpt_file}.')
2022-07-15 19:03:20 +08:00
if subnet_cfg_file is not None:
subnet_cfg = mmcv.fileio.load(subnet_cfg_file)
final_subnet_cfg_file = f'{final_file_prefix}_subnet_cfg.yaml'
mmcv.fileio.dump(subnet_cfg, final_subnet_cfg_file)
print(f'Successfully generated the publish-subnet-cfg as \
{final_subnet_cfg_file}.')
2021-12-23 03:09:46 +08:00
def main():
args = parse_args()
out_dir = Path(args.out_file).parent
if not out_dir.exists():
raise ValueError(f'Directory {out_dir} does not exist, '
'please generate it manually.')
2022-07-15 19:03:20 +08:00
process_checkpoint(args.in_file, args.out_file, args.subnet_cfg_file)
2021-12-23 03:09:46 +08:00
if __name__ == '__main__':
main()