[Feature] Support LLaVA 1.5 (#1853)

* Support LLaVA 1.5

* Fix lint
dev
Ma Zerun 2023-12-22 16:28:20 +08:00 committed by GitHub
parent ed5924b6fe
commit 3022f9af7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 338 additions and 175 deletions

View File

@ -16,46 +16,28 @@ Instruction tuning large language models (LLMs) using machine-generated instruct
<!-- [TABS-BEGIN] -->
**Prepare the checkpoint**
According to the license of LLaMA, we cannot provide the merged checkpoint directly. Please use the below
script to download and get the merged the checkpoint.
```shell
python tools/model_converters/llava-delta2mmpre.py huggyllama/llama-7b liuhaotian/LLaVA-Lightning-7B-delta-v1-1 ./LLaVA-Lightning-7B-delta-v1-1.pth
```
**Use the model**
```python
import torch
from mmpretrain import get_model, inference_model
model = get_model('llava-7b-v1_caption', pretrained='MERGED_CHECKPOINT_PATH', device='cuda')
out = inference_model(model, 'demo/cat-dog.png')
out = inference_model('llava-7b-v1_caption', 'demo/cat-dog.png', device='cuda')
print(out)
# {'pred_caption': 'In the image, there are two cats sitting on a blanket.'}
```
**Test Command**
Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).
Test:
```shell
python tools/test.py configs/llava/llava-7b-v1_caption.py MERGED_CHECKPOINT_PATH
```
<!-- [TABS-END] -->
## Models and results
### Image Caption on COCO
| Model | Params (M) | BLEU-4 | CIDER | Config | Download |
| :-------------------- | :--------: | :------: | :------: | :------------------------------: | :--------------------: |
| `llava-7b-v1_caption` | 7045.82 | Upcoming | Upcoming | [config](llava-7b-v1_caption.py) | See the above tutorial |
| Model | Params (M) | Config | Download |
| :---------------------- | :--------: | :--------------------------------: | :-------------------------------------------------------------------------------------------------------------: |
| `llava-7b-v1_caption` | 7045.82 | [config](llava-7b-v1_caption.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1_liuhaotian_20231025-c9e119b6.pth) |
| `llava-7b-v1.5_caption` | 7062.90 | [config](llava-7b-v1.5_caption.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth) |
| `llava-7b-v1.5_vqa` | 7062.90 | [config](llava-7b-v1.5_vqa.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth) |
## Citation

View File

@ -0,0 +1,76 @@
_base_ = '../_base_/default_runtime.py'
meta_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." # noqa: E501
image_size = 336
prompt_tmpl = f'''{meta_prompt} User: <image>
Describe the image in detail. ASSISTANT:'''
# model settings
model = dict(
type='Llava',
tokenizer=dict(
type='AutoTokenizer', name_or_path='liuhaotian/llava-v1.5-7b'),
vision_encoder=dict(
type='VisionTransformer',
arch='l',
patch_size=14,
img_size=image_size,
pre_norm=True,
norm_cfg=dict(type='LN', eps=1e-5),
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')),
final_norm=False,
out_type='raw',
pretrained='https://download.openmmlab.com/mmclassification/v0/clip/'
'vit-large-p14_clip-openai-pre_336px_20231025-fb1315ed.pth',
),
mm_hidden_size=1024,
use_im_patch=False,
use_im_start_end=False,
mm_proj_depth=2,
lang_encoder=dict(
type='AutoModelForCausalLM',
name_or_path='huggyllama/llama-7b',
),
task='caption',
prompt_tmpl=prompt_tmpl,
generation_cfg=dict(num_beams=3, max_new_tokens=50, length_penalty=-1.0),
)
# data settings
data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(image_size, image_size),
interpolation='bicubic',
backend='pillow'),
dict(type='PackInputs', meta_keys=['image_id']),
]
test_dataloader = dict(
batch_size=8,
num_workers=5,
dataset=dict(
type='COCOCaption',
data_root='data/coco',
ann_file='annotations/coco_karpathy_val.json',
pipeline=test_pipeline,
),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
test_evaluator = dict(
type='COCOCaption',
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
)
# schedule settings
test_cfg = dict()

View File

@ -0,0 +1,76 @@
_base_ = '../_base_/default_runtime.py'
meta_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." # noqa: E501
image_size = 336
prompt_tmpl = f'''{meta_prompt} User: <image>
{{question}} ASSISTANT:'''
# model settings
model = dict(
type='Llava',
tokenizer=dict(
type='AutoTokenizer', name_or_path='liuhaotian/llava-v1.5-7b'),
vision_encoder=dict(
type='VisionTransformer',
arch='l',
patch_size=14,
img_size=image_size,
pre_norm=True,
norm_cfg=dict(type='LN', eps=1e-5),
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')),
final_norm=False,
out_type='raw',
pretrained='https://download.openmmlab.com/mmclassification/v0/clip/'
'vit-large-p14_clip-openai-pre_336px_20231025-fb1315ed.pth',
),
mm_hidden_size=1024,
use_im_patch=False,
use_im_start_end=False,
mm_proj_depth=2,
lang_encoder=dict(
type='AutoModelForCausalLM',
name_or_path='huggyllama/llama-7b',
),
task='vqa',
prompt_tmpl=prompt_tmpl,
generation_cfg=dict(max_new_tokens=100),
)
# data settings
data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(image_size, image_size),
interpolation='bicubic',
backend='pillow'),
dict(type='PackInputs', meta_keys=['image_id', 'question']),
]
test_dataloader = dict(
batch_size=8,
num_workers=5,
dataset=dict(
type='COCOCaption',
data_root='data/coco',
ann_file='annotations/coco_karpathy_val.json',
pipeline=test_pipeline,
),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
test_evaluator = dict(
type='COCOCaption',
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
)
# schedule settings
test_cfg = dict()

View File

@ -1,16 +1,9 @@
_base_ = '../_base_/default_runtime.py'
meta_prompt = 'You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab.You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.Follow the instructions carefully and explain your answers in detail.' # noqa: E501
im_patch_token = '<im_patch>'
patch_size = 14
image_size = 224
num_patches = (image_size // patch_size)**2
caption_prompt = ' '.join([
meta_prompt,
'User: a photo of\n',
im_patch_token * num_patches,
'ASSISTANT:',
])
prompt_tmpl = f'''{meta_prompt} User: <im_start><image><im_end>
Describe the image in detail. ASSISTANT:'''
# model settings
model = dict(
@ -22,6 +15,7 @@ model = dict(
type='VisionTransformer',
arch='l',
patch_size=14,
img_size=image_size,
pre_norm=True,
norm_cfg=dict(type='LN', eps=1e-5),
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')),
@ -32,15 +26,16 @@ model = dict(
'vit-large-p14_clip-openai-pre_3rdparty_20230517-95e2af0b.pth'),
),
mm_hidden_size=1024,
use_im_start_end=False,
use_mm_proj=True,
use_im_patch=False,
use_im_start_end=True,
mm_proj_depth=1,
lang_encoder=dict(
type='AutoModelForCausalLM',
name_or_path='huggyllama/llama-7b',
),
task='caption',
prompt_tmpl=caption_prompt,
generation_cfg=dict(num_beams=3, max_new_tokens=20, length_penalty=-2.0),
prompt_tmpl=prompt_tmpl,
generation_cfg=dict(max_new_tokens=50),
)
# data settings

View File

@ -21,5 +21,31 @@ Models:
Metrics:
BLEU-4: null
CIDER: null
Weights: null
Weights: https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1_liuhaotian_20231025-c9e119b6.pth
Config: configs/llava/llava-7b-v1_caption.py
- Name: llava-7b-v1.5_caption
Metadata:
FLOPs: null
Parameters: 7062900736
In Collection: LLaVA
Results:
- Task: Image Caption
Dataset: COCO
Metrics:
BLEU-4: null
CIDER: null
Weights: https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth
Config: configs/llava/llava-7b-v1.5_caption.py
- Name: llava-7b-v1.5_vqa
Metadata:
FLOPs: null
Parameters: 7062900736
In Collection: LLaVA
Results:
- Task: Visual Question Answering
Dataset: COCO
Metrics:
BLEU-4: null
CIDER: null
Weights: https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth
Config: configs/llava/llava-7b-v1.5_vqa.py

View File

@ -24,8 +24,8 @@ class Llava(BaseModel):
use_im_start_end (bool): Whether to use the im_start and im_end tokens
mm_vision_select_layer (int): The index from vision encoder output.
Defaults to -1.
use_mm_proj (bool): Whether to enable multi-modal projection.
Defaults to True.
mm_proj_depth (int): The number of linear layers for multi-modal
projection. Defaults to 1.
load_lang_pretrained (bool): Whether to load the pretrained model of
language encoder. Defaults to False.
generation_cfg (dict): The extra generation config, accept the keyword
@ -51,9 +51,10 @@ class Llava(BaseModel):
mm_hidden_size: int,
prompt_tmpl: str,
task: str = 'caption',
use_im_patch: bool = True,
use_im_start_end: bool = False,
mm_vision_select_layer: int = -1,
use_mm_proj: bool = True,
mm_proj_depth: int = 1,
generation_cfg: dict = dict(),
load_lang_pretrained: bool = False,
data_preprocessor: Optional[dict] = None,
@ -75,7 +76,9 @@ class Llava(BaseModel):
# init tokenizer
self.tokenizer = TOKENIZER.build(tokenizer)
# add Llava special tokens to the tokenizer
self.tokenizer.add_tokens([self.im_patch_token], special_tokens=True)
if use_im_patch:
self.tokenizer.add_tokens([self.im_patch_token],
special_tokens=True)
if use_im_start_end:
self.tokenizer.add_tokens([self.im_start_token, self.im_end_token],
special_tokens=True)
@ -108,14 +111,12 @@ class Llava(BaseModel):
vision_encoder=vision_encoder,
lang_encoder=lang_encoder,
mm_hidden_size=mm_hidden_size,
use_mm_proj=use_mm_proj,
mm_proj_depth=mm_proj_depth,
use_im_start_end=use_im_start_end,
im_start_token=self.tokenizer.convert_tokens_to_ids(
self.im_start_token),
im_end_token=self.tokenizer.convert_tokens_to_ids(
self.im_end_token),
im_patch_token=self.tokenizer.convert_tokens_to_ids(
self.im_patch_token),
mm_vision_select_layer=mm_vision_select_layer)
self.generation_cfg = generation_cfg
@ -207,16 +208,24 @@ class Llava(BaseModel):
Returns:
List[DataSample]: Return list of data samples.
"""
prompts = []
tokens = []
for sample in data_samples:
final_prompt = self.prompt_tmpl.format(**sample.to_dict())
prompts.append(final_prompt)
prompt = self.prompt_tmpl.format(**sample.to_dict())
input_ids = []
while '<image>' in prompt:
prefix, _, prompt = prompt.partition('<image>')
input_ids.extend(
self.tokenizer(prefix, add_special_tokens=False).input_ids)
input_ids.append(-200)
if prompt:
input_ids.extend(
self.tokenizer(prompt, add_special_tokens=False).input_ids)
tokens.append(dict(input_ids=input_ids))
self.tokenizer.padding_side = 'left'
input_text = self.tokenizer(
prompts,
input_text = self.tokenizer.pad(
tokens,
padding='longest',
truncation=True,
return_tensors='pt',
max_length=2000,
).to(device)

View File

@ -31,10 +31,10 @@ class LlavaLlamaForCausalLM(PreTrainedModel):
lang_encoder,
mm_hidden_size,
use_im_start_end=True,
use_mm_proj=True,
mm_proj_depth=1,
im_start_token: Optional[int] = None,
im_end_token: Optional[int] = None,
im_patch_token: Optional[int] = None,
im_token_index: int = -200,
mm_vision_select_layer: int = -1):
super().__init__(lang_encoder.config)
self.vision_tower = vision_encoder
@ -43,16 +43,26 @@ class LlavaLlamaForCausalLM(PreTrainedModel):
self.use_im_start_end = use_im_start_end
self.im_start_token = im_start_token
self.im_end_token = im_end_token
self.im_patch_token = im_patch_token
self.mm_hidden_size = mm_hidden_size
self.mm_vision_select_layer = mm_vision_select_layer
self.im_token_index = im_token_index
self.lang_hidden_size = lang_encoder.config.hidden_size
if use_mm_proj and not hasattr(lang_encoder.model, 'mm_projector'):
if mm_proj_depth == 1:
# Llava V1
mm_projector = nn.Linear(self.mm_hidden_size,
self.lang_hidden_size)
self.lang_encoder.model.add_module('mm_projector', mm_projector)
elif not use_mm_proj:
elif mm_proj_depth > 1:
# Llava V1.5
modules = [nn.Linear(self.mm_hidden_size, self.lang_hidden_size)]
for _ in range(1, mm_proj_depth):
modules.append(nn.GELU())
modules.append(
nn.Linear(self.lang_hidden_size, self.lang_hidden_size))
mm_projector = nn.Sequential(*modules)
self.lang_encoder.model.add_module('mm_projector', mm_projector)
elif mm_proj_depth == 0:
self.lang_encoder.model.add_module('mm_projector', nn.Identity())
self.post_init()
@ -80,16 +90,12 @@ class LlavaLlamaForCausalLM(PreTrainedModel):
return_dict
if return_dict is not None else self.config.use_return_dict)
# decoder outputs consists of
# (dec_features, layer_state, dec_hidden, dec_attn)
if inputs_embeds is None:
inputs_embeds = self.lang_encoder.model.embed_tokens(input_ids)
inputs_embeds = self.forward_vision_tower(input_ids, inputs_embeds,
images)
(input_ids, attention_mask, past_key_values, inputs_embeds,
labels) = self.forward_vision_tower(input_ids, attention_mask,
past_key_values, labels, images)
return self.lang_encoder(
input_ids=None,
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@ -127,106 +133,93 @@ class LlavaLlamaForCausalLM(PreTrainedModel):
def forward_vision_tower(
self,
input_ids: torch.LongTensor,
inputs_embeds: torch.FloatTensor,
images: Union[torch.FloatTensor, list, None] = None,
attention_mask: torch.LongTensor,
past_key_values: torch.FloatTensor,
labels: torch.LongTensor,
images: Union[torch.FloatTensor, None] = None,
):
if self.use_im_start_end:
assert self.im_start_token is not None
assert self.im_end_token is not None
if images is not None:
assert self.im_patch_token is not None
if self.vision_tower is None or images is None or (
input_ids.shape[1] == 1 and not self.training):
return inputs_embeds
if self.vision_tower is None or images is None or input_ids.shape[
1] == 1:
if (past_key_values is not None and self.vision_tower is not None
and images is not None and input_ids.shape[1] == 1):
attention_mask = torch.ones(
(attention_mask.shape[0],
past_key_values[-1][-1].shape[-2] + 1),
dtype=attention_mask.dtype,
device=attention_mask.device)
return input_ids, attention_mask, past_key_values, None, labels
with torch.no_grad():
if isinstance(images, (list, tuple)):
# variable length images
image_features = []
for image in images:
feats = self.vision_tower(image.unsqueeze(0))
image_feature = feats[self.mm_vision_select_layer][:, 1:]
image_features.append(image_feature)
else:
feats = self.vision_tower(images)
image_features = feats[self.mm_vision_select_layer][:, 1:]
# TODO: support variable number of images (single now)
feats = self.vision_tower(images)
image_features = feats[-1][:, 1:]
mm_projector = self.lang_encoder.model.mm_projector
if isinstance(images, (list, tuple)):
image_features = [
mm_projector(image_feature)[0]
for image_feature in image_features
]
else:
image_features = mm_projector(image_features)
dummy_image_features = torch.zeros(
256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
dummy_image_features = mm_projector(dummy_image_features)
image_features = self.lang_encoder.model.mm_projector(image_features)
new_input_embeds = []
cur_image_idx = 0
for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
if (cur_input_ids != self.im_patch_token).all():
# multimodal LLM, but the current sample is not multimodal
cur_input_embeds = cur_input_embeds + (
0. * dummy_image_features).sum()
new_input_embeds.append(cur_input_embeds)
cur_image_idx += 1
continue
if self.use_im_start_end:
cur_image_features = image_features[cur_image_idx]
num_patches = cur_image_features.shape[0]
if (cur_input_ids == self.im_start_token).sum() != (
cur_input_ids == self.im_end_token).sum():
raise ValueError('The number of image start tokens and '
'image end tokens should be the same.')
image_start_tokens = torch.where(
cur_input_ids == self.im_start_token)[0]
for image_start_token_pos in image_start_tokens:
cur_image_features = image_features[cur_image_idx].to(
device=cur_input_embeds.device)
num_patches = cur_image_features.shape[0]
if cur_input_ids[image_start_token_pos + num_patches +
1] != self.im_end_token:
raise ValueError('The image end token should follow '
'the image start token.')
cur_new_input_embeds = torch.cat(
(cur_input_embeds[:image_start_token_pos + 1],
cur_image_features,
cur_input_embeds[image_start_token_pos + num_patches +
1:]),
dim=0)
cur_image_idx += 1
new_input_embeds.append(cur_new_input_embeds)
else:
cur_image_features = image_features[cur_image_idx]
num_patches = cur_image_features.shape[0]
if (cur_input_ids == self.im_patch_token).sum() != num_patches:
print(f'Debug: num_patches: {num_patches}')
raise ValueError(
'The number of image patch tokens should '
'be the same as the number of image patches.')
masked_indices = torch.where(
cur_input_ids == self.im_patch_token)[0]
mask_index_start = masked_indices[0]
if (masked_indices != torch.arange(
mask_index_start,
mask_index_start + num_patches,
device=masked_indices.device,
dtype=masked_indices.dtype)).any():
raise ValueError(
'The image patch tokens should be consecutive.')
cur_new_input_embeds = torch.cat(
(cur_input_embeds[:mask_index_start], cur_image_features,
cur_input_embeds[mask_index_start + num_patches:]),
dim=0)
new_input_embeds.append(cur_new_input_embeds)
cur_image_idx += 1
inputs_embeds = torch.stack(new_input_embeds, dim=0)
new_labels = [] if labels is not None else None
new_attn_mask = [] if attention_mask is not None else None
for batch_idx, cur_input_ids in enumerate(input_ids):
cur_img = image_features[batch_idx]
return inputs_embeds
if (cur_input_ids != self.im_token_index).all():
# multimodal LLM, but the current sample is not multimodal
new_input_embeds.append(self.embed_tokens(cur_input_ids))
if labels is not None:
new_labels.append(labels[batch_idx])
if attention_mask is not None:
new_attn_mask.append(attention_mask[batch_idx])
continue
img_idx = torch.where(cur_input_ids == self.im_token_index)[0][0]
if self.use_im_start_end:
cur_new_input_embeds = torch.cat(
[
self.embed_tokens(cur_input_ids[:img_idx - 1]),
self.embed_tokens(cur_input_ids[img_idx - 1:img_idx]),
cur_img,
self.embed_tokens(
cur_input_ids[img_idx + 1:img_idx + 2]),
self.embed_tokens(cur_input_ids[img_idx + 2:]),
],
dim=0,
)
else:
cur_new_input_embeds = torch.cat(
[
self.embed_tokens(cur_input_ids[:img_idx]),
cur_img,
self.embed_tokens(cur_input_ids[img_idx + 1:]),
],
dim=0,
)
new_input_embeds.append(cur_new_input_embeds)
if labels is not None:
cur_new_labels = torch.cat([
labels[batch_idx, :img_idx],
labels.new_full((cur_img.size(0), ), -100),
labels[batch_idx, img_idx + 1:],
],
dim=0)
new_labels.append(cur_new_labels)
if attention_mask is not None:
cur_attn_mask = torch.cat([
attention_mask[batch_idx, :img_idx],
attention_mask.new_full((cur_img.size(0), ), True),
attention_mask[batch_idx, img_idx + 1:],
],
dim=0)
new_attn_mask.append(cur_attn_mask)
inputs_embeds = torch.stack(new_input_embeds, dim=0)
if labels is not None:
labels = torch.stack(new_labels, dim=0)
if attention_mask is not None:
attention_mask = torch.stack(new_attn_mask, dim=0)
return None, attention_mask, past_key_values, inputs_embeds, labels
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
@ -236,3 +229,6 @@ class LlavaLlamaForCausalLM(PreTrainedModel):
past_state.index_select(0, beam_idx)
for past_state in layer_past), )
return reordered_past
def embed_tokens(self, input_ids):
return self.lang_encoder.model.embed_tokens(input_ids)

View File

@ -9,23 +9,21 @@ from huggingface_hub import snapshot_download
from transformers.modeling_utils import load_state_dict
prog_description = """\
Merge Llava delta weights and original weights,
and save as MMPreTrain checkpoint.
Convert Llava weights and original weights.
"""
def parse_args():
parser = argparse.ArgumentParser(description=prog_description)
parser.add_argument(
'src_path', type=str, help='The original checkpoint dir')
parser.add_argument(
'delta_path', type=str, help='The delta checkpoint dir')
parser.add_argument('dst_path', type=str, help='The saved checkpoint path')
parser.add_argument('src', type=str, help='The original checkpoint dir')
parser.add_argument('dst', type=str, help='The saved checkpoint path')
parser.add_argument('--delta', type=str, help='The delta checkpoint dir')
args = parser.parse_args()
return args
def load_checkpoint(path: Path):
path = Path(path)
if path.is_file():
return torch.load(path)
@ -41,19 +39,23 @@ def load_checkpoint(path: Path):
def main():
args = parse_args()
if Path(args.src_path).exists():
src_path = Path(args.src_path)
if Path(args.src).exists():
src_path = args.src
else:
src_path = Path(snapshot_download(args.src_path))
src_path = snapshot_download(
args.src, allow_patterns='pytorch_model*.bin')
src_state_dict = load_checkpoint(src_path)
if Path(args.delta_path).exists():
delta_path = Path(args.delta_path)
if args.delta is None:
delta_state_dict = {}
elif Path(args.delta).exists():
delta_state_dict = load_checkpoint(args.delta)
else:
delta_path = Path(snapshot_download(args.delta_path))
delta_state_dict = load_checkpoint(delta_path)
delta_path = snapshot_download(
args.delta, allow_patterns='pytorch_model*.bin')
delta_state_dict = load_checkpoint(delta_path)
merged_state_dict = OrderedDict()
new_state_dict = OrderedDict()
for k, v in src_state_dict.items():
if k in delta_state_dict:
delta_v = delta_state_dict.pop(k)
@ -63,12 +65,13 @@ def main():
v = delta_v
else:
v += delta_v
merged_state_dict['model.lang_encoder.' + k] = v
if 'rotary_emb.inv_freq' not in k:
new_state_dict['model.lang_encoder.' + k] = v
for k, v in delta_state_dict.items():
merged_state_dict['model.lang_encoder.' + k] = v
new_state_dict['model.lang_encoder.' + k] = v
torch.save(merged_state_dict, args.dst_path)
torch.save(new_state_dict, args.dst)
print('Done!!')