[Feature] Add minigpt4 gradio demo and training script. (#1758)
* Add minigpt4 gradio demo * update minigpt4 demo * update minigpt4 demo (inference with float16) * update minigpt4 and some dependent files * add minigpt4 dataset for training * add training script for minigpt4 * restore files deleted by mistake * fix an error * remove useless modification * provide command line arguments for minigpt4 gradio demo and update some comments * update code * Update minigpt-4 readme --------- Co-authored-by: mzr1996 <mzr1996@163.com>pull/1853/head
parent
5c71de6b8e
commit
c0766519b1
|
@ -34,9 +34,10 @@ For Vicuna model, please refer to [MiniGPT-4 page](https://github.com/Vision-CAI
|
|||
|
||||
### Pretrained models
|
||||
|
||||
| Model | Params (M) | Flops (G) | Config | Download |
|
||||
| :------------------------------ | :--------: | :-------: | :--------------------------------------: | :------------------------------------------------------------------------------------------------------------: |
|
||||
| `minigpt-4_vicuna-7b_caption`\* | 8121.32 | N/A | [config](minigpt-4_vicuna-7b_caption.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_linear-projection_20230615-714b5f52.pth) |
|
||||
| Model | Params (M) | Flops (G) | Config | Download |
|
||||
| :------------------------------ | :--------: | :-------: | :----------------------------------------: | :----------------------------------------------------------------------------------------------------------: |
|
||||
| `minigpt-4_baichuan-7b_caption` | 8094.77 | N/A | [config](minigpt-4_baichuan-7b_caption.py) | [model](https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_baichuan7b_20231011-5dca7ed6.pth) |
|
||||
| `minigpt-4_vicuna-7b_caption`\* | 8121.32 | N/A | [config](minigpt-4_vicuna-7b_caption.py) | [model](https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_vicuna7b_20230615-714b5f52.pth) |
|
||||
|
||||
*Models with * are converted from the [official repo](https://github.com/Vision-CAIR/MiniGPT-4/tree/main). The config files of these models are only for inference. We haven't reproduce the training results.*
|
||||
|
||||
|
|
|
@ -19,8 +19,19 @@ Models:
|
|||
- Task: Image Caption
|
||||
Dataset: COCO
|
||||
Metrics: null
|
||||
Weights: https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_linear-projection_20230615-714b5f52.pth
|
||||
Weights: https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_vicuna7b_20230615-714b5f52.pth
|
||||
Config: configs/minigpt4/minigpt-4_vicuna-7b_caption.py
|
||||
Converted From:
|
||||
Weights: https://github.com/Vision-CAIR/MiniGPT-4/tree/main
|
||||
Code: https://github.com/Vision-CAIR/MiniGPT-4/tree/main
|
||||
- Name: minigpt-4_baichuan-7b_caption
|
||||
Metadata:
|
||||
FLOPs: null
|
||||
Parameters: 8094769024
|
||||
In Collection: MiniGPT4
|
||||
Results:
|
||||
- Task: Image Caption
|
||||
Dataset: COCO
|
||||
Metrics: null
|
||||
Weights: https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_baichuan7b_20231011-5dca7ed6.pth
|
||||
Config: configs/minigpt4/minigpt-4_baichuan-7b_caption.py
|
||||
|
|
|
@ -0,0 +1,190 @@
|
|||
_base_ = [
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
data_preprocessor = dict(
|
||||
type='MultiModalDataPreprocessor',
|
||||
mean=[122.770938, 116.7460125, 104.09373615],
|
||||
std=[68.5005327, 66.6321579, 70.32316305],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
# dataset settings
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
scale=(224, 224),
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='CleanCaption',
|
||||
keys='chat_content',
|
||||
remove_chars='',
|
||||
lowercase=False),
|
||||
dict(
|
||||
type='PackInputs',
|
||||
algorithm_keys=['chat_content', 'lang'],
|
||||
meta_keys=['image_id']),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=2,
|
||||
num_workers=4,
|
||||
dataset=dict(
|
||||
type='MiniGPT4Dataset',
|
||||
data_root='YOUR_DATA_DIRECTORY',
|
||||
ann_file='YOUR_DATA_FILE',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
collate_fn=dict(type='default_collate'),
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
scale=(224, 224),
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='PackInputs', meta_keys=['image_id']),
|
||||
]
|
||||
|
||||
test_evaluator = dict(
|
||||
type='COCOCaption',
|
||||
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
|
||||
)
|
||||
|
||||
test_dataloader = dict(
|
||||
batch_size=1,
|
||||
dataset=dict(
|
||||
type='COCOCaption',
|
||||
data_root='data/coco',
|
||||
ann_file='annotations/coco_karpathy_val.json',
|
||||
pipeline=test_pipeline))
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='MiniGPT4',
|
||||
vision_encoder=dict(
|
||||
type='BEiTViT',
|
||||
# eva-g without the final layer
|
||||
arch=dict(
|
||||
embed_dims=1408,
|
||||
num_layers=39,
|
||||
num_heads=16,
|
||||
feedforward_channels=6144,
|
||||
),
|
||||
img_size=224,
|
||||
patch_size=14,
|
||||
layer_scale_init_value=0.0,
|
||||
frozen_stages=39,
|
||||
use_abs_pos_emb=True,
|
||||
use_rel_pos_bias=False,
|
||||
final_norm=False,
|
||||
use_shared_rel_pos_bias=False,
|
||||
out_type='raw',
|
||||
pretrained= # noqa
|
||||
'https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_eva-g-p14_20230615-e908c021.pth' # noqa
|
||||
),
|
||||
q_former_model=dict(
|
||||
type='Qformer',
|
||||
model_style='bert-base-uncased',
|
||||
vision_model_width=1408,
|
||||
add_cross_attention=True,
|
||||
cross_attention_freq=2,
|
||||
num_query_token=32,
|
||||
pretrained= # noqa
|
||||
'https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_qformer_20230615-1dfa889c.pth' # noqa
|
||||
),
|
||||
lang_encoder=dict(
|
||||
type='AutoModelForCausalLM',
|
||||
name_or_path='baichuan-inc/baichuan-7B',
|
||||
trust_remote_code=True),
|
||||
tokenizer=dict(
|
||||
type='AutoTokenizer',
|
||||
name_or_path='baichuan-inc/baichuan-7B',
|
||||
trust_remote_code=True),
|
||||
task='caption',
|
||||
prompt_template=dict([('en', '###Ask: {} ###Answer: '),
|
||||
('zh', '###问:{} ###答:')]),
|
||||
raw_prompts=dict([
|
||||
('en', [('<Img><ImageHere></Img> '
|
||||
'Describe this image in detail.'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'Take a look at this image and describe what you notice.'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'Please provide a detailed description of the picture.'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'Could you describe the contents of this image for me?')]),
|
||||
('zh', [('<Img><ImageHere></Img> '
|
||||
'详细描述这张图片。'), ('<Img><ImageHere></Img> '
|
||||
'浏览这张图片并描述你注意到什么。'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'请对这张图片进行详细的描述。'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'你能为我描述这张图片的内容吗?')])
|
||||
]),
|
||||
max_txt_len=160,
|
||||
end_sym='###')
|
||||
|
||||
strategy = dict(
|
||||
type='DeepSpeedStrategy',
|
||||
fp16=dict(
|
||||
enabled=True,
|
||||
auto_cast=False,
|
||||
fp16_master_weights_and_grads=False,
|
||||
loss_scale=0,
|
||||
loss_scale_window=1000,
|
||||
hysteresis=1,
|
||||
min_loss_scale=1,
|
||||
initial_scale_power=16,
|
||||
),
|
||||
inputs_to_half=[0],
|
||||
zero_optimization=dict(
|
||||
stage=2,
|
||||
allgather_partitions=True,
|
||||
allgather_bucket_size=2e8,
|
||||
reduce_scatter=True,
|
||||
reduce_bucket_size='auto',
|
||||
overlap_comm=True,
|
||||
contiguous_gradients=True,
|
||||
),
|
||||
)
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(
|
||||
type='DeepSpeedOptimWrapper',
|
||||
optimizer=dict(type='AdamW', lr=1e-3, weight_decay=0.05))
|
||||
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-3 / 500,
|
||||
by_epoch=False,
|
||||
begin=0,
|
||||
end=500,
|
||||
),
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
eta_min=2e-4,
|
||||
by_epoch=False,
|
||||
begin=500,
|
||||
),
|
||||
]
|
||||
|
||||
train_cfg = dict(by_epoch=True, max_epochs=6)
|
||||
test_cfg = dict()
|
||||
|
||||
runner_type = 'FlexibleRunner'
|
||||
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(
|
||||
type='CheckpointHook',
|
||||
interval=1,
|
||||
by_epoch=True,
|
||||
save_last=True,
|
||||
max_keep_ckpts=1,
|
||||
))
|
|
@ -55,13 +55,25 @@ model = dict(
|
|||
type='AutoModelForCausalLM', name_or_path='YOUR_PATH_TO_VICUNA'),
|
||||
tokenizer=dict(type='LlamaTokenizer', name_or_path='YOUR_PATH_TO_VICUNA'),
|
||||
task='caption',
|
||||
prompt_template='###Human: {} ###Assistant: ',
|
||||
raw_prompts=[
|
||||
'<Img><ImageHere></Img> Describe this image in detail.',
|
||||
'<Img><ImageHere></Img> Take a look at this image and describe what you notice.', # noqa
|
||||
'<Img><ImageHere></Img> Please provide a detailed description of the picture.', # noqa
|
||||
'<Img><ImageHere></Img> Could you describe the contents of this image for me?', # noqa
|
||||
],
|
||||
prompt_template=dict([('en', '###Ask: {} ###Answer: '),
|
||||
('zh', '###问:{} ###答:')]),
|
||||
raw_prompts=dict([
|
||||
('en', [('<Img><ImageHere></Img> '
|
||||
'Describe this image in detail.'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'Take a look at this image and describe what you notice.'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'Please provide a detailed description of the picture.'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'Could you describe the contents of this image for me?')]),
|
||||
('zh', [('<Img><ImageHere></Img> '
|
||||
'详细描述这张图片。'), ('<Img><ImageHere></Img> '
|
||||
'浏览这张图片并描述你注意到什么。'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'请对这张图片进行详细的描述。'),
|
||||
('<Img><ImageHere></Img> '
|
||||
'你能为我描述这张图片的内容吗?')])
|
||||
]),
|
||||
max_txt_len=160,
|
||||
end_sym='###')
|
||||
|
||||
|
|
|
@ -43,6 +43,7 @@ if WITH_MULTIMODAL:
|
|||
from .gqa_dataset import GQA
|
||||
from .iconqa import IconQA
|
||||
from .infographic_vqa import InfographicVQA
|
||||
from .minigpt4_dataset import MiniGPT4Dataset
|
||||
from .nocaps import NoCaps
|
||||
from .ocr_vqa import OCRVQA
|
||||
from .refcoco import RefCOCO
|
||||
|
@ -56,5 +57,6 @@ if WITH_MULTIMODAL:
|
|||
'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption',
|
||||
'FlamingoEvalCOCOVQA', 'Flickr30kCaption', 'Flickr30kRetrieval',
|
||||
'RefCOCO', 'VisualGenomeQA', 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA',
|
||||
'VSR', 'VizWiz', 'OCRVQA', 'InfographicVQA', 'IconQA'
|
||||
'VSR', 'VizWiz', 'OCRVQA', 'InfographicVQA', 'IconQA',
|
||||
'MiniGPT4Dataset'
|
||||
])
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
import mmengine
|
||||
from mmengine.dataset import BaseDataset
|
||||
from mmengine.fileio import get_file_backend
|
||||
|
||||
from mmpretrain.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class MiniGPT4Dataset(BaseDataset):
|
||||
"""Dataset for training MiniGPT4.
|
||||
|
||||
MiniGPT4 dataset directory:
|
||||
|
||||
minigpt4_dataset
|
||||
├── image
|
||||
│ ├── id0.jpg
|
||||
│ │── id1.jpg
|
||||
│ │── id2.jpg
|
||||
│ └── ...
|
||||
└── conversation_data.json
|
||||
|
||||
The structure of conversation_data.json:
|
||||
|
||||
[
|
||||
// English data
|
||||
{
|
||||
"id": str(id0),
|
||||
"conversation": "###Ask: <Img><ImageHere></Img> [Ask content]
|
||||
###Answer: [Answer content]"
|
||||
},
|
||||
|
||||
// Chinese data
|
||||
{
|
||||
"id": str(id1),
|
||||
"conversation": "###问:<Img><ImageHere></Img> [Ask content]
|
||||
###答:[Answer content]"
|
||||
},
|
||||
|
||||
...
|
||||
]
|
||||
|
||||
Args:
|
||||
data_root (str): The root directory for ``ann_file`` and ``image``.
|
||||
ann_file (str): Conversation file path.
|
||||
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
||||
"""
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
file_backend = get_file_backend(self.data_root)
|
||||
conversation_path = file_backend.join_path(self.data_root,
|
||||
self.ann_file)
|
||||
conversation = mmengine.load(conversation_path)
|
||||
img_ids = {}
|
||||
n = 0
|
||||
for conv in conversation:
|
||||
img_id = conv['id']
|
||||
if img_id not in img_ids.keys():
|
||||
img_ids[img_id] = n
|
||||
n += 1
|
||||
|
||||
img_root = file_backend.join_path(self.data_root, 'image')
|
||||
data_list = []
|
||||
for conv in conversation:
|
||||
img_file = '{}.jpg'.format(conv['id'])
|
||||
chat_content = conv['conversation']
|
||||
lang = 'en' if chat_content.startswith('###Ask: ') else 'zh'
|
||||
data_info = {
|
||||
'image_id': img_ids[conv['id']],
|
||||
'img_path': file_backend.join_path(img_root, img_file),
|
||||
'chat_content': chat_content,
|
||||
'lang': lang,
|
||||
}
|
||||
|
||||
data_list.append(data_info)
|
||||
|
||||
return data_list
|
|
@ -31,12 +31,12 @@ class MiniGPT4(BaseModel):
|
|||
True.
|
||||
num_query_token (int): Number of query tokens of Qformer. Defaults to
|
||||
32.
|
||||
prompt_template (str): Prompt template of the model. Defaults to
|
||||
'###Human: {} ###Assistant: '.
|
||||
raw_prompts (list): Prompts for training. Defaults to None.
|
||||
prompt_template (dict): Multi-language prompt template of the model. Defaults to dict([ ('en', '###Ask: {} ###Answer: '),
|
||||
('zh', '###问:{} ###答:')])
|
||||
raw_prompts (dict): Prompts for training. Defaults to dict().
|
||||
max_txt_len (int): Max token length while doing tokenization. Defaults
|
||||
to 32.
|
||||
end_sym (str): Ended symbol of the sequence. Defaults to '\\n'.
|
||||
end_sym (str): Ended symbol of the sequence. Defaults to '###'.
|
||||
generation_cfg (dict): The config of text generation. Defaults to
|
||||
dict().
|
||||
data_preprocessor (:obj:`BaseDataPreprocessor`): Used for
|
||||
|
@ -54,10 +54,12 @@ class MiniGPT4(BaseModel):
|
|||
freeze_vit: bool = True,
|
||||
freeze_q_former: bool = True,
|
||||
num_query_token: int = 32,
|
||||
prompt_template: str = '###Human: {} ###Assistant: ',
|
||||
raw_prompts: Optional[list] = None,
|
||||
prompt_template: dict = dict([('en',
|
||||
'###Ask: {} ###Answer: '),
|
||||
('zh', '###问:{} ###答:')]),
|
||||
raw_prompts: dict = dict(),
|
||||
max_txt_len: int = 32,
|
||||
end_sym: str = '\n',
|
||||
end_sym: str = '###',
|
||||
generation_cfg: dict = dict(),
|
||||
data_preprocessor: Optional[dict] = None,
|
||||
init_cfg: Optional[dict] = None):
|
||||
|
@ -135,16 +137,23 @@ class MiniGPT4(BaseModel):
|
|||
self.end_token_id = self.llama_tokenizer.encode(end_sym)[-1]
|
||||
|
||||
# set prompts
|
||||
if raw_prompts is not None:
|
||||
filted_prompts = [
|
||||
raw_prompt for raw_prompt in raw_prompts
|
||||
self.en_prompt_list, self.zh_prompt_list = [], []
|
||||
if raw_prompts.get('en') is not None:
|
||||
en_filted_prompts = [
|
||||
raw_prompt for raw_prompt in raw_prompts['en']
|
||||
if '<ImageHere>' in raw_prompt
|
||||
]
|
||||
self.prompt_list = [
|
||||
prompt_template.format(p) for p in filted_prompts
|
||||
self.en_prompt_list = [
|
||||
prompt_template['en'].format(p) for p in en_filted_prompts
|
||||
]
|
||||
if raw_prompts.get('zh') is not None:
|
||||
zh_filted_prompts = [
|
||||
raw_prompt for raw_prompt in raw_prompts['zh']
|
||||
if '<ImageHere>' in raw_prompt
|
||||
]
|
||||
self.zh_prompt_list = [
|
||||
prompt_template['zh'].format(p) for p in zh_filted_prompts
|
||||
]
|
||||
else:
|
||||
self.prompt_list = []
|
||||
|
||||
# update generation configs
|
||||
self.generation_cfg = dict(
|
||||
|
@ -153,7 +162,7 @@ class MiniGPT4(BaseModel):
|
|||
do_sample=True,
|
||||
min_length=1,
|
||||
top_p=0.9,
|
||||
repetition_penalty=1.0,
|
||||
repetition_penalty=1.1,
|
||||
length_penalty=1.0,
|
||||
temperature=1.0)
|
||||
self.generation_cfg.update(**generation_cfg)
|
||||
|
@ -161,6 +170,10 @@ class MiniGPT4(BaseModel):
|
|||
if hasattr(self, 'register_load_state_dict_post_hook'):
|
||||
self.register_load_state_dict_post_hook(self._load_llama_proj_hook)
|
||||
|
||||
def half(self):
|
||||
self.llama_model = self.llama_model.half()
|
||||
return self
|
||||
|
||||
def encode_img(self,
|
||||
images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""The function to encode the images."""
|
||||
|
@ -184,33 +197,39 @@ class MiniGPT4(BaseModel):
|
|||
return inputs_llama, atts_llama
|
||||
|
||||
def prompt_wrap(self, img_embeds: torch.Tensor, atts_img: torch.Tensor,
|
||||
prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
prompt: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""The function to wrap the image and prompt.
|
||||
|
||||
Currently, the function only supports applying one prompt to all input
|
||||
images in the one batch.
|
||||
Make sure that len(prompt) == img_embeds.shape[0].
|
||||
|
||||
Args:
|
||||
img_embeds (torch.Tensor): The embedding of the input images.
|
||||
atts_img (torch.Tensor): Attention map of the image embeddings.
|
||||
prompt (str): The prompt of the batch data.
|
||||
prompt (List[str]): The prompt of the batch data.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The embedding and attention map.
|
||||
"""
|
||||
if prompt:
|
||||
batch_size = img_embeds.shape[0]
|
||||
p_before, p_after = prompt.split('<ImageHere>')
|
||||
if len(prompt) > 0:
|
||||
p_before_list, p_after_list = [], []
|
||||
for pro in prompt:
|
||||
p_before, p_after = pro.split('<ImageHere>')
|
||||
p_before_list.append(p_before)
|
||||
p_after_list.append(p_after)
|
||||
p_before_tokens = self.llama_tokenizer(
|
||||
p_before, return_tensors='pt',
|
||||
p_before_list,
|
||||
return_tensors='pt',
|
||||
padding='longest',
|
||||
add_special_tokens=False).to(img_embeds.device)
|
||||
p_after_tokens = self.llama_tokenizer(
|
||||
p_after, return_tensors='pt',
|
||||
p_after_list,
|
||||
return_tensors='pt',
|
||||
padding='longest',
|
||||
add_special_tokens=False).to(img_embeds.device)
|
||||
p_before_embeds = self.llama_model.model.embed_tokens(
|
||||
p_before_tokens.input_ids).expand(batch_size, -1, -1)
|
||||
p_before_tokens.input_ids)
|
||||
p_after_embeds = self.llama_model.model.embed_tokens(
|
||||
p_after_tokens.input_ids).expand(batch_size, -1, -1)
|
||||
p_after_tokens.input_ids)
|
||||
wrapped_img_embeds = torch.cat(
|
||||
[p_before_embeds, img_embeds, p_after_embeds], dim=1)
|
||||
wrapped_atts_img = atts_img[:, :1].expand(
|
||||
|
@ -234,17 +253,22 @@ class MiniGPT4(BaseModel):
|
|||
"""
|
||||
img_embeds, atts_img = self.encode_img(images)
|
||||
|
||||
if self.task == 'caption' and self.prompt_list:
|
||||
prompt = random.choice(self.prompt_list)
|
||||
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img,
|
||||
prompt)
|
||||
|
||||
self.llama_tokenizer.padding_side = 'right'
|
||||
|
||||
text = [t + self.end_sym for t in data_samples['text_input']]
|
||||
prompts, texts = [], []
|
||||
for t in data_samples:
|
||||
chat_content = t.chat_content
|
||||
split_mark = '###Answer: ' if t.lang == 'en' else '###答:'
|
||||
prompt, text = chat_content.split(split_mark)
|
||||
prompt += split_mark
|
||||
text += self.end_sym
|
||||
prompts.append(prompt)
|
||||
texts.append(text)
|
||||
|
||||
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompts)
|
||||
|
||||
to_regress_tokens = self.llama_tokenizer(
|
||||
text,
|
||||
texts,
|
||||
return_tensors='pt',
|
||||
padding='longest',
|
||||
truncation=True,
|
||||
|
@ -295,10 +319,12 @@ class MiniGPT4(BaseModel):
|
|||
with torch.no_grad():
|
||||
img_embeds, atts_img = self.encode_img(images)
|
||||
|
||||
if self.task == 'caption' and self.prompt_list:
|
||||
prompt = random.choice(self.prompt_list)
|
||||
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img,
|
||||
prompt)
|
||||
prompts = [
|
||||
random.choice(self.zh_prompt_list) if hasattr(t, 'lang')
|
||||
and t.lang == 'zh' else random.choice(self.en_prompt_list)
|
||||
for t in data_samples
|
||||
]
|
||||
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompts)
|
||||
|
||||
batch_size = img_embeds.shape[0]
|
||||
bos = torch.ones(
|
||||
|
@ -336,7 +362,6 @@ class MiniGPT4(BaseModel):
|
|||
for output, data_sample in zip(outputs, data_samples):
|
||||
if self.task == 'caption':
|
||||
output = output.split('###')[0]
|
||||
output = output.split('Assistant:')[-1].strip()
|
||||
data_sample.pred_caption = output
|
||||
else:
|
||||
# raw output
|
||||
|
|
|
@ -0,0 +1,137 @@
|
|||
# Modified from
|
||||
# https://github.com/Vision-CAIR/MiniGPT-4/blob/main/minigpt4/conversation/conversation.py
|
||||
import dataclasses
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Conversation:
|
||||
system: str
|
||||
roles: List[str]
|
||||
messages: List[List[str]]
|
||||
sep: str = '###'
|
||||
|
||||
def get_prompt(self):
|
||||
ret = self.system + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + ': ' + message + self.sep
|
||||
else:
|
||||
ret += role + ':'
|
||||
return ret
|
||||
|
||||
def append_message(self, role, message):
|
||||
self.messages.append([role, message])
|
||||
|
||||
def copy(self):
|
||||
return Conversation(
|
||||
system=self.system,
|
||||
roles=[role for role in self.roles],
|
||||
messages=[[y for y in x] for x in self.messages],
|
||||
sep=self.sep,
|
||||
)
|
||||
|
||||
def dict(self):
|
||||
return {
|
||||
'system': self.system,
|
||||
'roles': self.roles,
|
||||
'messages': self.messages,
|
||||
'offset': self.offset,
|
||||
'sep': self.sep,
|
||||
}
|
||||
|
||||
|
||||
EN_CONV_VISION = Conversation(
|
||||
system='Give the following image. '
|
||||
'You will be able to see the image once I provide it to you. '
|
||||
'Please answer my questions in detail.',
|
||||
roles=['Ask', 'Answer'],
|
||||
messages=[],
|
||||
sep='###',
|
||||
)
|
||||
|
||||
ZH_CONV_VISION = Conversation(
|
||||
system='给定一张图片,请仔细观察这张图片,并回答我的问题。',
|
||||
roles=['问', '答'],
|
||||
messages=[],
|
||||
sep='###',
|
||||
)
|
||||
|
||||
|
||||
class Chat:
|
||||
|
||||
def __init__(self, inferencer, device, is_half=False):
|
||||
self.device = device
|
||||
self.inferencer = inferencer
|
||||
self.model = inferencer.model
|
||||
self.is_half = is_half
|
||||
if is_half:
|
||||
self.model = self.model.half()
|
||||
self.model = self.model.to(device)
|
||||
self.max_length = 2000
|
||||
|
||||
def upload_img(self, image, conv, img_list):
|
||||
img = next(self.inferencer.preprocess([image]))
|
||||
img = self.model.data_preprocessor(img, False)['images']
|
||||
img = img.to(self.device)
|
||||
image_emb, _ = self.model.encode_img(img)
|
||||
img_list.append(image_emb)
|
||||
conv.append_message(conv.roles[0], '<Img><ImageHere></Img>')
|
||||
|
||||
def get_context_emb(self, conv, img_list):
|
||||
prompt = conv.get_prompt()
|
||||
prompt_segs = prompt.split('<ImageHere>')
|
||||
seg_tokens = [
|
||||
self.model.llama_tokenizer(
|
||||
seg, return_tensors='pt',
|
||||
add_special_tokens=(i == 0)).to(self.device).input_ids
|
||||
for i, seg in enumerate(prompt_segs)
|
||||
]
|
||||
seg_embs = [
|
||||
self.model.llama_model.model.embed_tokens(seg_token)
|
||||
for seg_token in seg_tokens
|
||||
]
|
||||
mixed_embs = [
|
||||
emb for pair in zip(seg_embs[:-1], img_list) for emb in pair
|
||||
] + [seg_embs[-1]]
|
||||
mixed_embs = torch.cat(mixed_embs, dim=1)
|
||||
return mixed_embs
|
||||
|
||||
def ask(self, text, conv):
|
||||
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[
|
||||
0] and conv.messages[-1][1][-6:] == '</Img>':
|
||||
conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
|
||||
else:
|
||||
conv.append_message(conv.roles[0], text)
|
||||
|
||||
def answer(self, conv, img_list, generation_cfg):
|
||||
conv.append_message(conv.roles[1], None)
|
||||
embs = self.get_context_emb(conv, img_list)
|
||||
cur_max_len = generation_cfg['max_new_tokens'] + embs.shape[1]
|
||||
if cur_max_len > self.max_length:
|
||||
print('Warning: The number of tokens in current conversation'
|
||||
'exceeds the max length. '
|
||||
'The model will not see the contexts outside the range.')
|
||||
begin_idx = max(0, cur_max_len - self.max_length)
|
||||
embs = embs[:, begin_idx:]
|
||||
if self.is_half:
|
||||
embs = embs.half()
|
||||
outputs = self.model.llama_model.generate(
|
||||
inputs_embeds=embs,
|
||||
eos_token_id=self.model.end_token_id,
|
||||
**generation_cfg)
|
||||
|
||||
output_token = outputs[0]
|
||||
if output_token[0] == 0:
|
||||
output_token = output_token[1:]
|
||||
elif output_token[0] == 1:
|
||||
output_token = output_token[1:]
|
||||
output_text = self.model.llama_tokenizer.decode(
|
||||
output_token,
|
||||
add_special_tokens=False,
|
||||
skip_special_tokens=True)
|
||||
output_text = output_text.split('###')[0]
|
||||
conv.messages[-1][1] = output_text
|
||||
return output_text
|
|
@ -0,0 +1,144 @@
|
|||
import argparse
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import torch
|
||||
from conversation import EN_CONV_VISION, ZH_CONV_VISION, Chat
|
||||
|
||||
from mmpretrain import ImageCaptionInferencer
|
||||
|
||||
parser = argparse.ArgumentParser(description='MiniGPT4 demo')
|
||||
parser.add_argument(
|
||||
'cfg', type=str, help='config file for minigpt4 (absolute path)')
|
||||
parser.add_argument(
|
||||
'ckpt', type=str, help='pretrained file for minigpt4 (absolute path)')
|
||||
args = parser.parse_args()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
devices = [
|
||||
torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())
|
||||
]
|
||||
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
devices = [torch.device('mps')]
|
||||
else:
|
||||
devices = [torch.device('cpu')]
|
||||
|
||||
|
||||
def get_free_device():
|
||||
if hasattr(torch.cuda, 'mem_get_info'):
|
||||
free = [torch.cuda.mem_get_info(gpu)[0] for gpu in devices]
|
||||
select = max(zip(free, range(len(free))))[1]
|
||||
else:
|
||||
import random
|
||||
select = random.randint(0, len(devices) - 1)
|
||||
return devices[select]
|
||||
|
||||
|
||||
device = get_free_device()
|
||||
inferencer = ImageCaptionInferencer(model=args.cfg, pretrained=args.ckpt)
|
||||
model = inferencer.model
|
||||
chat = Chat(inferencer, device=device, is_half=(device.type != 'cpu'))
|
||||
|
||||
|
||||
def reset(chat_state, img_list):
|
||||
if chat_state is not None:
|
||||
chat_state.messages = []
|
||||
if img_list is not None:
|
||||
img_list = []
|
||||
return (None, gr.update(value=None, interactive=True),
|
||||
gr.update(
|
||||
value=None,
|
||||
placeholder='Please upload your image first',
|
||||
interactive=False),
|
||||
gr.update(value='Upload & Start Chat',
|
||||
interactive=True), chat_state, img_list,
|
||||
gr.update(value='Restart', interactive=False),
|
||||
gr.update(value='English', interactive=True))
|
||||
|
||||
|
||||
def upload_img(gr_img, language, chat_state):
|
||||
if gr_img is None:
|
||||
return (None,
|
||||
gr.update(
|
||||
placeholder='Please upload your image first',
|
||||
interactive=False),
|
||||
gr.update(value='Upload & Start Chat',
|
||||
interactive=True), chat_state, None,
|
||||
gr.update(value='Restart', interactive=False),
|
||||
gr.update(value='English', interactive=True))
|
||||
|
||||
if (language == 'English'):
|
||||
chat_state = EN_CONV_VISION.copy()
|
||||
else:
|
||||
chat_state = ZH_CONV_VISION.copy()
|
||||
img_list = []
|
||||
gr_img_array = np.asarray(gr_img)
|
||||
chat.upload_img(gr_img_array, chat_state, img_list)
|
||||
return (gr.update(interactive=False),
|
||||
gr.update(placeholder='Type and press Enter', interactive=True),
|
||||
gr.update(value='Start Chatting',
|
||||
interactive=False), chat_state, img_list,
|
||||
gr.update(value='Restart',
|
||||
interactive=True), gr.update(interactive=False))
|
||||
|
||||
|
||||
def ask(user_message, chatbot, chat_state):
|
||||
if (len(user_message) == 0):
|
||||
return gr.update(
|
||||
value=None,
|
||||
placeholder='Input should not be empty!',
|
||||
interactive=True), chatbot, chat_state
|
||||
chat.ask(user_message, chat_state)
|
||||
chatbot = chatbot + [[user_message, None]]
|
||||
return '', chatbot, chat_state
|
||||
|
||||
|
||||
def answer(chatbot, chat_state, img_list):
|
||||
llm_message = chat.answer(
|
||||
conv=chat_state,
|
||||
img_list=img_list,
|
||||
generation_cfg=model.generation_cfg)
|
||||
chatbot[-1][1] = llm_message
|
||||
return chatbot, chat_state, img_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
title = 'MMPretrain MiniGPT-4 Inference Demo'
|
||||
with gr.Blocks(analytics_enabled=False, title=title) as demo:
|
||||
gr.Markdown(f'# {title}')
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
image = gr.Image(type='pil')
|
||||
language = gr.Dropdown(['English', 'Chinese'],
|
||||
label='Language',
|
||||
info='Select chatbot\'s language',
|
||||
value='English',
|
||||
interactive=True)
|
||||
upload_button = gr.Button(
|
||||
value='Upload & Start Chat', interactive=True)
|
||||
clear = gr.Button(value='Restart', interactive=False)
|
||||
|
||||
with gr.Column():
|
||||
chat_state = gr.State()
|
||||
img_list = gr.State()
|
||||
chatbot = gr.Chatbot(
|
||||
label='MiniGPT-4', min_width=320, height=600)
|
||||
text_input = gr.Textbox(
|
||||
label='User',
|
||||
placeholder='Please upload your image first',
|
||||
interactive=False)
|
||||
|
||||
upload_button.click(upload_img, [image, language, chat_state], [
|
||||
image, text_input, upload_button, chat_state, img_list, clear,
|
||||
language
|
||||
])
|
||||
text_input.submit(ask, [text_input, chatbot, chat_state],
|
||||
[text_input, chatbot, chat_state]).then(
|
||||
answer, [chatbot, chat_state, img_list],
|
||||
[chatbot, chat_state, img_list])
|
||||
clear.click(reset, [chat_state, img_list], [
|
||||
chatbot, image, text_input, upload_button, chat_state, img_list,
|
||||
clear, language
|
||||
])
|
||||
|
||||
demo.launch(share=True)
|
Loading…
Reference in New Issue