[Feature] Support LLaVA (#1652)
parent
e69bace03f
commit
bfd49b0d52
|
@ -256,6 +256,7 @@ Results and models are available in the [model zoo](https://mmpretrain.readthedo
|
|||
<li><a href="configs/flamingo">Flamingo (NeurIPS'2022)</a></li>
|
||||
<li><a href="configs/chinese_clip">Chinese CLIP (arxiv'2022)</a></li>
|
||||
<li><a href="configs/minigpt4">MiniGPT-4 (arxiv'2023)</a></li>
|
||||
<li><a href="configs/llava">LLaVA (arxiv'2023)</a></li>
|
||||
<li><a href="configs/otter">Otter (arxiv'2023)</a></li>
|
||||
</ul>
|
||||
</td>
|
||||
|
|
|
@ -252,6 +252,7 @@ mim install -e ".[multimodal]"
|
|||
<li><a href="configs/flamingo">Flamingo (NeurIPS'2022)</a></li>
|
||||
<li><a href="configs/chinese_clip">Chinese CLIP (arxiv'2022)</a></li>
|
||||
<li><a href="configs/minigpt4">MiniGPT-4 (arxiv'2023)</a></li>
|
||||
<li><a href="configs/llava">LLaVA (arxiv'2023)</a></li>
|
||||
<li><a href="configs/otter">Otter (arxiv'2023)</a></li>
|
||||
</ul>
|
||||
</td>
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
# LLaVA
|
||||
|
||||
> [Visual Instruction Tuning](https://arxiv.org/abs/2304.08485)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
Instruction tuning large language models (LLMs) using machine-generated instruction-following data has improved zero-shot capabilities on new tasks, but the idea is less explored in the multimodal field. In this paper, we present the first attempt to use language-only GPT-4 to generate multimodal language-image instruction-following data. By instruction tuning on such generated data, we introduce LLaVA: Large Language and Vision Assistant, an end-to-end trained large multimodal model that connects a vision encoder and LLM for general-purpose visual and language understanding.Our early experiments show that LLaVA demonstrates impressive multimodel chat abilities, sometimes exhibiting the behaviors of multimodal GPT-4 on unseen images/instructions, and yields a 85.1% relative score compared with GPT-4 on a synthetic multimodal instruction-following dataset. When fine-tuned on Science QA, the synergy of LLaVA and GPT-4 achieves a new state-of-the-art accuracy of 92.53%. We make GPT-4 generated visual instruction tuning data, our model and code base publicly available.
|
||||
|
||||
<div align=center>
|
||||
<img src="https://github-production-user-asset-6210df.s3.amazonaws.com/26739999/246466979-c2f41b71-1de3-4da8-b20a-eaebe722c339.png" width="80%"/>
|
||||
</div>
|
||||
|
||||
## How to use it?
|
||||
|
||||
<!-- [TABS-BEGIN] -->
|
||||
|
||||
**Prepare the checkpoint**
|
||||
|
||||
According to the license of LLaMA, we cannot provide the merged checkpoint directly. Please use the below
|
||||
script to download and get the merged the checkpoint.
|
||||
|
||||
```baseh
|
||||
python tools/model_converters/llava-delta2mmpre.py huggyllama/llama-7b liuhaotian/LLaVA-Lightning-7B-delta-v1-1 ./LLaVA-Lightning-7B-delta-v1-1.pth
|
||||
```
|
||||
|
||||
**Use the model**
|
||||
|
||||
```python
|
||||
import torch
|
||||
from mmpretrain import get_model, inference_model
|
||||
|
||||
model = get_model('llava-7b-v1_caption', pretrained='MERGED_CHECKPOINT_PATH', device='cuda')
|
||||
out = inference_model(model, 'demo/cat-dog.png')
|
||||
print(out)
|
||||
```
|
||||
|
||||
**Test Command**
|
||||
|
||||
Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).
|
||||
|
||||
Test:
|
||||
|
||||
```shell
|
||||
python tools/test.py configs/llava/llava-7b-v1_caption.py MERGED_CHECKPOINT_PATH
|
||||
```
|
||||
|
||||
<!-- [TABS-END] -->
|
||||
|
||||
## Models and results
|
||||
|
||||
### Image Caption on COCO
|
||||
|
||||
| Model | Params (M) | BLEU-4 | CIDER | Config | Download |
|
||||
| :-------------------- | :--------: | :------: | :------: | :------------------------------: | :--------------------: |
|
||||
| `llava-7b-v1_caption` | 7045.82 | Upcoming | Upcoming | [config](llava-7b-v1_caption.py) | See the above tutorial |
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{liu2023llava,
|
||||
title={Visual Instruction Tuning},
|
||||
author={Liu, Haotian and Li, Chunyuan and Wu, Qingyang and Lee, Yong Jae},
|
||||
publisher={arXiv:2304.08485},
|
||||
year={2023},
|
||||
}
|
||||
```
|
|
@ -0,0 +1,83 @@
|
|||
_base_ = '../_base_/default_runtime.py'
|
||||
|
||||
meta_prompt = 'You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab.You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.Follow the instructions carefully and explain your answers in detail.' # noqa: E501
|
||||
im_patch_token = '<im_patch>'
|
||||
patch_size = 14
|
||||
image_size = 224
|
||||
num_patches = (image_size // patch_size)**2
|
||||
caption_prompt = ' '.join([
|
||||
meta_prompt,
|
||||
'User: a photo of\n',
|
||||
im_patch_token * num_patches,
|
||||
'ASSISTANT:',
|
||||
])
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='Llava',
|
||||
tokenizer=dict(
|
||||
type='AutoTokenizer',
|
||||
name_or_path='liuhaotian/LLaVA-Lightning-7B-delta-v1-1'),
|
||||
vision_encoder=dict(
|
||||
type='VisionTransformer',
|
||||
arch='l',
|
||||
patch_size=14,
|
||||
pre_norm=True,
|
||||
norm_cfg=dict(type='LN', eps=1e-5),
|
||||
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')),
|
||||
final_norm=False,
|
||||
out_type='raw',
|
||||
pretrained=(
|
||||
'https://download.openmmlab.com/mmclassification/v0/clip/'
|
||||
'vit-large-p14_clip-openai-pre_3rdparty_20230517-95e2af0b.pth'),
|
||||
),
|
||||
mm_hidden_size=1024,
|
||||
use_im_start_end=False,
|
||||
use_mm_proj=True,
|
||||
lang_encoder=dict(
|
||||
type='AutoModelForCausalLM',
|
||||
name_or_path='huggyllama/llama-7b',
|
||||
),
|
||||
task='caption',
|
||||
prompt_tmpl=caption_prompt,
|
||||
generation_cfg=dict(num_beams=3, max_new_tokens=20, length_penalty=-2.0),
|
||||
)
|
||||
|
||||
# data settings
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[122.770938, 116.7460125, 104.09373615],
|
||||
std=[68.5005327, 66.6321579, 70.32316305],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
scale=(image_size, image_size),
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='PackInputs', meta_keys=['image_id']),
|
||||
]
|
||||
|
||||
test_dataloader = dict(
|
||||
batch_size=8,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type='COCOCaption',
|
||||
data_root='data/coco',
|
||||
ann_file='annotations/coco_karpathy_val.json',
|
||||
pipeline=test_pipeline,
|
||||
),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
test_evaluator = dict(
|
||||
type='COCOCaption',
|
||||
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
|
||||
)
|
||||
|
||||
# schedule settings
|
||||
test_cfg = dict()
|
|
@ -0,0 +1,25 @@
|
|||
Collections:
|
||||
- Name: LLaVA
|
||||
Metadata:
|
||||
Architecture:
|
||||
- LLaMA
|
||||
- CLIP
|
||||
Paper:
|
||||
Title: Visual Instruction Tuning
|
||||
URL: https://arxiv.org/abs/2304.08485
|
||||
README: configs/llava/README.md
|
||||
|
||||
Models:
|
||||
- Name: llava-7b-v1_caption
|
||||
Metadata:
|
||||
FLOPs: null
|
||||
Parameters: 7045816320
|
||||
In Collection: LLaVA
|
||||
Results:
|
||||
- Task: Image Caption
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
BLEU-4: null
|
||||
CIDER: null
|
||||
Weights: null
|
||||
Config: configs/llava/llava-7b-v1_caption.py
|
|
@ -144,6 +144,7 @@ Multi-Modality Algorithms
|
|||
Flamingo
|
||||
OFA
|
||||
MiniGPT4
|
||||
Llava
|
||||
Otter
|
||||
|
||||
.. module:: mmpretrain.models.backbones
|
||||
|
|
|
@ -6,6 +6,7 @@ if WITH_MULTIMODAL:
|
|||
from .blip2 import * # noqa: F401,F403
|
||||
from .chinese_clip import * # noqa: F401, F403
|
||||
from .flamingo import * # noqa: F401, F403
|
||||
from .llava import * # noqa: F401, F403
|
||||
from .minigpt4 import * # noqa: F401, F403
|
||||
from .ofa import * # noqa: F401, F403
|
||||
from .otter import * # noqa: F401, F403
|
||||
|
@ -16,5 +17,5 @@ else:
|
|||
register_multimodal_placeholder([
|
||||
'Blip2Caption', 'Blip2Retrieval', 'Blip2VQA', 'BlipCaption',
|
||||
'BlipNLVR', 'BlipRetrieval', 'BlipGrounding', 'BlipVQA', 'Flamingo',
|
||||
'OFA', 'ChineseCLIP', 'MiniGPT4', 'Otter'
|
||||
'OFA', 'ChineseCLIP', 'MiniGPT4', 'Llava', 'Otter'
|
||||
], MODELS)
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .llava import Llava
|
||||
from .modules import LlavaLlamaForCausalLM
|
||||
|
||||
__all__ = ['Llava', 'LlavaLlamaForCausalLM']
|
|
@ -0,0 +1,256 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from mmengine.model import BaseModel
|
||||
|
||||
from mmpretrain.registry import MODELS, TOKENIZER
|
||||
from mmpretrain.structures import DataSample
|
||||
from ...utils import no_load_hf_pretrained_model
|
||||
from .modules import LlavaLlamaForCausalLM
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class Llava(BaseModel):
|
||||
"""The LLaVA model for multiple tasks.
|
||||
|
||||
Args:
|
||||
vision_encoder (dict): The config of the vision encoder.
|
||||
lang_encoder (dict): The config of the language encoder.
|
||||
tokenizer (dict): The tokenizer to encode the text.
|
||||
prompt_tmpl (str): Prompt template for inference.
|
||||
task (int): The task to perform prediction.
|
||||
use_im_start_end (bool): Whether to use the im_start and im_end tokens
|
||||
mm_vision_select_layer (int): The index from vision encoder output.
|
||||
Defaults to -1.
|
||||
use_mm_proj (bool): Whether to enable multi-modal projection.
|
||||
Defaults to True.
|
||||
load_lang_pretrained (bool): Whether to load the pretrained model of
|
||||
language encoder. Defaults to False.
|
||||
generation_cfg (dict): The extra generation config, accept the keyword
|
||||
arguments of [~`transformers.GenerationConfig`].
|
||||
Defaults to an empty dict.
|
||||
data_preprocessor (Optional[dict]): The config for preprocessing input
|
||||
data. If None or no specified type, it will use
|
||||
"MutimodalDataPreprocessor" as type.
|
||||
See :class:`MutimodalDataPreprocessor` for more details.
|
||||
Defaults to None.
|
||||
init_cfg (dict, optional): The initialization config. Defaults to None.
|
||||
"""
|
||||
|
||||
support_tasks = {'caption', 'vqa'}
|
||||
im_patch_token = '<im_patch>'
|
||||
im_start_token = '<im_start>'
|
||||
im_end_token = '<im_end>'
|
||||
|
||||
def __init__(self,
|
||||
vision_encoder: dict,
|
||||
lang_encoder: dict,
|
||||
tokenizer: dict,
|
||||
mm_hidden_size: int,
|
||||
prompt_tmpl: str,
|
||||
task: str = 'caption',
|
||||
use_im_start_end: bool = False,
|
||||
mm_vision_select_layer: int = -1,
|
||||
use_mm_proj: bool = True,
|
||||
generation_cfg: dict = dict(),
|
||||
load_lang_pretrained: bool = False,
|
||||
data_preprocessor: Optional[dict] = None,
|
||||
init_cfg: Optional[dict] = None):
|
||||
if data_preprocessor is None:
|
||||
data_preprocessor = {}
|
||||
if isinstance(data_preprocessor, dict):
|
||||
data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor')
|
||||
data_preprocessor = MODELS.build(data_preprocessor)
|
||||
|
||||
super().__init__(
|
||||
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
|
||||
|
||||
if task not in self.support_tasks:
|
||||
raise ValueError(f'Unsupported task {task}, please select '
|
||||
f'the task from {self.support_tasks}.')
|
||||
self.task = task
|
||||
|
||||
# init tokenizer
|
||||
self.tokenizer = TOKENIZER.build(tokenizer)
|
||||
# add Llava special tokens to the tokenizer
|
||||
self.tokenizer.add_tokens([self.im_patch_token], special_tokens=True)
|
||||
if use_im_start_end:
|
||||
self.tokenizer.add_tokens([self.im_start_token, self.im_end_token],
|
||||
special_tokens=True)
|
||||
|
||||
# Template to format the prompt input
|
||||
self.prompt_tmpl = prompt_tmpl
|
||||
|
||||
# init vision encoder related modules
|
||||
vision_encoder_weight = vision_encoder.pop('pretrained', None)
|
||||
vision_encoder = MODELS.build(vision_encoder)
|
||||
if vision_encoder_weight is not None:
|
||||
from mmengine.runner.checkpoint import load_checkpoint
|
||||
load_checkpoint(
|
||||
vision_encoder,
|
||||
vision_encoder_weight,
|
||||
map_location='cpu',
|
||||
revise_keys=[(r'^backbone\.', '')],
|
||||
)
|
||||
|
||||
# init language encoder related modules
|
||||
if load_lang_pretrained:
|
||||
lang_encoder = MODELS.build(lang_encoder)
|
||||
else:
|
||||
with no_load_hf_pretrained_model():
|
||||
lang_encoder = MODELS.build(lang_encoder)
|
||||
lang_encoder.resize_token_embeddings(len(self.tokenizer))
|
||||
|
||||
self.model = LlavaLlamaForCausalLM(
|
||||
vision_encoder=vision_encoder,
|
||||
lang_encoder=lang_encoder,
|
||||
mm_hidden_size=mm_hidden_size,
|
||||
use_mm_proj=use_mm_proj,
|
||||
use_im_start_end=use_im_start_end,
|
||||
im_start_token=self.tokenizer.convert_tokens_to_ids(
|
||||
self.im_start_token),
|
||||
im_end_token=self.tokenizer.convert_tokens_to_ids(
|
||||
self.im_end_token),
|
||||
im_patch_token=self.tokenizer.convert_tokens_to_ids(
|
||||
self.im_patch_token),
|
||||
mm_vision_select_layer=mm_vision_select_layer)
|
||||
|
||||
self.generation_cfg = generation_cfg
|
||||
|
||||
if hasattr(self, 'register_load_state_dict_post_hook'):
|
||||
self.register_load_state_dict_post_hook(self._load_ckpt_hook)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
images: torch.Tensor,
|
||||
data_samples: Optional[List[DataSample]] = None,
|
||||
mode: str = 'loss',
|
||||
):
|
||||
"""The unified entry for a forward process in both training and test.
|
||||
The method should accept only one mode "loss":
|
||||
|
||||
- "loss": Forward and return a dict of losses according to the given
|
||||
inputs and data samples.
|
||||
|
||||
Note that this method doesn't handle neither back propagation nor
|
||||
optimizer updating, which are done in the :meth:`train_step`.
|
||||
|
||||
Args:
|
||||
images (torch.Tensor): The input image tensor with different ndim
|
||||
according to the inputs.
|
||||
data_samples (List[DataSample], optional): The annotation
|
||||
data of every samples. It's required if ``mode="loss"``.
|
||||
Defaults to None.
|
||||
mode (str): Return what kind of value. Defaults to 'loss'.
|
||||
|
||||
Returns:
|
||||
The return type depends on ``mode``.
|
||||
- If ``mode="loss"``, return a dict of tensor.
|
||||
"""
|
||||
|
||||
if mode == 'loss':
|
||||
return self.loss(images, data_samples)
|
||||
elif mode == 'predict':
|
||||
return self.predict(images, data_samples)
|
||||
else:
|
||||
raise RuntimeError(f'Invalid mode "{mode}".')
|
||||
|
||||
def predict(self,
|
||||
images: torch.Tensor,
|
||||
data_samples: Optional[List[DataSample]] = None,
|
||||
**generation_cfg):
|
||||
"""Predict generation results from a batch of inputs.
|
||||
|
||||
Args:
|
||||
images (torch.Tensor): For zero-shot, the input images tensor is
|
||||
with shape (B, C, H, W), for few-shot, which is
|
||||
(B, T_img, C, H, W) in general. Images in the same chunk
|
||||
are collated along T_img. Video data is not supported yet.
|
||||
data_samples (List[DataSample], optional): The annotation
|
||||
data of every samples. Defaults to None.
|
||||
**generation_cfg: Other keyword arguments accepted by the
|
||||
``generate`` method of :attr:`lang_encoder`.
|
||||
|
||||
Returns:
|
||||
List[DataSample]: Return list of data samples.
|
||||
"""
|
||||
# generation_cfg in prediction should be dominant
|
||||
generation_cfg = {**self.generation_cfg, **generation_cfg}
|
||||
|
||||
input_text = self.preprocess_text(data_samples, device=images.device)
|
||||
|
||||
outputs = self.model.generate(
|
||||
input_text.input_ids,
|
||||
attention_mask=input_text.attention_mask,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
images=images,
|
||||
**generation_cfg)
|
||||
|
||||
# remove prefix
|
||||
outputs = outputs[:, len(input_text.input_ids[0]):]
|
||||
|
||||
return self.post_process(outputs, data_samples)
|
||||
|
||||
def preprocess_text(self, data_samples: List[DataSample],
|
||||
device: torch.device) -> List[DataSample]:
|
||||
"""Preprocess text in advance before fed into language model.
|
||||
|
||||
Args:
|
||||
data_samples (List[DataSample]): The annotation
|
||||
data of every samples. Defaults to None.
|
||||
device (torch.device): Device for text to put on.
|
||||
|
||||
Returns:
|
||||
List[DataSample]: Return list of data samples.
|
||||
"""
|
||||
prompts = []
|
||||
for sample in data_samples:
|
||||
final_prompt = self.prompt_tmpl.format(**sample.to_dict())
|
||||
prompts.append(final_prompt)
|
||||
|
||||
self.tokenizer.padding_side = 'left'
|
||||
input_text = self.tokenizer(
|
||||
prompts,
|
||||
padding='longest',
|
||||
truncation=True,
|
||||
return_tensors='pt',
|
||||
max_length=2000,
|
||||
).to(device)
|
||||
return input_text
|
||||
|
||||
def post_process(
|
||||
self, outputs: torch.Tensor,
|
||||
data_samples: Optional[List[DataSample]]) -> List[DataSample]:
|
||||
"""Perform post process for outputs for different task.
|
||||
|
||||
Args:
|
||||
outputs (torch.Tensor): The generated outputs.
|
||||
data_samples (List[DataSample], optional): The annotation
|
||||
data of every samples.
|
||||
|
||||
Returns:
|
||||
List[DataSample]: Return list of data samples.
|
||||
"""
|
||||
outputs = self.tokenizer.batch_decode(
|
||||
outputs, skip_special_tokens=True)
|
||||
|
||||
if data_samples is None:
|
||||
data_samples = [DataSample() for _ in range(len(outputs))]
|
||||
|
||||
for output, data_sample in zip(outputs, data_samples):
|
||||
# remove text pattern
|
||||
if self.task == 'caption':
|
||||
data_sample.pred_caption = output
|
||||
elif self.task == 'vqa':
|
||||
data_sample.pred_answer = output
|
||||
|
||||
return data_samples
|
||||
|
||||
@staticmethod
|
||||
def _load_ckpt_hook(module, incompatible_keys):
|
||||
"""Avoid warning missing keys except lang_encoder keys."""
|
||||
for key in list(incompatible_keys.missing_keys):
|
||||
if re.match('model.vision_tower', key):
|
||||
incompatible_keys.missing_keys.remove(key)
|
|
@ -0,0 +1,238 @@
|
|||
# Copyright 2023 Haotian Liu
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
DEFAULT_IMAGE_TOKEN = '<image>'
|
||||
DEFAULT_IMAGE_PATCH_TOKEN = '<im_patch>'
|
||||
DEFAULT_IM_START_TOKEN = '<im_start>'
|
||||
DEFAULT_IM_END_TOKEN = '<im_end>'
|
||||
|
||||
|
||||
class LlavaLlamaForCausalLM(PreTrainedModel):
|
||||
|
||||
def __init__(self,
|
||||
vision_encoder,
|
||||
lang_encoder,
|
||||
mm_hidden_size,
|
||||
use_im_start_end=True,
|
||||
use_mm_proj=True,
|
||||
im_start_token: Optional[int] = None,
|
||||
im_end_token: Optional[int] = None,
|
||||
im_patch_token: Optional[int] = None,
|
||||
mm_vision_select_layer: int = -1):
|
||||
super().__init__(lang_encoder.config)
|
||||
self.vision_tower = vision_encoder
|
||||
self.lang_encoder = lang_encoder
|
||||
|
||||
self.use_im_start_end = use_im_start_end
|
||||
self.im_start_token = im_start_token
|
||||
self.im_end_token = im_end_token
|
||||
self.im_patch_token = im_patch_token
|
||||
self.mm_hidden_size = mm_hidden_size
|
||||
self.mm_vision_select_layer = mm_vision_select_layer
|
||||
self.lang_hidden_size = lang_encoder.config.hidden_size
|
||||
|
||||
if use_mm_proj and not hasattr(lang_encoder.model, 'mm_projector'):
|
||||
mm_projector = nn.Linear(self.mm_hidden_size,
|
||||
self.lang_hidden_size)
|
||||
self.lang_encoder.model.add_module('mm_projector', mm_projector)
|
||||
elif not use_mm_proj:
|
||||
self.lang_encoder.model.add_module('mm_projector', nn.Identity())
|
||||
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
images: Optional[torch.FloatTensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None else
|
||||
self.config.output_attentions)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else
|
||||
self.config.output_hidden_states)
|
||||
return_dict = (
|
||||
return_dict
|
||||
if return_dict is not None else self.config.use_return_dict)
|
||||
|
||||
# decoder outputs consists of
|
||||
# (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.lang_encoder.model.embed_tokens(input_ids)
|
||||
|
||||
inputs_embeds = self.forward_vision_tower(input_ids, inputs_embeds,
|
||||
images)
|
||||
|
||||
return self.lang_encoder(
|
||||
input_ids=None,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
**kwargs):
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use
|
||||
# them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {'inputs_embeds': inputs_embeds}
|
||||
else:
|
||||
model_inputs = {'input_ids': input_ids}
|
||||
|
||||
model_inputs.update({
|
||||
'past_key_values': past_key_values,
|
||||
'use_cache': kwargs.get('use_cache'),
|
||||
'attention_mask': attention_mask,
|
||||
'images': kwargs.get('images', None),
|
||||
})
|
||||
return model_inputs
|
||||
|
||||
def forward_vision_tower(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
inputs_embeds: torch.FloatTensor,
|
||||
images: Union[torch.FloatTensor, list, None] = None,
|
||||
):
|
||||
if self.use_im_start_end:
|
||||
assert self.im_start_token is not None
|
||||
assert self.im_end_token is not None
|
||||
if images is not None:
|
||||
assert self.im_patch_token is not None
|
||||
|
||||
if self.vision_tower is None or images is None or (
|
||||
input_ids.shape[1] == 1 and not self.training):
|
||||
return inputs_embeds
|
||||
|
||||
with torch.no_grad():
|
||||
if isinstance(images, (list, tuple)):
|
||||
# variable length images
|
||||
image_features = []
|
||||
for image in images:
|
||||
feats = self.vision_tower(image.unsqueeze(0))
|
||||
image_feature = feats[self.mm_vision_select_layer][:, 1:]
|
||||
image_features.append(image_feature)
|
||||
else:
|
||||
feats = self.vision_tower(images)
|
||||
image_features = feats[self.mm_vision_select_layer][:, 1:]
|
||||
|
||||
mm_projector = self.lang_encoder.model.mm_projector
|
||||
if isinstance(images, (list, tuple)):
|
||||
image_features = [
|
||||
mm_projector(image_feature)[0]
|
||||
for image_feature in image_features
|
||||
]
|
||||
else:
|
||||
image_features = mm_projector(image_features)
|
||||
|
||||
dummy_image_features = torch.zeros(
|
||||
256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
|
||||
dummy_image_features = mm_projector(dummy_image_features)
|
||||
|
||||
new_input_embeds = []
|
||||
cur_image_idx = 0
|
||||
for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
|
||||
if (cur_input_ids != self.im_patch_token).all():
|
||||
# multimodal LLM, but the current sample is not multimodal
|
||||
cur_input_embeds = cur_input_embeds + (
|
||||
0. * dummy_image_features).sum()
|
||||
new_input_embeds.append(cur_input_embeds)
|
||||
cur_image_idx += 1
|
||||
continue
|
||||
if self.use_im_start_end:
|
||||
cur_image_features = image_features[cur_image_idx]
|
||||
num_patches = cur_image_features.shape[0]
|
||||
if (cur_input_ids == self.im_start_token).sum() != (
|
||||
cur_input_ids == self.im_end_token).sum():
|
||||
raise ValueError('The number of image start tokens and '
|
||||
'image end tokens should be the same.')
|
||||
image_start_tokens = torch.where(
|
||||
cur_input_ids == self.im_start_token)[0]
|
||||
for image_start_token_pos in image_start_tokens:
|
||||
cur_image_features = image_features[cur_image_idx].to(
|
||||
device=cur_input_embeds.device)
|
||||
num_patches = cur_image_features.shape[0]
|
||||
if cur_input_ids[image_start_token_pos + num_patches +
|
||||
1] != self.im_end_token:
|
||||
raise ValueError('The image end token should follow '
|
||||
'the image start token.')
|
||||
cur_new_input_embeds = torch.cat(
|
||||
(cur_input_embeds[:image_start_token_pos + 1],
|
||||
cur_image_features,
|
||||
cur_input_embeds[image_start_token_pos + num_patches +
|
||||
1:]),
|
||||
dim=0)
|
||||
cur_image_idx += 1
|
||||
new_input_embeds.append(cur_new_input_embeds)
|
||||
else:
|
||||
cur_image_features = image_features[cur_image_idx]
|
||||
num_patches = cur_image_features.shape[0]
|
||||
if (cur_input_ids == self.im_patch_token).sum() != num_patches:
|
||||
print(f'Debug: num_patches: {num_patches}')
|
||||
raise ValueError(
|
||||
'The number of image patch tokens should '
|
||||
'be the same as the number of image patches.')
|
||||
masked_indices = torch.where(
|
||||
cur_input_ids == self.im_patch_token)[0]
|
||||
mask_index_start = masked_indices[0]
|
||||
if (masked_indices != torch.arange(
|
||||
mask_index_start,
|
||||
mask_index_start + num_patches,
|
||||
device=masked_indices.device,
|
||||
dtype=masked_indices.dtype)).any():
|
||||
raise ValueError(
|
||||
'The image patch tokens should be consecutive.')
|
||||
cur_new_input_embeds = torch.cat(
|
||||
(cur_input_embeds[:mask_index_start], cur_image_features,
|
||||
cur_input_embeds[mask_index_start + num_patches:]),
|
||||
dim=0)
|
||||
new_input_embeds.append(cur_new_input_embeds)
|
||||
cur_image_idx += 1
|
||||
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (tuple(
|
||||
past_state.index_select(0, beam_idx)
|
||||
for past_state in layer_past), )
|
||||
return reordered_past
|
|
@ -79,4 +79,5 @@ Import:
|
|||
- configs/itpn/metafile.yml
|
||||
- configs/hivit/metafile.yml
|
||||
- configs/minigpt4/metafile.yml
|
||||
- configs/llava/metafile.yml
|
||||
- configs/otter/metafile.yml
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
from collections import OrderedDict
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers.modeling_utils import load_state_dict
|
||||
|
||||
prog_description = """\
|
||||
Merge Llava delta weights and original weights,
|
||||
and save as MMPreTrain checkpoint.
|
||||
"""
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description=prog_description)
|
||||
parser.add_argument(
|
||||
'src_path', type=str, help='The original checkpoint dir')
|
||||
parser.add_argument(
|
||||
'delta_path', type=str, help='The delta checkpoint dir')
|
||||
parser.add_argument('dst_path', type=str, help='The saved checkpoint path')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def load_checkpoint(path: Path):
|
||||
if path.is_file():
|
||||
return torch.load(path)
|
||||
|
||||
state_dict = OrderedDict()
|
||||
for ckpt in chain(
|
||||
path.rglob('*.bin'), path.rglob('*.pth'),
|
||||
path.rglob('*.safetensors')):
|
||||
state_dict.update(load_state_dict(str(ckpt)))
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
if Path(args.src_path).exists():
|
||||
src_path = Path(args.src_path)
|
||||
else:
|
||||
src_path = Path(snapshot_download(args.src_path))
|
||||
src_state_dict = load_checkpoint(src_path)
|
||||
|
||||
if Path(args.delta_path).exists():
|
||||
delta_path = Path(args.delta_path)
|
||||
else:
|
||||
delta_path = Path(snapshot_download(args.delta_path))
|
||||
delta_state_dict = load_checkpoint(delta_path)
|
||||
|
||||
merged_state_dict = OrderedDict()
|
||||
for k, v in src_state_dict.items():
|
||||
if k in delta_state_dict:
|
||||
delta_v = delta_state_dict.pop(k)
|
||||
if k in ['model.embed_tokens.weight', 'lm_head.weight']:
|
||||
h, w = v.shape[:2]
|
||||
delta_v[:h, :w] += v
|
||||
v = delta_v
|
||||
else:
|
||||
v += delta_v
|
||||
merged_state_dict['model.lang_encoder.' + k] = v
|
||||
|
||||
for k, v in delta_state_dict.items():
|
||||
merged_state_dict['model.lang_encoder.' + k] = v
|
||||
|
||||
torch.save(merged_state_dict, args.dst_path)
|
||||
print('Done!!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue