80 lines
2.2 KiB
Python
80 lines
2.2 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import argparse
|
|
from collections import OrderedDict
|
|
from itertools import chain
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from huggingface_hub import snapshot_download
|
|
from transformers.modeling_utils import load_state_dict
|
|
|
|
prog_description = """\
|
|
Convert Llava weights and original weights.
|
|
"""
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description=prog_description)
|
|
parser.add_argument('src', type=str, help='The original checkpoint dir')
|
|
parser.add_argument('dst', type=str, help='The saved checkpoint path')
|
|
parser.add_argument('--delta', type=str, help='The delta checkpoint dir')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def load_checkpoint(path: Path):
|
|
path = Path(path)
|
|
if path.is_file():
|
|
return torch.load(path)
|
|
|
|
state_dict = OrderedDict()
|
|
for ckpt in chain(
|
|
path.rglob('*.bin'), path.rglob('*.pth'),
|
|
path.rglob('*.safetensors')):
|
|
state_dict.update(load_state_dict(str(ckpt)))
|
|
|
|
return state_dict
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
if Path(args.src).exists():
|
|
src_path = args.src
|
|
else:
|
|
src_path = snapshot_download(
|
|
args.src, allow_patterns='pytorch_model*.bin')
|
|
src_state_dict = load_checkpoint(src_path)
|
|
|
|
if args.delta is None:
|
|
delta_state_dict = {}
|
|
elif Path(args.delta).exists():
|
|
delta_state_dict = load_checkpoint(args.delta)
|
|
else:
|
|
delta_path = snapshot_download(
|
|
args.delta, allow_patterns='pytorch_model*.bin')
|
|
delta_state_dict = load_checkpoint(delta_path)
|
|
|
|
new_state_dict = OrderedDict()
|
|
for k, v in src_state_dict.items():
|
|
if k in delta_state_dict:
|
|
delta_v = delta_state_dict.pop(k)
|
|
if k in ['model.embed_tokens.weight', 'lm_head.weight']:
|
|
h, w = v.shape[:2]
|
|
delta_v[:h, :w] += v
|
|
v = delta_v
|
|
else:
|
|
v += delta_v
|
|
if 'rotary_emb.inv_freq' not in k:
|
|
new_state_dict['model.lang_encoder.' + k] = v
|
|
|
|
for k, v in delta_state_dict.items():
|
|
new_state_dict['model.lang_encoder.' + k] = v
|
|
|
|
torch.save(new_state_dict, args.dst)
|
|
print('Done!!')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|