48 lines
1.4 KiB
Python
48 lines
1.4 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import argparse
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description='Process a checkpoint to be published')
|
|
parser.add_argument('checkpoint', help='input checkpoint filename')
|
|
parser.add_argument(
|
|
'--inplace', action='store_true', help='replace origin ckpt')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
checkpoint = torch.load(args.checkpoint, map_location='cpu')
|
|
new_state_dict = dict()
|
|
|
|
for key, value in checkpoint['state_dict'].items():
|
|
if key.startswith('architecture.model.distiller.teacher'):
|
|
new_key = key.replace('architecture.model.distiller.teacher',
|
|
'architecture.teacher')
|
|
elif key.startswith('architecture.model'):
|
|
new_key = key.replace('architecture.model', 'architecture')
|
|
else:
|
|
new_key = key
|
|
|
|
new_state_dict[new_key] = value
|
|
|
|
checkpoint['state_dict'] = new_state_dict
|
|
|
|
if args.inplace:
|
|
torch.save(checkpoint, args.checkpoint)
|
|
else:
|
|
ckpt_path = Path(args.checkpoint)
|
|
ckpt_name = ckpt_path.stem
|
|
ckpt_dir = ckpt_path.parent
|
|
new_ckpt_path = ckpt_dir / f'{ckpt_name}_latest.pth'
|
|
torch.save(checkpoint, new_ckpt_path)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|