[Enhancement] Add merging script

pull/1687/head
fanqiNO1 2023-07-20 18:10:47 +08:00
parent 63343dc116
commit a12f43b92c
2 changed files with 64 additions and 1 deletions

View File

@ -15,7 +15,8 @@ model = dict(
img_size=384,
patch_size=16,
drop_rate=0.1,
init_cfg=dict(type='Pretrained', checkpoint='')),
init_cfg=dict(type='Pretrained', checkpoint='',
prefix='backbone')),
alpha=16,
rank=16,
drop_rate=0.1,

View File

@ -0,0 +1,62 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from pathlib import Path
import torch
def merge_lora_weight(base_weight, lora_weight):
"""Merge base weight and lora weight.
Args:
base_weight (dict): weight dict from LoRAModel.module.
lora_weight (dict): weight dict from LoRAModel.
Returns:
Merged weight.
"""
temp = dict()
mapping = dict()
for lora_key, lora_value in lora_weight['state_dict'].items():
# backbone.module.layers.0.attn.qkv.lora_up.weight
lora_prefix = '.'.join(lora_key.split('.')[:-2])
mapping[lora_prefix] = f'{lora_prefix}.original_layer'
temp[lora_key] = lora_value
for base_key, base_value in base_weight.items():
# backbone.patch_embed.projection.weight
# add 'module' prefix
base_split = base_key.split('.')
base_split.insert(1, 'module')
base_prefix = '.'.join(base_split[:-1])
if base_prefix in mapping.keys():
base_prefix = mapping[base_prefix]
base_key = '.'.join([base_prefix, base_split[-1]])
temp[base_key] = base_value
result = dict()
result['state_dict'] = temp
for key in lora_weight.keys():
if key != 'state_dict':
result[key] = lora_weight[key]
return result
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Merge LoRA weight')
parser.add_argument('base_src', help='src detectron base model path')
parser.add_argument('lora_src', help='src detectron lora model path')
parser.add_argument('dst', help='save path')
args = parser.parse_args()
dst = Path(args.dst)
if dst.suffix != '.pth':
print('The path should contain the name of the pth format file.')
exit(1)
dst.parent.mkdir(parents=True, exist_ok=True)
base_model = torch.load(args.base_src, map_location='cpu')
lora_model = torch.load(args.lora_src, map_location='cpu')
merged_model = merge_lora_weight(base_model, lora_model)
torch.save(merged_model, args.dst)