mmpretrain/tools/model_converters/llava-delta2mmpre.py

80 lines
2.2 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from collections import OrderedDict
from itertools import chain
from pathlib import Path
import torch
from huggingface_hub import snapshot_download
from transformers.modeling_utils import load_state_dict
prog_description = """\
Convert Llava weights and original weights.
"""
def parse_args():
parser = argparse.ArgumentParser(description=prog_description)
parser.add_argument('src', type=str, help='The original checkpoint dir')
parser.add_argument('dst', type=str, help='The saved checkpoint path')
parser.add_argument('--delta', type=str, help='The delta checkpoint dir')
args = parser.parse_args()
return args
def load_checkpoint(path: Path):
path = Path(path)
if path.is_file():
return torch.load(path)
state_dict = OrderedDict()
for ckpt in chain(
path.rglob('*.bin'), path.rglob('*.pth'),
path.rglob('*.safetensors')):
state_dict.update(load_state_dict(str(ckpt)))
return state_dict
def main():
args = parse_args()
if Path(args.src).exists():
src_path = args.src
else:
src_path = snapshot_download(
args.src, allow_patterns='pytorch_model*.bin')
src_state_dict = load_checkpoint(src_path)
if args.delta is None:
delta_state_dict = {}
elif Path(args.delta).exists():
delta_state_dict = load_checkpoint(args.delta)
else:
delta_path = snapshot_download(
args.delta, allow_patterns='pytorch_model*.bin')
delta_state_dict = load_checkpoint(delta_path)
new_state_dict = OrderedDict()
for k, v in src_state_dict.items():
if k in delta_state_dict:
delta_v = delta_state_dict.pop(k)
if k in ['model.embed_tokens.weight', 'lm_head.weight']:
h, w = v.shape[:2]
delta_v[:h, :w] += v
v = delta_v
else:
v += delta_v
if 'rotary_emb.inv_freq' not in k:
new_state_dict['model.lang_encoder.' + k] = v
for k, v in delta_state_dict.items():
new_state_dict['model.lang_encoder.' + k] = v
torch.save(new_state_dict, args.dst)
print('Done!!')
if __name__ == '__main__':
main()