mmocr/projects/SPTS/tools/ckpt_adapter.py
Tong Gao 2d743cfa19
[Model] SPTS (#1696)
* [Feature] Add RepeatAugSampler

* initial commit

* spts inference done

* merge repeat_aug (bug in multi-node?)

* fix inference

* train done

* rm readme

* Revert "merge repeat_aug (bug in multi-node?)"

This reverts commit 393506a97cbe6d75ad1f28611ea10eba6b8fa4b3.

* Revert "[Feature] Add RepeatAugSampler"

This reverts commit 2089b02b4844157670033766f257b5d1bca452ce.

* remove utils

* readme & conversion script

* update readme

* fix

* optimize

* rename cfg & del compose

* fix

* fix
2023-02-01 11:58:03 +08:00

46 lines
1.3 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import torch
prefix_mapping = {
'backbone.0.body': 'backbone',
'input_proj': 'encoder.input_proj',
'transformer': 'decoder',
'vocab_embed.layers.': 'decoder.vocab_embed.layer-'
}
def adapt(model_path, save_path):
model = torch.load(model_path)
model_dict = model['model']
new_model_dict = model_dict.copy()
for k, v in model_dict.items():
for old_prefix, new_prefix in prefix_mapping.items():
if k.startswith(old_prefix):
new_k = k.replace(old_prefix, new_prefix)
new_model_dict[new_k] = v
del new_model_dict[k]
break
model['state_dict'] = new_model_dict
del model['model']
torch.save(model, save_path)
def parse_args():
parser = argparse.ArgumentParser(
description='Adapt the pretrained checkpoints from SPTS official '
'implementation.')
parser.add_argument(
'model_path', type=str, help='Path to the source model')
parser.add_argument(
'out_path', type=str, help='Path to the converted model')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
adapt(args.model_path, args.out_path)