[Feature] Add minigpt4 gradio demo and training script. (#1758)

* Add minigpt4 gradio demo

* update minigpt4 demo

* update minigpt4 demo (inference with float16)

* update minigpt4 and some dependent files

* add minigpt4 dataset for training

* add training script for minigpt4

* restore files deleted by mistake

* fix an error

* remove useless modification

* provide command line arguments for minigpt4 gradio demo and update some comments

* update code

* Update minigpt-4 readme

---------

Co-authored-by: mzr1996 <mzr1996@163.com>
pull/1853/head
hmtbgc 2023-10-12 10:36:17 +08:00 committed by GitHub
parent 5c71de6b8e
commit c0766519b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 651 additions and 50 deletions

View File

@ -35,8 +35,9 @@ For Vicuna model, please refer to [MiniGPT-4 page](https://github.com/Vision-CAI
### Pretrained models ### Pretrained models
| Model | Params (M) | Flops (G) | Config | Download | | Model | Params (M) | Flops (G) | Config | Download |
| :------------------------------ | :--------: | :-------: | :--------------------------------------: | :------------------------------------------------------------------------------------------------------------: | | :------------------------------ | :--------: | :-------: | :----------------------------------------: | :----------------------------------------------------------------------------------------------------------: |
| `minigpt-4_vicuna-7b_caption`\* | 8121.32 | N/A | [config](minigpt-4_vicuna-7b_caption.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_linear-projection_20230615-714b5f52.pth) | | `minigpt-4_baichuan-7b_caption` | 8094.77 | N/A | [config](minigpt-4_baichuan-7b_caption.py) | [model](https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_baichuan7b_20231011-5dca7ed6.pth) |
| `minigpt-4_vicuna-7b_caption`\* | 8121.32 | N/A | [config](minigpt-4_vicuna-7b_caption.py) | [model](https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_vicuna7b_20230615-714b5f52.pth) |
*Models with * are converted from the [official repo](https://github.com/Vision-CAIR/MiniGPT-4/tree/main). The config files of these models are only for inference. We haven't reproduce the training results.* *Models with * are converted from the [official repo](https://github.com/Vision-CAIR/MiniGPT-4/tree/main). The config files of these models are only for inference. We haven't reproduce the training results.*

View File

@ -19,8 +19,19 @@ Models:
- Task: Image Caption - Task: Image Caption
Dataset: COCO Dataset: COCO
Metrics: null Metrics: null
Weights: https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_linear-projection_20230615-714b5f52.pth Weights: https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_vicuna7b_20230615-714b5f52.pth
Config: configs/minigpt4/minigpt-4_vicuna-7b_caption.py Config: configs/minigpt4/minigpt-4_vicuna-7b_caption.py
Converted From: Converted From:
Weights: https://github.com/Vision-CAIR/MiniGPT-4/tree/main Weights: https://github.com/Vision-CAIR/MiniGPT-4/tree/main
Code: https://github.com/Vision-CAIR/MiniGPT-4/tree/main Code: https://github.com/Vision-CAIR/MiniGPT-4/tree/main
- Name: minigpt-4_baichuan-7b_caption
Metadata:
FLOPs: null
Parameters: 8094769024
In Collection: MiniGPT4
Results:
- Task: Image Caption
Dataset: COCO
Metrics: null
Weights: https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_baichuan7b_20231011-5dca7ed6.pth
Config: configs/minigpt4/minigpt-4_baichuan-7b_caption.py

View File

@ -0,0 +1,190 @@
_base_ = [
'../_base_/default_runtime.py',
]
data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)
# dataset settings
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(224, 224),
interpolation='bicubic',
backend='pillow'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='CleanCaption',
keys='chat_content',
remove_chars='',
lowercase=False),
dict(
type='PackInputs',
algorithm_keys=['chat_content', 'lang'],
meta_keys=['image_id']),
]
train_dataloader = dict(
batch_size=2,
num_workers=4,
dataset=dict(
type='MiniGPT4Dataset',
data_root='YOUR_DATA_DIRECTORY',
ann_file='YOUR_DATA_FILE',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
drop_last=False,
)
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(224, 224),
interpolation='bicubic',
backend='pillow'),
dict(type='PackInputs', meta_keys=['image_id']),
]
test_evaluator = dict(
type='COCOCaption',
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
)
test_dataloader = dict(
batch_size=1,
dataset=dict(
type='COCOCaption',
data_root='data/coco',
ann_file='annotations/coco_karpathy_val.json',
pipeline=test_pipeline))
# model settings
model = dict(
type='MiniGPT4',
vision_encoder=dict(
type='BEiTViT',
# eva-g without the final layer
arch=dict(
embed_dims=1408,
num_layers=39,
num_heads=16,
feedforward_channels=6144,
),
img_size=224,
patch_size=14,
layer_scale_init_value=0.0,
frozen_stages=39,
use_abs_pos_emb=True,
use_rel_pos_bias=False,
final_norm=False,
use_shared_rel_pos_bias=False,
out_type='raw',
pretrained= # noqa
'https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_eva-g-p14_20230615-e908c021.pth' # noqa
),
q_former_model=dict(
type='Qformer',
model_style='bert-base-uncased',
vision_model_width=1408,
add_cross_attention=True,
cross_attention_freq=2,
num_query_token=32,
pretrained= # noqa
'https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_qformer_20230615-1dfa889c.pth' # noqa
),
lang_encoder=dict(
type='AutoModelForCausalLM',
name_or_path='baichuan-inc/baichuan-7B',
trust_remote_code=True),
tokenizer=dict(
type='AutoTokenizer',
name_or_path='baichuan-inc/baichuan-7B',
trust_remote_code=True),
task='caption',
prompt_template=dict([('en', '###Ask: {} ###Answer: '),
('zh', '###问:{} ###答:')]),
raw_prompts=dict([
('en', [('<Img><ImageHere></Img> '
'Describe this image in detail.'),
('<Img><ImageHere></Img> '
'Take a look at this image and describe what you notice.'),
('<Img><ImageHere></Img> '
'Please provide a detailed description of the picture.'),
('<Img><ImageHere></Img> '
'Could you describe the contents of this image for me?')]),
('zh', [('<Img><ImageHere></Img> '
'详细描述这张图片。'), ('<Img><ImageHere></Img> '
'浏览这张图片并描述你注意到什么。'),
('<Img><ImageHere></Img> '
'请对这张图片进行详细的描述。'),
('<Img><ImageHere></Img> '
'你能为我描述这张图片的内容吗?')])
]),
max_txt_len=160,
end_sym='###')
strategy = dict(
type='DeepSpeedStrategy',
fp16=dict(
enabled=True,
auto_cast=False,
fp16_master_weights_and_grads=False,
loss_scale=0,
loss_scale_window=1000,
hysteresis=1,
min_loss_scale=1,
initial_scale_power=16,
),
inputs_to_half=[0],
zero_optimization=dict(
stage=2,
allgather_partitions=True,
allgather_bucket_size=2e8,
reduce_scatter=True,
reduce_bucket_size='auto',
overlap_comm=True,
contiguous_gradients=True,
),
)
# schedule settings
optim_wrapper = dict(
type='DeepSpeedOptimWrapper',
optimizer=dict(type='AdamW', lr=1e-3, weight_decay=0.05))
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-3 / 500,
by_epoch=False,
begin=0,
end=500,
),
dict(
type='CosineAnnealingLR',
eta_min=2e-4,
by_epoch=False,
begin=500,
),
]
train_cfg = dict(by_epoch=True, max_epochs=6)
test_cfg = dict()
runner_type = 'FlexibleRunner'
default_hooks = dict(
checkpoint=dict(
type='CheckpointHook',
interval=1,
by_epoch=True,
save_last=True,
max_keep_ckpts=1,
))

View File

@ -55,13 +55,25 @@ model = dict(
type='AutoModelForCausalLM', name_or_path='YOUR_PATH_TO_VICUNA'), type='AutoModelForCausalLM', name_or_path='YOUR_PATH_TO_VICUNA'),
tokenizer=dict(type='LlamaTokenizer', name_or_path='YOUR_PATH_TO_VICUNA'), tokenizer=dict(type='LlamaTokenizer', name_or_path='YOUR_PATH_TO_VICUNA'),
task='caption', task='caption',
prompt_template='###Human: {} ###Assistant: ', prompt_template=dict([('en', '###Ask: {} ###Answer: '),
raw_prompts=[ ('zh', '###问:{} ###答:')]),
'<Img><ImageHere></Img> Describe this image in detail.', raw_prompts=dict([
'<Img><ImageHere></Img> Take a look at this image and describe what you notice.', # noqa ('en', [('<Img><ImageHere></Img> '
'<Img><ImageHere></Img> Please provide a detailed description of the picture.', # noqa 'Describe this image in detail.'),
'<Img><ImageHere></Img> Could you describe the contents of this image for me?', # noqa ('<Img><ImageHere></Img> '
], 'Take a look at this image and describe what you notice.'),
('<Img><ImageHere></Img> '
'Please provide a detailed description of the picture.'),
('<Img><ImageHere></Img> '
'Could you describe the contents of this image for me?')]),
('zh', [('<Img><ImageHere></Img> '
'详细描述这张图片。'), ('<Img><ImageHere></Img> '
'浏览这张图片并描述你注意到什么。'),
('<Img><ImageHere></Img> '
'请对这张图片进行详细的描述。'),
('<Img><ImageHere></Img> '
'你能为我描述这张图片的内容吗?')])
]),
max_txt_len=160, max_txt_len=160,
end_sym='###') end_sym='###')

View File

@ -43,6 +43,7 @@ if WITH_MULTIMODAL:
from .gqa_dataset import GQA from .gqa_dataset import GQA
from .iconqa import IconQA from .iconqa import IconQA
from .infographic_vqa import InfographicVQA from .infographic_vqa import InfographicVQA
from .minigpt4_dataset import MiniGPT4Dataset
from .nocaps import NoCaps from .nocaps import NoCaps
from .ocr_vqa import OCRVQA from .ocr_vqa import OCRVQA
from .refcoco import RefCOCO from .refcoco import RefCOCO
@ -56,5 +57,6 @@ if WITH_MULTIMODAL:
'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption', 'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption',
'FlamingoEvalCOCOVQA', 'Flickr30kCaption', 'Flickr30kRetrieval', 'FlamingoEvalCOCOVQA', 'Flickr30kCaption', 'Flickr30kRetrieval',
'RefCOCO', 'VisualGenomeQA', 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA', 'RefCOCO', 'VisualGenomeQA', 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA',
'VSR', 'VizWiz', 'OCRVQA', 'InfographicVQA', 'IconQA' 'VSR', 'VizWiz', 'OCRVQA', 'InfographicVQA', 'IconQA',
'MiniGPT4Dataset'
]) ])

View File

@ -0,0 +1,79 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import mmengine
from mmengine.dataset import BaseDataset
from mmengine.fileio import get_file_backend
from mmpretrain.registry import DATASETS
@DATASETS.register_module()
class MiniGPT4Dataset(BaseDataset):
"""Dataset for training MiniGPT4.
MiniGPT4 dataset directory:
minigpt4_dataset
image
id0.jpg
id1.jpg
id2.jpg
...
conversation_data.json
The structure of conversation_data.json:
[
// English data
{
"id": str(id0),
"conversation": "###Ask: <Img><ImageHere></Img> [Ask content]
###Answer: [Answer content]"
},
// Chinese data
{
"id": str(id1),
"conversation": "###问:<Img><ImageHere></Img> [Ask content]
###答:[Answer content]"
},
...
]
Args:
data_root (str): The root directory for ``ann_file`` and ``image``.
ann_file (str): Conversation file path.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
def load_data_list(self) -> List[dict]:
file_backend = get_file_backend(self.data_root)
conversation_path = file_backend.join_path(self.data_root,
self.ann_file)
conversation = mmengine.load(conversation_path)
img_ids = {}
n = 0
for conv in conversation:
img_id = conv['id']
if img_id not in img_ids.keys():
img_ids[img_id] = n
n += 1
img_root = file_backend.join_path(self.data_root, 'image')
data_list = []
for conv in conversation:
img_file = '{}.jpg'.format(conv['id'])
chat_content = conv['conversation']
lang = 'en' if chat_content.startswith('###Ask: ') else 'zh'
data_info = {
'image_id': img_ids[conv['id']],
'img_path': file_backend.join_path(img_root, img_file),
'chat_content': chat_content,
'lang': lang,
}
data_list.append(data_info)
return data_list

View File

@ -31,12 +31,12 @@ class MiniGPT4(BaseModel):
True. True.
num_query_token (int): Number of query tokens of Qformer. Defaults to num_query_token (int): Number of query tokens of Qformer. Defaults to
32. 32.
prompt_template (str): Prompt template of the model. Defaults to prompt_template (dict): Multi-language prompt template of the model. Defaults to dict([ ('en', '###Ask: {} ###Answer: '),
'###Human: {} ###Assistant: '. ('zh', '###问:{} ###答:')])
raw_prompts (list): Prompts for training. Defaults to None. raw_prompts (dict): Prompts for training. Defaults to dict().
max_txt_len (int): Max token length while doing tokenization. Defaults max_txt_len (int): Max token length while doing tokenization. Defaults
to 32. to 32.
end_sym (str): Ended symbol of the sequence. Defaults to '\\n'. end_sym (str): Ended symbol of the sequence. Defaults to '###'.
generation_cfg (dict): The config of text generation. Defaults to generation_cfg (dict): The config of text generation. Defaults to
dict(). dict().
data_preprocessor (:obj:`BaseDataPreprocessor`): Used for data_preprocessor (:obj:`BaseDataPreprocessor`): Used for
@ -54,10 +54,12 @@ class MiniGPT4(BaseModel):
freeze_vit: bool = True, freeze_vit: bool = True,
freeze_q_former: bool = True, freeze_q_former: bool = True,
num_query_token: int = 32, num_query_token: int = 32,
prompt_template: str = '###Human: {} ###Assistant: ', prompt_template: dict = dict([('en',
raw_prompts: Optional[list] = None, '###Ask: {} ###Answer: '),
('zh', '###问:{} ###答:')]),
raw_prompts: dict = dict(),
max_txt_len: int = 32, max_txt_len: int = 32,
end_sym: str = '\n', end_sym: str = '###',
generation_cfg: dict = dict(), generation_cfg: dict = dict(),
data_preprocessor: Optional[dict] = None, data_preprocessor: Optional[dict] = None,
init_cfg: Optional[dict] = None): init_cfg: Optional[dict] = None):
@ -135,16 +137,23 @@ class MiniGPT4(BaseModel):
self.end_token_id = self.llama_tokenizer.encode(end_sym)[-1] self.end_token_id = self.llama_tokenizer.encode(end_sym)[-1]
# set prompts # set prompts
if raw_prompts is not None: self.en_prompt_list, self.zh_prompt_list = [], []
filted_prompts = [ if raw_prompts.get('en') is not None:
raw_prompt for raw_prompt in raw_prompts en_filted_prompts = [
raw_prompt for raw_prompt in raw_prompts['en']
if '<ImageHere>' in raw_prompt if '<ImageHere>' in raw_prompt
] ]
self.prompt_list = [ self.en_prompt_list = [
prompt_template.format(p) for p in filted_prompts prompt_template['en'].format(p) for p in en_filted_prompts
]
if raw_prompts.get('zh') is not None:
zh_filted_prompts = [
raw_prompt for raw_prompt in raw_prompts['zh']
if '<ImageHere>' in raw_prompt
]
self.zh_prompt_list = [
prompt_template['zh'].format(p) for p in zh_filted_prompts
] ]
else:
self.prompt_list = []
# update generation configs # update generation configs
self.generation_cfg = dict( self.generation_cfg = dict(
@ -153,7 +162,7 @@ class MiniGPT4(BaseModel):
do_sample=True, do_sample=True,
min_length=1, min_length=1,
top_p=0.9, top_p=0.9,
repetition_penalty=1.0, repetition_penalty=1.1,
length_penalty=1.0, length_penalty=1.0,
temperature=1.0) temperature=1.0)
self.generation_cfg.update(**generation_cfg) self.generation_cfg.update(**generation_cfg)
@ -161,6 +170,10 @@ class MiniGPT4(BaseModel):
if hasattr(self, 'register_load_state_dict_post_hook'): if hasattr(self, 'register_load_state_dict_post_hook'):
self.register_load_state_dict_post_hook(self._load_llama_proj_hook) self.register_load_state_dict_post_hook(self._load_llama_proj_hook)
def half(self):
self.llama_model = self.llama_model.half()
return self
def encode_img(self, def encode_img(self,
images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""The function to encode the images.""" """The function to encode the images."""
@ -184,33 +197,39 @@ class MiniGPT4(BaseModel):
return inputs_llama, atts_llama return inputs_llama, atts_llama
def prompt_wrap(self, img_embeds: torch.Tensor, atts_img: torch.Tensor, def prompt_wrap(self, img_embeds: torch.Tensor, atts_img: torch.Tensor,
prompt: str) -> Tuple[torch.Tensor, torch.Tensor]: prompt: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
"""The function to wrap the image and prompt. """The function to wrap the image and prompt.
Currently, the function only supports applying one prompt to all input Make sure that len(prompt) == img_embeds.shape[0].
images in the one batch.
Args: Args:
img_embeds (torch.Tensor): The embedding of the input images. img_embeds (torch.Tensor): The embedding of the input images.
atts_img (torch.Tensor): Attention map of the image embeddings. atts_img (torch.Tensor): Attention map of the image embeddings.
prompt (str): The prompt of the batch data. prompt (List[str]): The prompt of the batch data.
Returns: Returns:
Tuple[torch.Tensor, torch.Tensor]: The embedding and attention map. Tuple[torch.Tensor, torch.Tensor]: The embedding and attention map.
""" """
if prompt: if len(prompt) > 0:
batch_size = img_embeds.shape[0] p_before_list, p_after_list = [], []
p_before, p_after = prompt.split('<ImageHere>') for pro in prompt:
p_before, p_after = pro.split('<ImageHere>')
p_before_list.append(p_before)
p_after_list.append(p_after)
p_before_tokens = self.llama_tokenizer( p_before_tokens = self.llama_tokenizer(
p_before, return_tensors='pt', p_before_list,
return_tensors='pt',
padding='longest',
add_special_tokens=False).to(img_embeds.device) add_special_tokens=False).to(img_embeds.device)
p_after_tokens = self.llama_tokenizer( p_after_tokens = self.llama_tokenizer(
p_after, return_tensors='pt', p_after_list,
return_tensors='pt',
padding='longest',
add_special_tokens=False).to(img_embeds.device) add_special_tokens=False).to(img_embeds.device)
p_before_embeds = self.llama_model.model.embed_tokens( p_before_embeds = self.llama_model.model.embed_tokens(
p_before_tokens.input_ids).expand(batch_size, -1, -1) p_before_tokens.input_ids)
p_after_embeds = self.llama_model.model.embed_tokens( p_after_embeds = self.llama_model.model.embed_tokens(
p_after_tokens.input_ids).expand(batch_size, -1, -1) p_after_tokens.input_ids)
wrapped_img_embeds = torch.cat( wrapped_img_embeds = torch.cat(
[p_before_embeds, img_embeds, p_after_embeds], dim=1) [p_before_embeds, img_embeds, p_after_embeds], dim=1)
wrapped_atts_img = atts_img[:, :1].expand( wrapped_atts_img = atts_img[:, :1].expand(
@ -234,17 +253,22 @@ class MiniGPT4(BaseModel):
""" """
img_embeds, atts_img = self.encode_img(images) img_embeds, atts_img = self.encode_img(images)
if self.task == 'caption' and self.prompt_list:
prompt = random.choice(self.prompt_list)
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img,
prompt)
self.llama_tokenizer.padding_side = 'right' self.llama_tokenizer.padding_side = 'right'
text = [t + self.end_sym for t in data_samples['text_input']] prompts, texts = [], []
for t in data_samples:
chat_content = t.chat_content
split_mark = '###Answer: ' if t.lang == 'en' else '###答:'
prompt, text = chat_content.split(split_mark)
prompt += split_mark
text += self.end_sym
prompts.append(prompt)
texts.append(text)
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompts)
to_regress_tokens = self.llama_tokenizer( to_regress_tokens = self.llama_tokenizer(
text, texts,
return_tensors='pt', return_tensors='pt',
padding='longest', padding='longest',
truncation=True, truncation=True,
@ -295,10 +319,12 @@ class MiniGPT4(BaseModel):
with torch.no_grad(): with torch.no_grad():
img_embeds, atts_img = self.encode_img(images) img_embeds, atts_img = self.encode_img(images)
if self.task == 'caption' and self.prompt_list: prompts = [
prompt = random.choice(self.prompt_list) random.choice(self.zh_prompt_list) if hasattr(t, 'lang')
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, and t.lang == 'zh' else random.choice(self.en_prompt_list)
prompt) for t in data_samples
]
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompts)
batch_size = img_embeds.shape[0] batch_size = img_embeds.shape[0]
bos = torch.ones( bos = torch.ones(
@ -336,7 +362,6 @@ class MiniGPT4(BaseModel):
for output, data_sample in zip(outputs, data_samples): for output, data_sample in zip(outputs, data_samples):
if self.task == 'caption': if self.task == 'caption':
output = output.split('###')[0] output = output.split('###')[0]
output = output.split('Assistant:')[-1].strip()
data_sample.pred_caption = output data_sample.pred_caption = output
else: else:
# raw output # raw output

View File

@ -0,0 +1,137 @@
# Modified from
# https://github.com/Vision-CAIR/MiniGPT-4/blob/main/minigpt4/conversation/conversation.py
import dataclasses
from typing import List
import torch
@dataclasses.dataclass
class Conversation:
system: str
roles: List[str]
messages: List[List[str]]
sep: str = '###'
def get_prompt(self):
ret = self.system + self.sep
for role, message in self.messages:
if message:
ret += role + ': ' + message + self.sep
else:
ret += role + ':'
return ret
def append_message(self, role, message):
self.messages.append([role, message])
def copy(self):
return Conversation(
system=self.system,
roles=[role for role in self.roles],
messages=[[y for y in x] for x in self.messages],
sep=self.sep,
)
def dict(self):
return {
'system': self.system,
'roles': self.roles,
'messages': self.messages,
'offset': self.offset,
'sep': self.sep,
}
EN_CONV_VISION = Conversation(
system='Give the following image. '
'You will be able to see the image once I provide it to you. '
'Please answer my questions in detail.',
roles=['Ask', 'Answer'],
messages=[],
sep='###',
)
ZH_CONV_VISION = Conversation(
system='给定一张图片,请仔细观察这张图片,并回答我的问题。',
roles=['', ''],
messages=[],
sep='###',
)
class Chat:
def __init__(self, inferencer, device, is_half=False):
self.device = device
self.inferencer = inferencer
self.model = inferencer.model
self.is_half = is_half
if is_half:
self.model = self.model.half()
self.model = self.model.to(device)
self.max_length = 2000
def upload_img(self, image, conv, img_list):
img = next(self.inferencer.preprocess([image]))
img = self.model.data_preprocessor(img, False)['images']
img = img.to(self.device)
image_emb, _ = self.model.encode_img(img)
img_list.append(image_emb)
conv.append_message(conv.roles[0], '<Img><ImageHere></Img>')
def get_context_emb(self, conv, img_list):
prompt = conv.get_prompt()
prompt_segs = prompt.split('<ImageHere>')
seg_tokens = [
self.model.llama_tokenizer(
seg, return_tensors='pt',
add_special_tokens=(i == 0)).to(self.device).input_ids
for i, seg in enumerate(prompt_segs)
]
seg_embs = [
self.model.llama_model.model.embed_tokens(seg_token)
for seg_token in seg_tokens
]
mixed_embs = [
emb for pair in zip(seg_embs[:-1], img_list) for emb in pair
] + [seg_embs[-1]]
mixed_embs = torch.cat(mixed_embs, dim=1)
return mixed_embs
def ask(self, text, conv):
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[
0] and conv.messages[-1][1][-6:] == '</Img>':
conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
else:
conv.append_message(conv.roles[0], text)
def answer(self, conv, img_list, generation_cfg):
conv.append_message(conv.roles[1], None)
embs = self.get_context_emb(conv, img_list)
cur_max_len = generation_cfg['max_new_tokens'] + embs.shape[1]
if cur_max_len > self.max_length:
print('Warning: The number of tokens in current conversation'
'exceeds the max length. '
'The model will not see the contexts outside the range.')
begin_idx = max(0, cur_max_len - self.max_length)
embs = embs[:, begin_idx:]
if self.is_half:
embs = embs.half()
outputs = self.model.llama_model.generate(
inputs_embeds=embs,
eos_token_id=self.model.end_token_id,
**generation_cfg)
output_token = outputs[0]
if output_token[0] == 0:
output_token = output_token[1:]
elif output_token[0] == 1:
output_token = output_token[1:]
output_text = self.model.llama_tokenizer.decode(
output_token,
add_special_tokens=False,
skip_special_tokens=True)
output_text = output_text.split('###')[0]
conv.messages[-1][1] = output_text
return output_text

View File

@ -0,0 +1,144 @@
import argparse
import gradio as gr
import numpy as np
import torch
from conversation import EN_CONV_VISION, ZH_CONV_VISION, Chat
from mmpretrain import ImageCaptionInferencer
parser = argparse.ArgumentParser(description='MiniGPT4 demo')
parser.add_argument(
'cfg', type=str, help='config file for minigpt4 (absolute path)')
parser.add_argument(
'ckpt', type=str, help='pretrained file for minigpt4 (absolute path)')
args = parser.parse_args()
if torch.cuda.is_available():
devices = [
torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())
]
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
devices = [torch.device('mps')]
else:
devices = [torch.device('cpu')]
def get_free_device():
if hasattr(torch.cuda, 'mem_get_info'):
free = [torch.cuda.mem_get_info(gpu)[0] for gpu in devices]
select = max(zip(free, range(len(free))))[1]
else:
import random
select = random.randint(0, len(devices) - 1)
return devices[select]
device = get_free_device()
inferencer = ImageCaptionInferencer(model=args.cfg, pretrained=args.ckpt)
model = inferencer.model
chat = Chat(inferencer, device=device, is_half=(device.type != 'cpu'))
def reset(chat_state, img_list):
if chat_state is not None:
chat_state.messages = []
if img_list is not None:
img_list = []
return (None, gr.update(value=None, interactive=True),
gr.update(
value=None,
placeholder='Please upload your image first',
interactive=False),
gr.update(value='Upload & Start Chat',
interactive=True), chat_state, img_list,
gr.update(value='Restart', interactive=False),
gr.update(value='English', interactive=True))
def upload_img(gr_img, language, chat_state):
if gr_img is None:
return (None,
gr.update(
placeholder='Please upload your image first',
interactive=False),
gr.update(value='Upload & Start Chat',
interactive=True), chat_state, None,
gr.update(value='Restart', interactive=False),
gr.update(value='English', interactive=True))
if (language == 'English'):
chat_state = EN_CONV_VISION.copy()
else:
chat_state = ZH_CONV_VISION.copy()
img_list = []
gr_img_array = np.asarray(gr_img)
chat.upload_img(gr_img_array, chat_state, img_list)
return (gr.update(interactive=False),
gr.update(placeholder='Type and press Enter', interactive=True),
gr.update(value='Start Chatting',
interactive=False), chat_state, img_list,
gr.update(value='Restart',
interactive=True), gr.update(interactive=False))
def ask(user_message, chatbot, chat_state):
if (len(user_message) == 0):
return gr.update(
value=None,
placeholder='Input should not be empty!',
interactive=True), chatbot, chat_state
chat.ask(user_message, chat_state)
chatbot = chatbot + [[user_message, None]]
return '', chatbot, chat_state
def answer(chatbot, chat_state, img_list):
llm_message = chat.answer(
conv=chat_state,
img_list=img_list,
generation_cfg=model.generation_cfg)
chatbot[-1][1] = llm_message
return chatbot, chat_state, img_list
if __name__ == '__main__':
title = 'MMPretrain MiniGPT-4 Inference Demo'
with gr.Blocks(analytics_enabled=False, title=title) as demo:
gr.Markdown(f'# {title}')
with gr.Row():
with gr.Column():
image = gr.Image(type='pil')
language = gr.Dropdown(['English', 'Chinese'],
label='Language',
info='Select chatbot\'s language',
value='English',
interactive=True)
upload_button = gr.Button(
value='Upload & Start Chat', interactive=True)
clear = gr.Button(value='Restart', interactive=False)
with gr.Column():
chat_state = gr.State()
img_list = gr.State()
chatbot = gr.Chatbot(
label='MiniGPT-4', min_width=320, height=600)
text_input = gr.Textbox(
label='User',
placeholder='Please upload your image first',
interactive=False)
upload_button.click(upload_img, [image, language, chat_state], [
image, text_input, upload_button, chat_state, img_list, clear,
language
])
text_input.submit(ask, [text_input, chatbot, chat_state],
[text_input, chatbot, chat_state]).then(
answer, [chatbot, chat_state, img_list],
[chatbot, chat_state, img_list])
clear.click(reset, [chat_state, img_list], [
chatbot, image, text_input, upload_button, chat_state, img_list,
clear, language
])
demo.launch(share=True)