mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
* [Feature] Add RepeatAugSampler * initial commit * spts inference done * merge repeat_aug (bug in multi-node?) * fix inference * train done * rm readme * Revert "merge repeat_aug (bug in multi-node?)" This reverts commit 393506a97cbe6d75ad1f28611ea10eba6b8fa4b3. * Revert "[Feature] Add RepeatAugSampler" This reverts commit 2089b02b4844157670033766f257b5d1bca452ce. * remove utils * readme & conversion script * update readme * fix * optimize * rename cfg & del compose * fix * fix
46 lines
1.3 KiB
Python
46 lines
1.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import argparse
|
|
|
|
import torch
|
|
|
|
prefix_mapping = {
|
|
'backbone.0.body': 'backbone',
|
|
'input_proj': 'encoder.input_proj',
|
|
'transformer': 'decoder',
|
|
'vocab_embed.layers.': 'decoder.vocab_embed.layer-'
|
|
}
|
|
|
|
|
|
def adapt(model_path, save_path):
|
|
model = torch.load(model_path)
|
|
model_dict = model['model']
|
|
new_model_dict = model_dict.copy()
|
|
|
|
for k, v in model_dict.items():
|
|
for old_prefix, new_prefix in prefix_mapping.items():
|
|
if k.startswith(old_prefix):
|
|
new_k = k.replace(old_prefix, new_prefix)
|
|
new_model_dict[new_k] = v
|
|
del new_model_dict[k]
|
|
break
|
|
model['state_dict'] = new_model_dict
|
|
del model['model']
|
|
torch.save(model, save_path)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description='Adapt the pretrained checkpoints from SPTS official '
|
|
'implementation.')
|
|
parser.add_argument(
|
|
'model_path', type=str, help='Path to the source model')
|
|
parser.add_argument(
|
|
'out_path', type=str, help='Path to the converted model')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
adapt(args.model_path, args.out_path)
|