[Feature] Support LLaVA (#1652)

pull/1531/head^2
Ma Zerun 2023-06-17 16:05:52 +08:00 committed by GitHub
parent e69bace03f
commit bfd49b0d52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 757 additions and 1 deletions

View File

@ -256,6 +256,7 @@ Results and models are available in the [model zoo](https://mmpretrain.readthedo
<li><a href="configs/flamingo">Flamingo (NeurIPS'2022)</a></li>
<li><a href="configs/chinese_clip">Chinese CLIP (arxiv'2022)</a></li>
<li><a href="configs/minigpt4">MiniGPT-4 (arxiv'2023)</a></li>
<li><a href="configs/llava">LLaVA (arxiv'2023)</a></li>
<li><a href="configs/otter">Otter (arxiv'2023)</a></li>
</ul>
</td>

View File

@ -252,6 +252,7 @@ mim install -e ".[multimodal]"
<li><a href="configs/flamingo">Flamingo (NeurIPS'2022)</a></li>
<li><a href="configs/chinese_clip">Chinese CLIP (arxiv'2022)</a></li>
<li><a href="configs/minigpt4">MiniGPT-4 (arxiv'2023)</a></li>
<li><a href="configs/llava">LLaVA (arxiv'2023)</a></li>
<li><a href="configs/otter">Otter (arxiv'2023)</a></li>
</ul>
</td>

View File

@ -0,0 +1,68 @@
# LLaVA
> [Visual Instruction Tuning](https://arxiv.org/abs/2304.08485)
<!-- [ALGORITHM] -->
## Abstract
Instruction tuning large language models (LLMs) using machine-generated instruction-following data has improved zero-shot capabilities on new tasks, but the idea is less explored in the multimodal field. In this paper, we present the first attempt to use language-only GPT-4 to generate multimodal language-image instruction-following data. By instruction tuning on such generated data, we introduce LLaVA: Large Language and Vision Assistant, an end-to-end trained large multimodal model that connects a vision encoder and LLM for general-purpose visual and language understanding.Our early experiments show that LLaVA demonstrates impressive multimodel chat abilities, sometimes exhibiting the behaviors of multimodal GPT-4 on unseen images/instructions, and yields a 85.1% relative score compared with GPT-4 on a synthetic multimodal instruction-following dataset. When fine-tuned on Science QA, the synergy of LLaVA and GPT-4 achieves a new state-of-the-art accuracy of 92.53%. We make GPT-4 generated visual instruction tuning data, our model and code base publicly available.
<div align=center>
<img src="https://github-production-user-asset-6210df.s3.amazonaws.com/26739999/246466979-c2f41b71-1de3-4da8-b20a-eaebe722c339.png" width="80%"/>
</div>
## How to use it?
<!-- [TABS-BEGIN] -->
**Prepare the checkpoint**
According to the license of LLaMA, we cannot provide the merged checkpoint directly. Please use the below
script to download and get the merged the checkpoint.
```baseh
python tools/model_converters/llava-delta2mmpre.py huggyllama/llama-7b liuhaotian/LLaVA-Lightning-7B-delta-v1-1 ./LLaVA-Lightning-7B-delta-v1-1.pth
```
**Use the model**
```python
import torch
from mmpretrain import get_model, inference_model
model = get_model('llava-7b-v1_caption', pretrained='MERGED_CHECKPOINT_PATH', device='cuda')
out = inference_model(model, 'demo/cat-dog.png')
print(out)
```
**Test Command**
Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).
Test:
```shell
python tools/test.py configs/llava/llava-7b-v1_caption.py MERGED_CHECKPOINT_PATH
```
<!-- [TABS-END] -->
## Models and results
### Image Caption on COCO
| Model | Params (M) | BLEU-4 | CIDER | Config | Download |
| :-------------------- | :--------: | :------: | :------: | :------------------------------: | :--------------------: |
| `llava-7b-v1_caption` | 7045.82 | Upcoming | Upcoming | [config](llava-7b-v1_caption.py) | See the above tutorial |
## Citation
```bibtex
@misc{liu2023llava,
title={Visual Instruction Tuning},
author={Liu, Haotian and Li, Chunyuan and Wu, Qingyang and Lee, Yong Jae},
publisher={arXiv:2304.08485},
year={2023},
}
```

View File

@ -0,0 +1,83 @@
_base_ = '../_base_/default_runtime.py'
meta_prompt = 'You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab.You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.Follow the instructions carefully and explain your answers in detail.' # noqa: E501
im_patch_token = '<im_patch>'
patch_size = 14
image_size = 224
num_patches = (image_size // patch_size)**2
caption_prompt = ' '.join([
meta_prompt,
'User: a photo of\n',
im_patch_token * num_patches,
'ASSISTANT:',
])
# model settings
model = dict(
type='Llava',
tokenizer=dict(
type='AutoTokenizer',
name_or_path='liuhaotian/LLaVA-Lightning-7B-delta-v1-1'),
vision_encoder=dict(
type='VisionTransformer',
arch='l',
patch_size=14,
pre_norm=True,
norm_cfg=dict(type='LN', eps=1e-5),
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')),
final_norm=False,
out_type='raw',
pretrained=(
'https://download.openmmlab.com/mmclassification/v0/clip/'
'vit-large-p14_clip-openai-pre_3rdparty_20230517-95e2af0b.pth'),
),
mm_hidden_size=1024,
use_im_start_end=False,
use_mm_proj=True,
lang_encoder=dict(
type='AutoModelForCausalLM',
name_or_path='huggyllama/llama-7b',
),
task='caption',
prompt_tmpl=caption_prompt,
generation_cfg=dict(num_beams=3, max_new_tokens=20, length_penalty=-2.0),
)
# data settings
data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(image_size, image_size),
interpolation='bicubic',
backend='pillow'),
dict(type='PackInputs', meta_keys=['image_id']),
]
test_dataloader = dict(
batch_size=8,
num_workers=5,
dataset=dict(
type='COCOCaption',
data_root='data/coco',
ann_file='annotations/coco_karpathy_val.json',
pipeline=test_pipeline,
),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
test_evaluator = dict(
type='COCOCaption',
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
)
# schedule settings
test_cfg = dict()

View File

@ -0,0 +1,25 @@
Collections:
- Name: LLaVA
Metadata:
Architecture:
- LLaMA
- CLIP
Paper:
Title: Visual Instruction Tuning
URL: https://arxiv.org/abs/2304.08485
README: configs/llava/README.md
Models:
- Name: llava-7b-v1_caption
Metadata:
FLOPs: null
Parameters: 7045816320
In Collection: LLaVA
Results:
- Task: Image Caption
Dataset: COCO
Metrics:
BLEU-4: null
CIDER: null
Weights: null
Config: configs/llava/llava-7b-v1_caption.py

View File

@ -144,6 +144,7 @@ Multi-Modality Algorithms
Flamingo
OFA
MiniGPT4
Llava
Otter
.. module:: mmpretrain.models.backbones

View File

@ -6,6 +6,7 @@ if WITH_MULTIMODAL:
from .blip2 import * # noqa: F401,F403
from .chinese_clip import * # noqa: F401, F403
from .flamingo import * # noqa: F401, F403
from .llava import * # noqa: F401, F403
from .minigpt4 import * # noqa: F401, F403
from .ofa import * # noqa: F401, F403
from .otter import * # noqa: F401, F403
@ -16,5 +17,5 @@ else:
register_multimodal_placeholder([
'Blip2Caption', 'Blip2Retrieval', 'Blip2VQA', 'BlipCaption',
'BlipNLVR', 'BlipRetrieval', 'BlipGrounding', 'BlipVQA', 'Flamingo',
'OFA', 'ChineseCLIP', 'MiniGPT4', 'Otter'
'OFA', 'ChineseCLIP', 'MiniGPT4', 'Llava', 'Otter'
], MODELS)

View File

@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .llava import Llava
from .modules import LlavaLlamaForCausalLM
__all__ = ['Llava', 'LlavaLlamaForCausalLM']

View File

@ -0,0 +1,256 @@
# Copyright (c) OpenMMLab. All rights reserved.
import re
from typing import List, Optional
import torch
from mmengine.model import BaseModel
from mmpretrain.registry import MODELS, TOKENIZER
from mmpretrain.structures import DataSample
from ...utils import no_load_hf_pretrained_model
from .modules import LlavaLlamaForCausalLM
@MODELS.register_module()
class Llava(BaseModel):
"""The LLaVA model for multiple tasks.
Args:
vision_encoder (dict): The config of the vision encoder.
lang_encoder (dict): The config of the language encoder.
tokenizer (dict): The tokenizer to encode the text.
prompt_tmpl (str): Prompt template for inference.
task (int): The task to perform prediction.
use_im_start_end (bool): Whether to use the im_start and im_end tokens
mm_vision_select_layer (int): The index from vision encoder output.
Defaults to -1.
use_mm_proj (bool): Whether to enable multi-modal projection.
Defaults to True.
load_lang_pretrained (bool): Whether to load the pretrained model of
language encoder. Defaults to False.
generation_cfg (dict): The extra generation config, accept the keyword
arguments of [~`transformers.GenerationConfig`].
Defaults to an empty dict.
data_preprocessor (Optional[dict]): The config for preprocessing input
data. If None or no specified type, it will use
"MutimodalDataPreprocessor" as type.
See :class:`MutimodalDataPreprocessor` for more details.
Defaults to None.
init_cfg (dict, optional): The initialization config. Defaults to None.
"""
support_tasks = {'caption', 'vqa'}
im_patch_token = '<im_patch>'
im_start_token = '<im_start>'
im_end_token = '<im_end>'
def __init__(self,
vision_encoder: dict,
lang_encoder: dict,
tokenizer: dict,
mm_hidden_size: int,
prompt_tmpl: str,
task: str = 'caption',
use_im_start_end: bool = False,
mm_vision_select_layer: int = -1,
use_mm_proj: bool = True,
generation_cfg: dict = dict(),
load_lang_pretrained: bool = False,
data_preprocessor: Optional[dict] = None,
init_cfg: Optional[dict] = None):
if data_preprocessor is None:
data_preprocessor = {}
if isinstance(data_preprocessor, dict):
data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor')
data_preprocessor = MODELS.build(data_preprocessor)
super().__init__(
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
if task not in self.support_tasks:
raise ValueError(f'Unsupported task {task}, please select '
f'the task from {self.support_tasks}.')
self.task = task
# init tokenizer
self.tokenizer = TOKENIZER.build(tokenizer)
# add Llava special tokens to the tokenizer
self.tokenizer.add_tokens([self.im_patch_token], special_tokens=True)
if use_im_start_end:
self.tokenizer.add_tokens([self.im_start_token, self.im_end_token],
special_tokens=True)
# Template to format the prompt input
self.prompt_tmpl = prompt_tmpl
# init vision encoder related modules
vision_encoder_weight = vision_encoder.pop('pretrained', None)
vision_encoder = MODELS.build(vision_encoder)
if vision_encoder_weight is not None:
from mmengine.runner.checkpoint import load_checkpoint
load_checkpoint(
vision_encoder,
vision_encoder_weight,
map_location='cpu',
revise_keys=[(r'^backbone\.', '')],
)
# init language encoder related modules
if load_lang_pretrained:
lang_encoder = MODELS.build(lang_encoder)
else:
with no_load_hf_pretrained_model():
lang_encoder = MODELS.build(lang_encoder)
lang_encoder.resize_token_embeddings(len(self.tokenizer))
self.model = LlavaLlamaForCausalLM(
vision_encoder=vision_encoder,
lang_encoder=lang_encoder,
mm_hidden_size=mm_hidden_size,
use_mm_proj=use_mm_proj,
use_im_start_end=use_im_start_end,
im_start_token=self.tokenizer.convert_tokens_to_ids(
self.im_start_token),
im_end_token=self.tokenizer.convert_tokens_to_ids(
self.im_end_token),
im_patch_token=self.tokenizer.convert_tokens_to_ids(
self.im_patch_token),
mm_vision_select_layer=mm_vision_select_layer)
self.generation_cfg = generation_cfg
if hasattr(self, 'register_load_state_dict_post_hook'):
self.register_load_state_dict_post_hook(self._load_ckpt_hook)
def forward(
self,
images: torch.Tensor,
data_samples: Optional[List[DataSample]] = None,
mode: str = 'loss',
):
"""The unified entry for a forward process in both training and test.
The method should accept only one mode "loss":
- "loss": Forward and return a dict of losses according to the given
inputs and data samples.
Note that this method doesn't handle neither back propagation nor
optimizer updating, which are done in the :meth:`train_step`.
Args:
images (torch.Tensor): The input image tensor with different ndim
according to the inputs.
data_samples (List[DataSample], optional): The annotation
data of every samples. It's required if ``mode="loss"``.
Defaults to None.
mode (str): Return what kind of value. Defaults to 'loss'.
Returns:
The return type depends on ``mode``.
- If ``mode="loss"``, return a dict of tensor.
"""
if mode == 'loss':
return self.loss(images, data_samples)
elif mode == 'predict':
return self.predict(images, data_samples)
else:
raise RuntimeError(f'Invalid mode "{mode}".')
def predict(self,
images: torch.Tensor,
data_samples: Optional[List[DataSample]] = None,
**generation_cfg):
"""Predict generation results from a batch of inputs.
Args:
images (torch.Tensor): For zero-shot, the input images tensor is
with shape (B, C, H, W), for few-shot, which is
(B, T_img, C, H, W) in general. Images in the same chunk
are collated along T_img. Video data is not supported yet.
data_samples (List[DataSample], optional): The annotation
data of every samples. Defaults to None.
**generation_cfg: Other keyword arguments accepted by the
``generate`` method of :attr:`lang_encoder`.
Returns:
List[DataSample]: Return list of data samples.
"""
# generation_cfg in prediction should be dominant
generation_cfg = {**self.generation_cfg, **generation_cfg}
input_text = self.preprocess_text(data_samples, device=images.device)
outputs = self.model.generate(
input_text.input_ids,
attention_mask=input_text.attention_mask,
eos_token_id=self.tokenizer.eos_token_id,
images=images,
**generation_cfg)
# remove prefix
outputs = outputs[:, len(input_text.input_ids[0]):]
return self.post_process(outputs, data_samples)
def preprocess_text(self, data_samples: List[DataSample],
device: torch.device) -> List[DataSample]:
"""Preprocess text in advance before fed into language model.
Args:
data_samples (List[DataSample]): The annotation
data of every samples. Defaults to None.
device (torch.device): Device for text to put on.
Returns:
List[DataSample]: Return list of data samples.
"""
prompts = []
for sample in data_samples:
final_prompt = self.prompt_tmpl.format(**sample.to_dict())
prompts.append(final_prompt)
self.tokenizer.padding_side = 'left'
input_text = self.tokenizer(
prompts,
padding='longest',
truncation=True,
return_tensors='pt',
max_length=2000,
).to(device)
return input_text
def post_process(
self, outputs: torch.Tensor,
data_samples: Optional[List[DataSample]]) -> List[DataSample]:
"""Perform post process for outputs for different task.
Args:
outputs (torch.Tensor): The generated outputs.
data_samples (List[DataSample], optional): The annotation
data of every samples.
Returns:
List[DataSample]: Return list of data samples.
"""
outputs = self.tokenizer.batch_decode(
outputs, skip_special_tokens=True)
if data_samples is None:
data_samples = [DataSample() for _ in range(len(outputs))]
for output, data_sample in zip(outputs, data_samples):
# remove text pattern
if self.task == 'caption':
data_sample.pred_caption = output
elif self.task == 'vqa':
data_sample.pred_answer = output
return data_samples
@staticmethod
def _load_ckpt_hook(module, incompatible_keys):
"""Avoid warning missing keys except lang_encoder keys."""
for key in list(incompatible_keys.missing_keys):
if re.match('model.vision_tower', key):
incompatible_keys.missing_keys.remove(key)

View File

@ -0,0 +1,238 @@
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import torch
import torch.nn as nn
from transformers import PreTrainedModel
DEFAULT_IMAGE_TOKEN = '<image>'
DEFAULT_IMAGE_PATCH_TOKEN = '<im_patch>'
DEFAULT_IM_START_TOKEN = '<im_start>'
DEFAULT_IM_END_TOKEN = '<im_end>'
class LlavaLlamaForCausalLM(PreTrainedModel):
def __init__(self,
vision_encoder,
lang_encoder,
mm_hidden_size,
use_im_start_end=True,
use_mm_proj=True,
im_start_token: Optional[int] = None,
im_end_token: Optional[int] = None,
im_patch_token: Optional[int] = None,
mm_vision_select_layer: int = -1):
super().__init__(lang_encoder.config)
self.vision_tower = vision_encoder
self.lang_encoder = lang_encoder
self.use_im_start_end = use_im_start_end
self.im_start_token = im_start_token
self.im_end_token = im_end_token
self.im_patch_token = im_patch_token
self.mm_hidden_size = mm_hidden_size
self.mm_vision_select_layer = mm_vision_select_layer
self.lang_hidden_size = lang_encoder.config.hidden_size
if use_mm_proj and not hasattr(lang_encoder.model, 'mm_projector'):
mm_projector = nn.Linear(self.mm_hidden_size,
self.lang_hidden_size)
self.lang_encoder.model.add_module('mm_projector', mm_projector)
elif not use_mm_proj:
self.lang_encoder.model.add_module('mm_projector', nn.Identity())
self.post_init()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
):
output_attentions = (
output_attentions if output_attentions is not None else
self.config.output_attentions)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else
self.config.output_hidden_states)
return_dict = (
return_dict
if return_dict is not None else self.config.use_return_dict)
# decoder outputs consists of
# (dec_features, layer_state, dec_hidden, dec_attn)
if inputs_embeds is None:
inputs_embeds = self.lang_encoder.model.embed_tokens(input_ids)
inputs_embeds = self.forward_vision_tower(input_ids, inputs_embeds,
images)
return self.lang_encoder(
input_ids=None,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
)
def prepare_inputs_for_generation(self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs):
if past_key_values:
input_ids = input_ids[:, -1:]
# if `inputs_embeds` are passed, we only want to use
# them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {'inputs_embeds': inputs_embeds}
else:
model_inputs = {'input_ids': input_ids}
model_inputs.update({
'past_key_values': past_key_values,
'use_cache': kwargs.get('use_cache'),
'attention_mask': attention_mask,
'images': kwargs.get('images', None),
})
return model_inputs
def forward_vision_tower(
self,
input_ids: torch.LongTensor,
inputs_embeds: torch.FloatTensor,
images: Union[torch.FloatTensor, list, None] = None,
):
if self.use_im_start_end:
assert self.im_start_token is not None
assert self.im_end_token is not None
if images is not None:
assert self.im_patch_token is not None
if self.vision_tower is None or images is None or (
input_ids.shape[1] == 1 and not self.training):
return inputs_embeds
with torch.no_grad():
if isinstance(images, (list, tuple)):
# variable length images
image_features = []
for image in images:
feats = self.vision_tower(image.unsqueeze(0))
image_feature = feats[self.mm_vision_select_layer][:, 1:]
image_features.append(image_feature)
else:
feats = self.vision_tower(images)
image_features = feats[self.mm_vision_select_layer][:, 1:]
mm_projector = self.lang_encoder.model.mm_projector
if isinstance(images, (list, tuple)):
image_features = [
mm_projector(image_feature)[0]
for image_feature in image_features
]
else:
image_features = mm_projector(image_features)
dummy_image_features = torch.zeros(
256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
dummy_image_features = mm_projector(dummy_image_features)
new_input_embeds = []
cur_image_idx = 0
for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
if (cur_input_ids != self.im_patch_token).all():
# multimodal LLM, but the current sample is not multimodal
cur_input_embeds = cur_input_embeds + (
0. * dummy_image_features).sum()
new_input_embeds.append(cur_input_embeds)
cur_image_idx += 1
continue
if self.use_im_start_end:
cur_image_features = image_features[cur_image_idx]
num_patches = cur_image_features.shape[0]
if (cur_input_ids == self.im_start_token).sum() != (
cur_input_ids == self.im_end_token).sum():
raise ValueError('The number of image start tokens and '
'image end tokens should be the same.')
image_start_tokens = torch.where(
cur_input_ids == self.im_start_token)[0]
for image_start_token_pos in image_start_tokens:
cur_image_features = image_features[cur_image_idx].to(
device=cur_input_embeds.device)
num_patches = cur_image_features.shape[0]
if cur_input_ids[image_start_token_pos + num_patches +
1] != self.im_end_token:
raise ValueError('The image end token should follow '
'the image start token.')
cur_new_input_embeds = torch.cat(
(cur_input_embeds[:image_start_token_pos + 1],
cur_image_features,
cur_input_embeds[image_start_token_pos + num_patches +
1:]),
dim=0)
cur_image_idx += 1
new_input_embeds.append(cur_new_input_embeds)
else:
cur_image_features = image_features[cur_image_idx]
num_patches = cur_image_features.shape[0]
if (cur_input_ids == self.im_patch_token).sum() != num_patches:
print(f'Debug: num_patches: {num_patches}')
raise ValueError(
'The number of image patch tokens should '
'be the same as the number of image patches.')
masked_indices = torch.where(
cur_input_ids == self.im_patch_token)[0]
mask_index_start = masked_indices[0]
if (masked_indices != torch.arange(
mask_index_start,
mask_index_start + num_patches,
device=masked_indices.device,
dtype=masked_indices.dtype)).any():
raise ValueError(
'The image patch tokens should be consecutive.')
cur_new_input_embeds = torch.cat(
(cur_input_embeds[:mask_index_start], cur_image_features,
cur_input_embeds[mask_index_start + num_patches:]),
dim=0)
new_input_embeds.append(cur_new_input_embeds)
cur_image_idx += 1
inputs_embeds = torch.stack(new_input_embeds, dim=0)
return inputs_embeds
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(
past_state.index_select(0, beam_idx)
for past_state in layer_past), )
return reordered_past

View File

@ -79,4 +79,5 @@ Import:
- configs/itpn/metafile.yml
- configs/hivit/metafile.yml
- configs/minigpt4/metafile.yml
- configs/llava/metafile.yml
- configs/otter/metafile.yml

View File

@ -0,0 +1,76 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from collections import OrderedDict
from itertools import chain
from pathlib import Path
import torch
from huggingface_hub import snapshot_download
from transformers.modeling_utils import load_state_dict
prog_description = """\
Merge Llava delta weights and original weights,
and save as MMPreTrain checkpoint.
"""
def parse_args():
parser = argparse.ArgumentParser(description=prog_description)
parser.add_argument(
'src_path', type=str, help='The original checkpoint dir')
parser.add_argument(
'delta_path', type=str, help='The delta checkpoint dir')
parser.add_argument('dst_path', type=str, help='The saved checkpoint path')
args = parser.parse_args()
return args
def load_checkpoint(path: Path):
if path.is_file():
return torch.load(path)
state_dict = OrderedDict()
for ckpt in chain(
path.rglob('*.bin'), path.rglob('*.pth'),
path.rglob('*.safetensors')):
state_dict.update(load_state_dict(str(ckpt)))
return state_dict
def main():
args = parse_args()
if Path(args.src_path).exists():
src_path = Path(args.src_path)
else:
src_path = Path(snapshot_download(args.src_path))
src_state_dict = load_checkpoint(src_path)
if Path(args.delta_path).exists():
delta_path = Path(args.delta_path)
else:
delta_path = Path(snapshot_download(args.delta_path))
delta_state_dict = load_checkpoint(delta_path)
merged_state_dict = OrderedDict()
for k, v in src_state_dict.items():
if k in delta_state_dict:
delta_v = delta_state_dict.pop(k)
if k in ['model.embed_tokens.weight', 'lm_head.weight']:
h, w = v.shape[:2]
delta_v[:h, :w] += v
v = delta_v
else:
v += delta_v
merged_state_dict['model.lang_encoder.' + k] = v
for k, v in delta_state_dict.items():
merged_state_dict['model.lang_encoder.' + k] = v
torch.save(merged_state_dict, args.dst_path)
print('Done!!')
if __name__ == '__main__':
main()