parent
ed5924b6fe
commit
3022f9af7b
|
@ -16,46 +16,28 @@ Instruction tuning large language models (LLMs) using machine-generated instruct
|
|||
|
||||
<!-- [TABS-BEGIN] -->
|
||||
|
||||
**Prepare the checkpoint**
|
||||
|
||||
According to the license of LLaMA, we cannot provide the merged checkpoint directly. Please use the below
|
||||
script to download and get the merged the checkpoint.
|
||||
|
||||
```shell
|
||||
python tools/model_converters/llava-delta2mmpre.py huggyllama/llama-7b liuhaotian/LLaVA-Lightning-7B-delta-v1-1 ./LLaVA-Lightning-7B-delta-v1-1.pth
|
||||
```
|
||||
|
||||
**Use the model**
|
||||
|
||||
```python
|
||||
import torch
|
||||
from mmpretrain import get_model, inference_model
|
||||
|
||||
model = get_model('llava-7b-v1_caption', pretrained='MERGED_CHECKPOINT_PATH', device='cuda')
|
||||
out = inference_model(model, 'demo/cat-dog.png')
|
||||
out = inference_model('llava-7b-v1_caption', 'demo/cat-dog.png', device='cuda')
|
||||
print(out)
|
||||
# {'pred_caption': 'In the image, there are two cats sitting on a blanket.'}
|
||||
```
|
||||
|
||||
**Test Command**
|
||||
|
||||
Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).
|
||||
|
||||
Test:
|
||||
|
||||
```shell
|
||||
python tools/test.py configs/llava/llava-7b-v1_caption.py MERGED_CHECKPOINT_PATH
|
||||
```
|
||||
|
||||
<!-- [TABS-END] -->
|
||||
|
||||
## Models and results
|
||||
|
||||
### Image Caption on COCO
|
||||
|
||||
| Model | Params (M) | BLEU-4 | CIDER | Config | Download |
|
||||
| :-------------------- | :--------: | :------: | :------: | :------------------------------: | :--------------------: |
|
||||
| `llava-7b-v1_caption` | 7045.82 | Upcoming | Upcoming | [config](llava-7b-v1_caption.py) | See the above tutorial |
|
||||
| Model | Params (M) | Config | Download |
|
||||
| :---------------------- | :--------: | :--------------------------------: | :-------------------------------------------------------------------------------------------------------------: |
|
||||
| `llava-7b-v1_caption` | 7045.82 | [config](llava-7b-v1_caption.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1_liuhaotian_20231025-c9e119b6.pth) |
|
||||
| `llava-7b-v1.5_caption` | 7062.90 | [config](llava-7b-v1.5_caption.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth) |
|
||||
| `llava-7b-v1.5_vqa` | 7062.90 | [config](llava-7b-v1.5_vqa.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth) |
|
||||
|
||||
## Citation
|
||||
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
_base_ = '../_base_/default_runtime.py'
|
||||
|
||||
meta_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." # noqa: E501
|
||||
image_size = 336
|
||||
prompt_tmpl = f'''{meta_prompt} User: <image>
|
||||
Describe the image in detail. ASSISTANT:'''
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='Llava',
|
||||
tokenizer=dict(
|
||||
type='AutoTokenizer', name_or_path='liuhaotian/llava-v1.5-7b'),
|
||||
vision_encoder=dict(
|
||||
type='VisionTransformer',
|
||||
arch='l',
|
||||
patch_size=14,
|
||||
img_size=image_size,
|
||||
pre_norm=True,
|
||||
norm_cfg=dict(type='LN', eps=1e-5),
|
||||
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')),
|
||||
final_norm=False,
|
||||
out_type='raw',
|
||||
pretrained='https://download.openmmlab.com/mmclassification/v0/clip/'
|
||||
'vit-large-p14_clip-openai-pre_336px_20231025-fb1315ed.pth',
|
||||
),
|
||||
mm_hidden_size=1024,
|
||||
use_im_patch=False,
|
||||
use_im_start_end=False,
|
||||
mm_proj_depth=2,
|
||||
lang_encoder=dict(
|
||||
type='AutoModelForCausalLM',
|
||||
name_or_path='huggyllama/llama-7b',
|
||||
),
|
||||
task='caption',
|
||||
prompt_tmpl=prompt_tmpl,
|
||||
generation_cfg=dict(num_beams=3, max_new_tokens=50, length_penalty=-1.0),
|
||||
)
|
||||
|
||||
# data settings
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[122.770938, 116.7460125, 104.09373615],
|
||||
std=[68.5005327, 66.6321579, 70.32316305],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
scale=(image_size, image_size),
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='PackInputs', meta_keys=['image_id']),
|
||||
]
|
||||
|
||||
test_dataloader = dict(
|
||||
batch_size=8,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type='COCOCaption',
|
||||
data_root='data/coco',
|
||||
ann_file='annotations/coco_karpathy_val.json',
|
||||
pipeline=test_pipeline,
|
||||
),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
test_evaluator = dict(
|
||||
type='COCOCaption',
|
||||
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
|
||||
)
|
||||
|
||||
# schedule settings
|
||||
test_cfg = dict()
|
|
@ -0,0 +1,76 @@
|
|||
_base_ = '../_base_/default_runtime.py'
|
||||
|
||||
meta_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." # noqa: E501
|
||||
image_size = 336
|
||||
prompt_tmpl = f'''{meta_prompt} User: <image>
|
||||
{{question}} ASSISTANT:'''
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='Llava',
|
||||
tokenizer=dict(
|
||||
type='AutoTokenizer', name_or_path='liuhaotian/llava-v1.5-7b'),
|
||||
vision_encoder=dict(
|
||||
type='VisionTransformer',
|
||||
arch='l',
|
||||
patch_size=14,
|
||||
img_size=image_size,
|
||||
pre_norm=True,
|
||||
norm_cfg=dict(type='LN', eps=1e-5),
|
||||
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')),
|
||||
final_norm=False,
|
||||
out_type='raw',
|
||||
pretrained='https://download.openmmlab.com/mmclassification/v0/clip/'
|
||||
'vit-large-p14_clip-openai-pre_336px_20231025-fb1315ed.pth',
|
||||
),
|
||||
mm_hidden_size=1024,
|
||||
use_im_patch=False,
|
||||
use_im_start_end=False,
|
||||
mm_proj_depth=2,
|
||||
lang_encoder=dict(
|
||||
type='AutoModelForCausalLM',
|
||||
name_or_path='huggyllama/llama-7b',
|
||||
),
|
||||
task='vqa',
|
||||
prompt_tmpl=prompt_tmpl,
|
||||
generation_cfg=dict(max_new_tokens=100),
|
||||
)
|
||||
|
||||
# data settings
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[122.770938, 116.7460125, 104.09373615],
|
||||
std=[68.5005327, 66.6321579, 70.32316305],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
scale=(image_size, image_size),
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='PackInputs', meta_keys=['image_id', 'question']),
|
||||
]
|
||||
|
||||
test_dataloader = dict(
|
||||
batch_size=8,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type='COCOCaption',
|
||||
data_root='data/coco',
|
||||
ann_file='annotations/coco_karpathy_val.json',
|
||||
pipeline=test_pipeline,
|
||||
),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
test_evaluator = dict(
|
||||
type='COCOCaption',
|
||||
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
|
||||
)
|
||||
|
||||
# schedule settings
|
||||
test_cfg = dict()
|
|
@ -1,16 +1,9 @@
|
|||
_base_ = '../_base_/default_runtime.py'
|
||||
|
||||
meta_prompt = 'You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab.You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.Follow the instructions carefully and explain your answers in detail.' # noqa: E501
|
||||
im_patch_token = '<im_patch>'
|
||||
patch_size = 14
|
||||
image_size = 224
|
||||
num_patches = (image_size // patch_size)**2
|
||||
caption_prompt = ' '.join([
|
||||
meta_prompt,
|
||||
'User: a photo of\n',
|
||||
im_patch_token * num_patches,
|
||||
'ASSISTANT:',
|
||||
])
|
||||
prompt_tmpl = f'''{meta_prompt} User: <im_start><image><im_end>
|
||||
Describe the image in detail. ASSISTANT:'''
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
|
@ -22,6 +15,7 @@ model = dict(
|
|||
type='VisionTransformer',
|
||||
arch='l',
|
||||
patch_size=14,
|
||||
img_size=image_size,
|
||||
pre_norm=True,
|
||||
norm_cfg=dict(type='LN', eps=1e-5),
|
||||
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')),
|
||||
|
@ -32,15 +26,16 @@ model = dict(
|
|||
'vit-large-p14_clip-openai-pre_3rdparty_20230517-95e2af0b.pth'),
|
||||
),
|
||||
mm_hidden_size=1024,
|
||||
use_im_start_end=False,
|
||||
use_mm_proj=True,
|
||||
use_im_patch=False,
|
||||
use_im_start_end=True,
|
||||
mm_proj_depth=1,
|
||||
lang_encoder=dict(
|
||||
type='AutoModelForCausalLM',
|
||||
name_or_path='huggyllama/llama-7b',
|
||||
),
|
||||
task='caption',
|
||||
prompt_tmpl=caption_prompt,
|
||||
generation_cfg=dict(num_beams=3, max_new_tokens=20, length_penalty=-2.0),
|
||||
prompt_tmpl=prompt_tmpl,
|
||||
generation_cfg=dict(max_new_tokens=50),
|
||||
)
|
||||
|
||||
# data settings
|
||||
|
|
|
@ -21,5 +21,31 @@ Models:
|
|||
Metrics:
|
||||
BLEU-4: null
|
||||
CIDER: null
|
||||
Weights: null
|
||||
Weights: https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1_liuhaotian_20231025-c9e119b6.pth
|
||||
Config: configs/llava/llava-7b-v1_caption.py
|
||||
- Name: llava-7b-v1.5_caption
|
||||
Metadata:
|
||||
FLOPs: null
|
||||
Parameters: 7062900736
|
||||
In Collection: LLaVA
|
||||
Results:
|
||||
- Task: Image Caption
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
BLEU-4: null
|
||||
CIDER: null
|
||||
Weights: https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth
|
||||
Config: configs/llava/llava-7b-v1.5_caption.py
|
||||
- Name: llava-7b-v1.5_vqa
|
||||
Metadata:
|
||||
FLOPs: null
|
||||
Parameters: 7062900736
|
||||
In Collection: LLaVA
|
||||
Results:
|
||||
- Task: Visual Question Answering
|
||||
Dataset: COCO
|
||||
Metrics:
|
||||
BLEU-4: null
|
||||
CIDER: null
|
||||
Weights: https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth
|
||||
Config: configs/llava/llava-7b-v1.5_vqa.py
|
||||
|
|
|
@ -24,8 +24,8 @@ class Llava(BaseModel):
|
|||
use_im_start_end (bool): Whether to use the im_start and im_end tokens
|
||||
mm_vision_select_layer (int): The index from vision encoder output.
|
||||
Defaults to -1.
|
||||
use_mm_proj (bool): Whether to enable multi-modal projection.
|
||||
Defaults to True.
|
||||
mm_proj_depth (int): The number of linear layers for multi-modal
|
||||
projection. Defaults to 1.
|
||||
load_lang_pretrained (bool): Whether to load the pretrained model of
|
||||
language encoder. Defaults to False.
|
||||
generation_cfg (dict): The extra generation config, accept the keyword
|
||||
|
@ -51,9 +51,10 @@ class Llava(BaseModel):
|
|||
mm_hidden_size: int,
|
||||
prompt_tmpl: str,
|
||||
task: str = 'caption',
|
||||
use_im_patch: bool = True,
|
||||
use_im_start_end: bool = False,
|
||||
mm_vision_select_layer: int = -1,
|
||||
use_mm_proj: bool = True,
|
||||
mm_proj_depth: int = 1,
|
||||
generation_cfg: dict = dict(),
|
||||
load_lang_pretrained: bool = False,
|
||||
data_preprocessor: Optional[dict] = None,
|
||||
|
@ -75,7 +76,9 @@ class Llava(BaseModel):
|
|||
# init tokenizer
|
||||
self.tokenizer = TOKENIZER.build(tokenizer)
|
||||
# add Llava special tokens to the tokenizer
|
||||
self.tokenizer.add_tokens([self.im_patch_token], special_tokens=True)
|
||||
if use_im_patch:
|
||||
self.tokenizer.add_tokens([self.im_patch_token],
|
||||
special_tokens=True)
|
||||
if use_im_start_end:
|
||||
self.tokenizer.add_tokens([self.im_start_token, self.im_end_token],
|
||||
special_tokens=True)
|
||||
|
@ -108,14 +111,12 @@ class Llava(BaseModel):
|
|||
vision_encoder=vision_encoder,
|
||||
lang_encoder=lang_encoder,
|
||||
mm_hidden_size=mm_hidden_size,
|
||||
use_mm_proj=use_mm_proj,
|
||||
mm_proj_depth=mm_proj_depth,
|
||||
use_im_start_end=use_im_start_end,
|
||||
im_start_token=self.tokenizer.convert_tokens_to_ids(
|
||||
self.im_start_token),
|
||||
im_end_token=self.tokenizer.convert_tokens_to_ids(
|
||||
self.im_end_token),
|
||||
im_patch_token=self.tokenizer.convert_tokens_to_ids(
|
||||
self.im_patch_token),
|
||||
mm_vision_select_layer=mm_vision_select_layer)
|
||||
|
||||
self.generation_cfg = generation_cfg
|
||||
|
@ -207,16 +208,24 @@ class Llava(BaseModel):
|
|||
Returns:
|
||||
List[DataSample]: Return list of data samples.
|
||||
"""
|
||||
prompts = []
|
||||
tokens = []
|
||||
for sample in data_samples:
|
||||
final_prompt = self.prompt_tmpl.format(**sample.to_dict())
|
||||
prompts.append(final_prompt)
|
||||
prompt = self.prompt_tmpl.format(**sample.to_dict())
|
||||
input_ids = []
|
||||
while '<image>' in prompt:
|
||||
prefix, _, prompt = prompt.partition('<image>')
|
||||
input_ids.extend(
|
||||
self.tokenizer(prefix, add_special_tokens=False).input_ids)
|
||||
input_ids.append(-200)
|
||||
if prompt:
|
||||
input_ids.extend(
|
||||
self.tokenizer(prompt, add_special_tokens=False).input_ids)
|
||||
tokens.append(dict(input_ids=input_ids))
|
||||
|
||||
self.tokenizer.padding_side = 'left'
|
||||
input_text = self.tokenizer(
|
||||
prompts,
|
||||
input_text = self.tokenizer.pad(
|
||||
tokens,
|
||||
padding='longest',
|
||||
truncation=True,
|
||||
return_tensors='pt',
|
||||
max_length=2000,
|
||||
).to(device)
|
||||
|
|
|
@ -31,10 +31,10 @@ class LlavaLlamaForCausalLM(PreTrainedModel):
|
|||
lang_encoder,
|
||||
mm_hidden_size,
|
||||
use_im_start_end=True,
|
||||
use_mm_proj=True,
|
||||
mm_proj_depth=1,
|
||||
im_start_token: Optional[int] = None,
|
||||
im_end_token: Optional[int] = None,
|
||||
im_patch_token: Optional[int] = None,
|
||||
im_token_index: int = -200,
|
||||
mm_vision_select_layer: int = -1):
|
||||
super().__init__(lang_encoder.config)
|
||||
self.vision_tower = vision_encoder
|
||||
|
@ -43,16 +43,26 @@ class LlavaLlamaForCausalLM(PreTrainedModel):
|
|||
self.use_im_start_end = use_im_start_end
|
||||
self.im_start_token = im_start_token
|
||||
self.im_end_token = im_end_token
|
||||
self.im_patch_token = im_patch_token
|
||||
self.mm_hidden_size = mm_hidden_size
|
||||
self.mm_vision_select_layer = mm_vision_select_layer
|
||||
self.im_token_index = im_token_index
|
||||
self.lang_hidden_size = lang_encoder.config.hidden_size
|
||||
|
||||
if use_mm_proj and not hasattr(lang_encoder.model, 'mm_projector'):
|
||||
if mm_proj_depth == 1:
|
||||
# Llava V1
|
||||
mm_projector = nn.Linear(self.mm_hidden_size,
|
||||
self.lang_hidden_size)
|
||||
self.lang_encoder.model.add_module('mm_projector', mm_projector)
|
||||
elif not use_mm_proj:
|
||||
elif mm_proj_depth > 1:
|
||||
# Llava V1.5
|
||||
modules = [nn.Linear(self.mm_hidden_size, self.lang_hidden_size)]
|
||||
for _ in range(1, mm_proj_depth):
|
||||
modules.append(nn.GELU())
|
||||
modules.append(
|
||||
nn.Linear(self.lang_hidden_size, self.lang_hidden_size))
|
||||
mm_projector = nn.Sequential(*modules)
|
||||
self.lang_encoder.model.add_module('mm_projector', mm_projector)
|
||||
elif mm_proj_depth == 0:
|
||||
self.lang_encoder.model.add_module('mm_projector', nn.Identity())
|
||||
|
||||
self.post_init()
|
||||
|
@ -80,16 +90,12 @@ class LlavaLlamaForCausalLM(PreTrainedModel):
|
|||
return_dict
|
||||
if return_dict is not None else self.config.use_return_dict)
|
||||
|
||||
# decoder outputs consists of
|
||||
# (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.lang_encoder.model.embed_tokens(input_ids)
|
||||
|
||||
inputs_embeds = self.forward_vision_tower(input_ids, inputs_embeds,
|
||||
images)
|
||||
(input_ids, attention_mask, past_key_values, inputs_embeds,
|
||||
labels) = self.forward_vision_tower(input_ids, attention_mask,
|
||||
past_key_values, labels, images)
|
||||
|
||||
return self.lang_encoder(
|
||||
input_ids=None,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
@ -127,106 +133,93 @@ class LlavaLlamaForCausalLM(PreTrainedModel):
|
|||
def forward_vision_tower(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
inputs_embeds: torch.FloatTensor,
|
||||
images: Union[torch.FloatTensor, list, None] = None,
|
||||
attention_mask: torch.LongTensor,
|
||||
past_key_values: torch.FloatTensor,
|
||||
labels: torch.LongTensor,
|
||||
images: Union[torch.FloatTensor, None] = None,
|
||||
):
|
||||
if self.use_im_start_end:
|
||||
assert self.im_start_token is not None
|
||||
assert self.im_end_token is not None
|
||||
if images is not None:
|
||||
assert self.im_patch_token is not None
|
||||
|
||||
if self.vision_tower is None or images is None or (
|
||||
input_ids.shape[1] == 1 and not self.training):
|
||||
return inputs_embeds
|
||||
if self.vision_tower is None or images is None or input_ids.shape[
|
||||
1] == 1:
|
||||
if (past_key_values is not None and self.vision_tower is not None
|
||||
and images is not None and input_ids.shape[1] == 1):
|
||||
attention_mask = torch.ones(
|
||||
(attention_mask.shape[0],
|
||||
past_key_values[-1][-1].shape[-2] + 1),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device)
|
||||
return input_ids, attention_mask, past_key_values, None, labels
|
||||
|
||||
with torch.no_grad():
|
||||
if isinstance(images, (list, tuple)):
|
||||
# variable length images
|
||||
image_features = []
|
||||
for image in images:
|
||||
feats = self.vision_tower(image.unsqueeze(0))
|
||||
image_feature = feats[self.mm_vision_select_layer][:, 1:]
|
||||
image_features.append(image_feature)
|
||||
else:
|
||||
feats = self.vision_tower(images)
|
||||
image_features = feats[self.mm_vision_select_layer][:, 1:]
|
||||
# TODO: support variable number of images (single now)
|
||||
feats = self.vision_tower(images)
|
||||
image_features = feats[-1][:, 1:]
|
||||
|
||||
mm_projector = self.lang_encoder.model.mm_projector
|
||||
if isinstance(images, (list, tuple)):
|
||||
image_features = [
|
||||
mm_projector(image_feature)[0]
|
||||
for image_feature in image_features
|
||||
]
|
||||
else:
|
||||
image_features = mm_projector(image_features)
|
||||
|
||||
dummy_image_features = torch.zeros(
|
||||
256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
|
||||
dummy_image_features = mm_projector(dummy_image_features)
|
||||
image_features = self.lang_encoder.model.mm_projector(image_features)
|
||||
|
||||
new_input_embeds = []
|
||||
cur_image_idx = 0
|
||||
for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
|
||||
if (cur_input_ids != self.im_patch_token).all():
|
||||
# multimodal LLM, but the current sample is not multimodal
|
||||
cur_input_embeds = cur_input_embeds + (
|
||||
0. * dummy_image_features).sum()
|
||||
new_input_embeds.append(cur_input_embeds)
|
||||
cur_image_idx += 1
|
||||
continue
|
||||
if self.use_im_start_end:
|
||||
cur_image_features = image_features[cur_image_idx]
|
||||
num_patches = cur_image_features.shape[0]
|
||||
if (cur_input_ids == self.im_start_token).sum() != (
|
||||
cur_input_ids == self.im_end_token).sum():
|
||||
raise ValueError('The number of image start tokens and '
|
||||
'image end tokens should be the same.')
|
||||
image_start_tokens = torch.where(
|
||||
cur_input_ids == self.im_start_token)[0]
|
||||
for image_start_token_pos in image_start_tokens:
|
||||
cur_image_features = image_features[cur_image_idx].to(
|
||||
device=cur_input_embeds.device)
|
||||
num_patches = cur_image_features.shape[0]
|
||||
if cur_input_ids[image_start_token_pos + num_patches +
|
||||
1] != self.im_end_token:
|
||||
raise ValueError('The image end token should follow '
|
||||
'the image start token.')
|
||||
cur_new_input_embeds = torch.cat(
|
||||
(cur_input_embeds[:image_start_token_pos + 1],
|
||||
cur_image_features,
|
||||
cur_input_embeds[image_start_token_pos + num_patches +
|
||||
1:]),
|
||||
dim=0)
|
||||
cur_image_idx += 1
|
||||
new_input_embeds.append(cur_new_input_embeds)
|
||||
else:
|
||||
cur_image_features = image_features[cur_image_idx]
|
||||
num_patches = cur_image_features.shape[0]
|
||||
if (cur_input_ids == self.im_patch_token).sum() != num_patches:
|
||||
print(f'Debug: num_patches: {num_patches}')
|
||||
raise ValueError(
|
||||
'The number of image patch tokens should '
|
||||
'be the same as the number of image patches.')
|
||||
masked_indices = torch.where(
|
||||
cur_input_ids == self.im_patch_token)[0]
|
||||
mask_index_start = masked_indices[0]
|
||||
if (masked_indices != torch.arange(
|
||||
mask_index_start,
|
||||
mask_index_start + num_patches,
|
||||
device=masked_indices.device,
|
||||
dtype=masked_indices.dtype)).any():
|
||||
raise ValueError(
|
||||
'The image patch tokens should be consecutive.')
|
||||
cur_new_input_embeds = torch.cat(
|
||||
(cur_input_embeds[:mask_index_start], cur_image_features,
|
||||
cur_input_embeds[mask_index_start + num_patches:]),
|
||||
dim=0)
|
||||
new_input_embeds.append(cur_new_input_embeds)
|
||||
cur_image_idx += 1
|
||||
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
||||
new_labels = [] if labels is not None else None
|
||||
new_attn_mask = [] if attention_mask is not None else None
|
||||
for batch_idx, cur_input_ids in enumerate(input_ids):
|
||||
cur_img = image_features[batch_idx]
|
||||
|
||||
return inputs_embeds
|
||||
if (cur_input_ids != self.im_token_index).all():
|
||||
# multimodal LLM, but the current sample is not multimodal
|
||||
new_input_embeds.append(self.embed_tokens(cur_input_ids))
|
||||
if labels is not None:
|
||||
new_labels.append(labels[batch_idx])
|
||||
if attention_mask is not None:
|
||||
new_attn_mask.append(attention_mask[batch_idx])
|
||||
continue
|
||||
|
||||
img_idx = torch.where(cur_input_ids == self.im_token_index)[0][0]
|
||||
if self.use_im_start_end:
|
||||
cur_new_input_embeds = torch.cat(
|
||||
[
|
||||
self.embed_tokens(cur_input_ids[:img_idx - 1]),
|
||||
self.embed_tokens(cur_input_ids[img_idx - 1:img_idx]),
|
||||
cur_img,
|
||||
self.embed_tokens(
|
||||
cur_input_ids[img_idx + 1:img_idx + 2]),
|
||||
self.embed_tokens(cur_input_ids[img_idx + 2:]),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
else:
|
||||
cur_new_input_embeds = torch.cat(
|
||||
[
|
||||
self.embed_tokens(cur_input_ids[:img_idx]),
|
||||
cur_img,
|
||||
self.embed_tokens(cur_input_ids[img_idx + 1:]),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
new_input_embeds.append(cur_new_input_embeds)
|
||||
|
||||
if labels is not None:
|
||||
cur_new_labels = torch.cat([
|
||||
labels[batch_idx, :img_idx],
|
||||
labels.new_full((cur_img.size(0), ), -100),
|
||||
labels[batch_idx, img_idx + 1:],
|
||||
],
|
||||
dim=0)
|
||||
new_labels.append(cur_new_labels)
|
||||
|
||||
if attention_mask is not None:
|
||||
cur_attn_mask = torch.cat([
|
||||
attention_mask[batch_idx, :img_idx],
|
||||
attention_mask.new_full((cur_img.size(0), ), True),
|
||||
attention_mask[batch_idx, img_idx + 1:],
|
||||
],
|
||||
dim=0)
|
||||
new_attn_mask.append(cur_attn_mask)
|
||||
|
||||
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
||||
if labels is not None:
|
||||
labels = torch.stack(new_labels, dim=0)
|
||||
if attention_mask is not None:
|
||||
attention_mask = torch.stack(new_attn_mask, dim=0)
|
||||
|
||||
return None, attention_mask, past_key_values, inputs_embeds, labels
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
|
@ -236,3 +229,6 @@ class LlavaLlamaForCausalLM(PreTrainedModel):
|
|||
past_state.index_select(0, beam_idx)
|
||||
for past_state in layer_past), )
|
||||
return reordered_past
|
||||
|
||||
def embed_tokens(self, input_ids):
|
||||
return self.lang_encoder.model.embed_tokens(input_ids)
|
||||
|
|
|
@ -9,23 +9,21 @@ from huggingface_hub import snapshot_download
|
|||
from transformers.modeling_utils import load_state_dict
|
||||
|
||||
prog_description = """\
|
||||
Merge Llava delta weights and original weights,
|
||||
and save as MMPreTrain checkpoint.
|
||||
Convert Llava weights and original weights.
|
||||
"""
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description=prog_description)
|
||||
parser.add_argument(
|
||||
'src_path', type=str, help='The original checkpoint dir')
|
||||
parser.add_argument(
|
||||
'delta_path', type=str, help='The delta checkpoint dir')
|
||||
parser.add_argument('dst_path', type=str, help='The saved checkpoint path')
|
||||
parser.add_argument('src', type=str, help='The original checkpoint dir')
|
||||
parser.add_argument('dst', type=str, help='The saved checkpoint path')
|
||||
parser.add_argument('--delta', type=str, help='The delta checkpoint dir')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def load_checkpoint(path: Path):
|
||||
path = Path(path)
|
||||
if path.is_file():
|
||||
return torch.load(path)
|
||||
|
||||
|
@ -41,19 +39,23 @@ def load_checkpoint(path: Path):
|
|||
def main():
|
||||
args = parse_args()
|
||||
|
||||
if Path(args.src_path).exists():
|
||||
src_path = Path(args.src_path)
|
||||
if Path(args.src).exists():
|
||||
src_path = args.src
|
||||
else:
|
||||
src_path = Path(snapshot_download(args.src_path))
|
||||
src_path = snapshot_download(
|
||||
args.src, allow_patterns='pytorch_model*.bin')
|
||||
src_state_dict = load_checkpoint(src_path)
|
||||
|
||||
if Path(args.delta_path).exists():
|
||||
delta_path = Path(args.delta_path)
|
||||
if args.delta is None:
|
||||
delta_state_dict = {}
|
||||
elif Path(args.delta).exists():
|
||||
delta_state_dict = load_checkpoint(args.delta)
|
||||
else:
|
||||
delta_path = Path(snapshot_download(args.delta_path))
|
||||
delta_state_dict = load_checkpoint(delta_path)
|
||||
delta_path = snapshot_download(
|
||||
args.delta, allow_patterns='pytorch_model*.bin')
|
||||
delta_state_dict = load_checkpoint(delta_path)
|
||||
|
||||
merged_state_dict = OrderedDict()
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in src_state_dict.items():
|
||||
if k in delta_state_dict:
|
||||
delta_v = delta_state_dict.pop(k)
|
||||
|
@ -63,12 +65,13 @@ def main():
|
|||
v = delta_v
|
||||
else:
|
||||
v += delta_v
|
||||
merged_state_dict['model.lang_encoder.' + k] = v
|
||||
if 'rotary_emb.inv_freq' not in k:
|
||||
new_state_dict['model.lang_encoder.' + k] = v
|
||||
|
||||
for k, v in delta_state_dict.items():
|
||||
merged_state_dict['model.lang_encoder.' + k] = v
|
||||
new_state_dict['model.lang_encoder.' + k] = v
|
||||
|
||||
torch.save(merged_state_dict, args.dst_path)
|
||||
torch.save(new_state_dict, args.dst)
|
||||
print('Done!!')
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue