mirror of https://github.com/open-mmlab/mmyolo.git
55 lines
1.7 KiB
Python
55 lines
1.7 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import argparse
|
|
import os.path as osp
|
|
from pathlib import Path
|
|
|
|
from mmengine.runner import CheckpointLoader, save_checkpoint
|
|
from mmengine.utils import mkdir_or_exist
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description='Convert KD checkpoint to student-only checkpoint')
|
|
parser.add_argument('checkpoint', help='input checkpoint filename')
|
|
parser.add_argument('--out-path', help='save checkpoint path')
|
|
parser.add_argument(
|
|
'--inplace', action='store_true', help='replace origin ckpt')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
checkpoint = CheckpointLoader.load_checkpoint(
|
|
args.checkpoint, map_location='cpu')
|
|
new_state_dict = dict()
|
|
new_meta = checkpoint['meta']
|
|
|
|
for key, value in checkpoint['state_dict'].items():
|
|
if key.startswith('architecture.'):
|
|
new_key = key.replace('architecture.', '')
|
|
new_state_dict[new_key] = value
|
|
|
|
checkpoint = dict()
|
|
checkpoint['meta'] = new_meta
|
|
checkpoint['state_dict'] = new_state_dict
|
|
|
|
if args.inplace:
|
|
assert osp.exists(args.checkpoint), \
|
|
'can not find the checkpoint path: {args.checkpoint}'
|
|
save_checkpoint(checkpoint, args.checkpoint)
|
|
else:
|
|
ckpt_path = Path(args.checkpoint)
|
|
ckpt_name = ckpt_path.stem
|
|
if args.out_path:
|
|
ckpt_dir = Path(args.out_path)
|
|
else:
|
|
ckpt_dir = ckpt_path.parent
|
|
mkdir_or_exist(ckpt_dir)
|
|
new_ckpt_path = osp.join(ckpt_dir, f'{ckpt_name}_student.pth')
|
|
save_checkpoint(checkpoint, new_ckpt_path)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|