[Feature] Add minigpt4 gradio demo and training script. (#1758)

* Add minigpt4 gradio demo

* update minigpt4 demo

* update minigpt4 demo (inference with float16)

* update minigpt4 and some dependent files

* add minigpt4 dataset for training

* add training script for minigpt4

* restore files deleted by mistake

* fix an error

* remove useless modification

* provide command line arguments for minigpt4 gradio demo and update some comments

* update code

* Update minigpt-4 readme

---------

Co-authored-by: mzr1996 <mzr1996@163.com>
pull/1853/head
hmtbgc 2023-10-12 10:36:17 +08:00 committed by GitHub
parent 5c71de6b8e
commit c0766519b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 651 additions and 50 deletions

View File

@ -34,9 +34,10 @@ For Vicuna model, please refer to [MiniGPT-4 page](https://github.com/Vision-CAI
### Pretrained models
| Model | Params (M) | Flops (G) | Config | Download |
| :------------------------------ | :--------: | :-------: | :--------------------------------------: | :------------------------------------------------------------------------------------------------------------: |
| `minigpt-4_vicuna-7b_caption`\* | 8121.32 | N/A | [config](minigpt-4_vicuna-7b_caption.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_linear-projection_20230615-714b5f52.pth) |
| Model | Params (M) | Flops (G) | Config | Download |
| :------------------------------ | :--------: | :-------: | :----------------------------------------: | :----------------------------------------------------------------------------------------------------------: |
| `minigpt-4_baichuan-7b_caption` | 8094.77 | N/A | [config](minigpt-4_baichuan-7b_caption.py) | [model](https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_baichuan7b_20231011-5dca7ed6.pth) |
| `minigpt-4_vicuna-7b_caption`\* | 8121.32 | N/A | [config](minigpt-4_vicuna-7b_caption.py) | [model](https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_vicuna7b_20230615-714b5f52.pth) |
*Models with * are converted from the [official repo](https://github.com/Vision-CAIR/MiniGPT-4/tree/main). The config files of these models are only for inference. We haven't reproduce the training results.*

View File

@ -19,8 +19,19 @@ Models:
- Task: Image Caption
Dataset: COCO
Metrics: null
Weights: https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_linear-projection_20230615-714b5f52.pth
Weights: https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_vicuna7b_20230615-714b5f52.pth
Config: configs/minigpt4/minigpt-4_vicuna-7b_caption.py
Converted From:
Weights: https://github.com/Vision-CAIR/MiniGPT-4/tree/main
Code: https://github.com/Vision-CAIR/MiniGPT-4/tree/main
- Name: minigpt-4_baichuan-7b_caption
Metadata:
FLOPs: null
Parameters: 8094769024
In Collection: MiniGPT4
Results:
- Task: Image Caption
Dataset: COCO
Metrics: null
Weights: https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_baichuan7b_20231011-5dca7ed6.pth
Config: configs/minigpt4/minigpt-4_baichuan-7b_caption.py

View File

@ -0,0 +1,190 @@
_base_ = [
'../_base_/default_runtime.py',
]
data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)
# dataset settings
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(224, 224),
interpolation='bicubic',
backend='pillow'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='CleanCaption',
keys='chat_content',
remove_chars='',
lowercase=False),
dict(
type='PackInputs',
algorithm_keys=['chat_content', 'lang'],
meta_keys=['image_id']),
]
train_dataloader = dict(
batch_size=2,
num_workers=4,
dataset=dict(
type='MiniGPT4Dataset',
data_root='YOUR_DATA_DIRECTORY',
ann_file='YOUR_DATA_FILE',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
drop_last=False,
)
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(224, 224),
interpolation='bicubic',
backend='pillow'),
dict(type='PackInputs', meta_keys=['image_id']),
]
test_evaluator = dict(
type='COCOCaption',
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
)
test_dataloader = dict(
batch_size=1,
dataset=dict(
type='COCOCaption',
data_root='data/coco',
ann_file='annotations/coco_karpathy_val.json',
pipeline=test_pipeline))
# model settings
model = dict(
type='MiniGPT4',
vision_encoder=dict(
type='BEiTViT',
# eva-g without the final layer
arch=dict(
embed_dims=1408,
num_layers=39,
num_heads=16,
feedforward_channels=6144,
),
img_size=224,
patch_size=14,
layer_scale_init_value=0.0,
frozen_stages=39,
use_abs_pos_emb=True,
use_rel_pos_bias=False,
final_norm=False,
use_shared_rel_pos_bias=False,
out_type='raw',
pretrained= # noqa
'https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_eva-g-p14_20230615-e908c021.pth' # noqa
),
q_former_model=dict(
type='Qformer',
model_style='bert-base-uncased',
vision_model_width=1408,
add_cross_attention=True,
cross_attention_freq=2,
num_query_token=32,
pretrained= # noqa
'https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_qformer_20230615-1dfa889c.pth' # noqa
),
lang_encoder=dict(
type='AutoModelForCausalLM',
name_or_path='baichuan-inc/baichuan-7B',
trust_remote_code=True),
tokenizer=dict(
type='AutoTokenizer',
name_or_path='baichuan-inc/baichuan-7B',
trust_remote_code=True),
task='caption',
prompt_template=dict([('en', '###Ask: {} ###Answer: '),
('zh', '###问:{} ###答:')]),
raw_prompts=dict([
('en', [('<Img><ImageHere></Img> '
'Describe this image in detail.'),
('<Img><ImageHere></Img> '
'Take a look at this image and describe what you notice.'),
('<Img><ImageHere></Img> '
'Please provide a detailed description of the picture.'),
('<Img><ImageHere></Img> '
'Could you describe the contents of this image for me?')]),
('zh', [('<Img><ImageHere></Img> '
'详细描述这张图片。'), ('<Img><ImageHere></Img> '
'浏览这张图片并描述你注意到什么。'),
('<Img><ImageHere></Img> '
'请对这张图片进行详细的描述。'),
('<Img><ImageHere></Img> '
'你能为我描述这张图片的内容吗?')])
]),
max_txt_len=160,
end_sym='###')
strategy = dict(
type='DeepSpeedStrategy',
fp16=dict(
enabled=True,
auto_cast=False,
fp16_master_weights_and_grads=False,
loss_scale=0,
loss_scale_window=1000,
hysteresis=1,
min_loss_scale=1,
initial_scale_power=16,
),
inputs_to_half=[0],
zero_optimization=dict(
stage=2,
allgather_partitions=True,
allgather_bucket_size=2e8,
reduce_scatter=True,
reduce_bucket_size='auto',
overlap_comm=True,
contiguous_gradients=True,
),
)
# schedule settings
optim_wrapper = dict(
type='DeepSpeedOptimWrapper',
optimizer=dict(type='AdamW', lr=1e-3, weight_decay=0.05))
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-3 / 500,
by_epoch=False,
begin=0,
end=500,
),
dict(
type='CosineAnnealingLR',
eta_min=2e-4,
by_epoch=False,
begin=500,
),
]
train_cfg = dict(by_epoch=True, max_epochs=6)
test_cfg = dict()
runner_type = 'FlexibleRunner'
default_hooks = dict(
checkpoint=dict(
type='CheckpointHook',
interval=1,
by_epoch=True,
save_last=True,
max_keep_ckpts=1,
))

View File

@ -55,13 +55,25 @@ model = dict(
type='AutoModelForCausalLM', name_or_path='YOUR_PATH_TO_VICUNA'),
tokenizer=dict(type='LlamaTokenizer', name_or_path='YOUR_PATH_TO_VICUNA'),
task='caption',
prompt_template='###Human: {} ###Assistant: ',
raw_prompts=[
'<Img><ImageHere></Img> Describe this image in detail.',
'<Img><ImageHere></Img> Take a look at this image and describe what you notice.', # noqa
'<Img><ImageHere></Img> Please provide a detailed description of the picture.', # noqa
'<Img><ImageHere></Img> Could you describe the contents of this image for me?', # noqa
],
prompt_template=dict([('en', '###Ask: {} ###Answer: '),
('zh', '###问:{} ###答:')]),
raw_prompts=dict([
('en', [('<Img><ImageHere></Img> '
'Describe this image in detail.'),
('<Img><ImageHere></Img> '
'Take a look at this image and describe what you notice.'),
('<Img><ImageHere></Img> '
'Please provide a detailed description of the picture.'),
('<Img><ImageHere></Img> '
'Could you describe the contents of this image for me?')]),
('zh', [('<Img><ImageHere></Img> '
'详细描述这张图片。'), ('<Img><ImageHere></Img> '
'浏览这张图片并描述你注意到什么。'),
('<Img><ImageHere></Img> '
'请对这张图片进行详细的描述。'),
('<Img><ImageHere></Img> '
'你能为我描述这张图片的内容吗?')])
]),
max_txt_len=160,
end_sym='###')

View File

@ -43,6 +43,7 @@ if WITH_MULTIMODAL:
from .gqa_dataset import GQA
from .iconqa import IconQA
from .infographic_vqa import InfographicVQA
from .minigpt4_dataset import MiniGPT4Dataset
from .nocaps import NoCaps
from .ocr_vqa import OCRVQA
from .refcoco import RefCOCO
@ -56,5 +57,6 @@ if WITH_MULTIMODAL:
'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption',
'FlamingoEvalCOCOVQA', 'Flickr30kCaption', 'Flickr30kRetrieval',
'RefCOCO', 'VisualGenomeQA', 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA',
'VSR', 'VizWiz', 'OCRVQA', 'InfographicVQA', 'IconQA'
'VSR', 'VizWiz', 'OCRVQA', 'InfographicVQA', 'IconQA',
'MiniGPT4Dataset'
])

View File

@ -0,0 +1,79 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import mmengine
from mmengine.dataset import BaseDataset
from mmengine.fileio import get_file_backend
from mmpretrain.registry import DATASETS
@DATASETS.register_module()
class MiniGPT4Dataset(BaseDataset):
"""Dataset for training MiniGPT4.
MiniGPT4 dataset directory:
minigpt4_dataset
image
id0.jpg
id1.jpg
id2.jpg
...
conversation_data.json
The structure of conversation_data.json:
[
// English data
{
"id": str(id0),
"conversation": "###Ask: <Img><ImageHere></Img> [Ask content]
###Answer: [Answer content]"
},
// Chinese data
{
"id": str(id1),
"conversation": "###问:<Img><ImageHere></Img> [Ask content]
###答:[Answer content]"
},
...
]
Args:
data_root (str): The root directory for ``ann_file`` and ``image``.
ann_file (str): Conversation file path.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
def load_data_list(self) -> List[dict]:
file_backend = get_file_backend(self.data_root)
conversation_path = file_backend.join_path(self.data_root,
self.ann_file)
conversation = mmengine.load(conversation_path)
img_ids = {}
n = 0
for conv in conversation:
img_id = conv['id']
if img_id not in img_ids.keys():
img_ids[img_id] = n
n += 1
img_root = file_backend.join_path(self.data_root, 'image')
data_list = []
for conv in conversation:
img_file = '{}.jpg'.format(conv['id'])
chat_content = conv['conversation']
lang = 'en' if chat_content.startswith('###Ask: ') else 'zh'
data_info = {
'image_id': img_ids[conv['id']],
'img_path': file_backend.join_path(img_root, img_file),
'chat_content': chat_content,
'lang': lang,
}
data_list.append(data_info)
return data_list

View File

@ -31,12 +31,12 @@ class MiniGPT4(BaseModel):
True.
num_query_token (int): Number of query tokens of Qformer. Defaults to
32.
prompt_template (str): Prompt template of the model. Defaults to
'###Human: {} ###Assistant: '.
raw_prompts (list): Prompts for training. Defaults to None.
prompt_template (dict): Multi-language prompt template of the model. Defaults to dict([ ('en', '###Ask: {} ###Answer: '),
('zh', '###问:{} ###答:')])
raw_prompts (dict): Prompts for training. Defaults to dict().
max_txt_len (int): Max token length while doing tokenization. Defaults
to 32.
end_sym (str): Ended symbol of the sequence. Defaults to '\\n'.
end_sym (str): Ended symbol of the sequence. Defaults to '###'.
generation_cfg (dict): The config of text generation. Defaults to
dict().
data_preprocessor (:obj:`BaseDataPreprocessor`): Used for
@ -54,10 +54,12 @@ class MiniGPT4(BaseModel):
freeze_vit: bool = True,
freeze_q_former: bool = True,
num_query_token: int = 32,
prompt_template: str = '###Human: {} ###Assistant: ',
raw_prompts: Optional[list] = None,
prompt_template: dict = dict([('en',
'###Ask: {} ###Answer: '),
('zh', '###问:{} ###答:')]),
raw_prompts: dict = dict(),
max_txt_len: int = 32,
end_sym: str = '\n',
end_sym: str = '###',
generation_cfg: dict = dict(),
data_preprocessor: Optional[dict] = None,
init_cfg: Optional[dict] = None):
@ -135,16 +137,23 @@ class MiniGPT4(BaseModel):
self.end_token_id = self.llama_tokenizer.encode(end_sym)[-1]
# set prompts
if raw_prompts is not None:
filted_prompts = [
raw_prompt for raw_prompt in raw_prompts
self.en_prompt_list, self.zh_prompt_list = [], []
if raw_prompts.get('en') is not None:
en_filted_prompts = [
raw_prompt for raw_prompt in raw_prompts['en']
if '<ImageHere>' in raw_prompt
]
self.prompt_list = [
prompt_template.format(p) for p in filted_prompts
self.en_prompt_list = [
prompt_template['en'].format(p) for p in en_filted_prompts
]
if raw_prompts.get('zh') is not None:
zh_filted_prompts = [
raw_prompt for raw_prompt in raw_prompts['zh']
if '<ImageHere>' in raw_prompt
]
self.zh_prompt_list = [
prompt_template['zh'].format(p) for p in zh_filted_prompts
]
else:
self.prompt_list = []
# update generation configs
self.generation_cfg = dict(
@ -153,7 +162,7 @@ class MiniGPT4(BaseModel):
do_sample=True,
min_length=1,
top_p=0.9,
repetition_penalty=1.0,
repetition_penalty=1.1,
length_penalty=1.0,
temperature=1.0)
self.generation_cfg.update(**generation_cfg)
@ -161,6 +170,10 @@ class MiniGPT4(BaseModel):
if hasattr(self, 'register_load_state_dict_post_hook'):
self.register_load_state_dict_post_hook(self._load_llama_proj_hook)
def half(self):
self.llama_model = self.llama_model.half()
return self
def encode_img(self,
images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""The function to encode the images."""
@ -184,33 +197,39 @@ class MiniGPT4(BaseModel):
return inputs_llama, atts_llama
def prompt_wrap(self, img_embeds: torch.Tensor, atts_img: torch.Tensor,
prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
prompt: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
"""The function to wrap the image and prompt.
Currently, the function only supports applying one prompt to all input
images in the one batch.
Make sure that len(prompt) == img_embeds.shape[0].
Args:
img_embeds (torch.Tensor): The embedding of the input images.
atts_img (torch.Tensor): Attention map of the image embeddings.
prompt (str): The prompt of the batch data.
prompt (List[str]): The prompt of the batch data.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The embedding and attention map.
"""
if prompt:
batch_size = img_embeds.shape[0]
p_before, p_after = prompt.split('<ImageHere>')
if len(prompt) > 0:
p_before_list, p_after_list = [], []
for pro in prompt:
p_before, p_after = pro.split('<ImageHere>')
p_before_list.append(p_before)
p_after_list.append(p_after)
p_before_tokens = self.llama_tokenizer(
p_before, return_tensors='pt',
p_before_list,
return_tensors='pt',
padding='longest',
add_special_tokens=False).to(img_embeds.device)
p_after_tokens = self.llama_tokenizer(
p_after, return_tensors='pt',
p_after_list,
return_tensors='pt',
padding='longest',
add_special_tokens=False).to(img_embeds.device)
p_before_embeds = self.llama_model.model.embed_tokens(
p_before_tokens.input_ids).expand(batch_size, -1, -1)
p_before_tokens.input_ids)
p_after_embeds = self.llama_model.model.embed_tokens(
p_after_tokens.input_ids).expand(batch_size, -1, -1)
p_after_tokens.input_ids)
wrapped_img_embeds = torch.cat(
[p_before_embeds, img_embeds, p_after_embeds], dim=1)
wrapped_atts_img = atts_img[:, :1].expand(
@ -234,17 +253,22 @@ class MiniGPT4(BaseModel):
"""
img_embeds, atts_img = self.encode_img(images)
if self.task == 'caption' and self.prompt_list:
prompt = random.choice(self.prompt_list)
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img,
prompt)
self.llama_tokenizer.padding_side = 'right'
text = [t + self.end_sym for t in data_samples['text_input']]
prompts, texts = [], []
for t in data_samples:
chat_content = t.chat_content
split_mark = '###Answer: ' if t.lang == 'en' else '###答:'
prompt, text = chat_content.split(split_mark)
prompt += split_mark
text += self.end_sym
prompts.append(prompt)
texts.append(text)
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompts)
to_regress_tokens = self.llama_tokenizer(
text,
texts,
return_tensors='pt',
padding='longest',
truncation=True,
@ -295,10 +319,12 @@ class MiniGPT4(BaseModel):
with torch.no_grad():
img_embeds, atts_img = self.encode_img(images)
if self.task == 'caption' and self.prompt_list:
prompt = random.choice(self.prompt_list)
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img,
prompt)
prompts = [
random.choice(self.zh_prompt_list) if hasattr(t, 'lang')
and t.lang == 'zh' else random.choice(self.en_prompt_list)
for t in data_samples
]
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompts)
batch_size = img_embeds.shape[0]
bos = torch.ones(
@ -336,7 +362,6 @@ class MiniGPT4(BaseModel):
for output, data_sample in zip(outputs, data_samples):
if self.task == 'caption':
output = output.split('###')[0]
output = output.split('Assistant:')[-1].strip()
data_sample.pred_caption = output
else:
# raw output

View File

@ -0,0 +1,137 @@
# Modified from
# https://github.com/Vision-CAIR/MiniGPT-4/blob/main/minigpt4/conversation/conversation.py
import dataclasses
from typing import List
import torch
@dataclasses.dataclass
class Conversation:
system: str
roles: List[str]
messages: List[List[str]]
sep: str = '###'
def get_prompt(self):
ret = self.system + self.sep
for role, message in self.messages:
if message:
ret += role + ': ' + message + self.sep
else:
ret += role + ':'
return ret
def append_message(self, role, message):
self.messages.append([role, message])
def copy(self):
return Conversation(
system=self.system,
roles=[role for role in self.roles],
messages=[[y for y in x] for x in self.messages],
sep=self.sep,
)
def dict(self):
return {
'system': self.system,
'roles': self.roles,
'messages': self.messages,
'offset': self.offset,
'sep': self.sep,
}
EN_CONV_VISION = Conversation(
system='Give the following image. '
'You will be able to see the image once I provide it to you. '
'Please answer my questions in detail.',
roles=['Ask', 'Answer'],
messages=[],
sep='###',
)
ZH_CONV_VISION = Conversation(
system='给定一张图片,请仔细观察这张图片,并回答我的问题。',
roles=['', ''],
messages=[],
sep='###',
)
class Chat:
def __init__(self, inferencer, device, is_half=False):
self.device = device
self.inferencer = inferencer
self.model = inferencer.model
self.is_half = is_half
if is_half:
self.model = self.model.half()
self.model = self.model.to(device)
self.max_length = 2000
def upload_img(self, image, conv, img_list):
img = next(self.inferencer.preprocess([image]))
img = self.model.data_preprocessor(img, False)['images']
img = img.to(self.device)
image_emb, _ = self.model.encode_img(img)
img_list.append(image_emb)
conv.append_message(conv.roles[0], '<Img><ImageHere></Img>')
def get_context_emb(self, conv, img_list):
prompt = conv.get_prompt()
prompt_segs = prompt.split('<ImageHere>')
seg_tokens = [
self.model.llama_tokenizer(
seg, return_tensors='pt',
add_special_tokens=(i == 0)).to(self.device).input_ids
for i, seg in enumerate(prompt_segs)
]
seg_embs = [
self.model.llama_model.model.embed_tokens(seg_token)
for seg_token in seg_tokens
]
mixed_embs = [
emb for pair in zip(seg_embs[:-1], img_list) for emb in pair
] + [seg_embs[-1]]
mixed_embs = torch.cat(mixed_embs, dim=1)
return mixed_embs
def ask(self, text, conv):
if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[
0] and conv.messages[-1][1][-6:] == '</Img>':
conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
else:
conv.append_message(conv.roles[0], text)
def answer(self, conv, img_list, generation_cfg):
conv.append_message(conv.roles[1], None)
embs = self.get_context_emb(conv, img_list)
cur_max_len = generation_cfg['max_new_tokens'] + embs.shape[1]
if cur_max_len > self.max_length:
print('Warning: The number of tokens in current conversation'
'exceeds the max length. '
'The model will not see the contexts outside the range.')
begin_idx = max(0, cur_max_len - self.max_length)
embs = embs[:, begin_idx:]
if self.is_half:
embs = embs.half()
outputs = self.model.llama_model.generate(
inputs_embeds=embs,
eos_token_id=self.model.end_token_id,
**generation_cfg)
output_token = outputs[0]
if output_token[0] == 0:
output_token = output_token[1:]
elif output_token[0] == 1:
output_token = output_token[1:]
output_text = self.model.llama_tokenizer.decode(
output_token,
add_special_tokens=False,
skip_special_tokens=True)
output_text = output_text.split('###')[0]
conv.messages[-1][1] = output_text
return output_text

View File

@ -0,0 +1,144 @@
import argparse
import gradio as gr
import numpy as np
import torch
from conversation import EN_CONV_VISION, ZH_CONV_VISION, Chat
from mmpretrain import ImageCaptionInferencer
parser = argparse.ArgumentParser(description='MiniGPT4 demo')
parser.add_argument(
'cfg', type=str, help='config file for minigpt4 (absolute path)')
parser.add_argument(
'ckpt', type=str, help='pretrained file for minigpt4 (absolute path)')
args = parser.parse_args()
if torch.cuda.is_available():
devices = [
torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())
]
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
devices = [torch.device('mps')]
else:
devices = [torch.device('cpu')]
def get_free_device():
if hasattr(torch.cuda, 'mem_get_info'):
free = [torch.cuda.mem_get_info(gpu)[0] for gpu in devices]
select = max(zip(free, range(len(free))))[1]
else:
import random
select = random.randint(0, len(devices) - 1)
return devices[select]
device = get_free_device()
inferencer = ImageCaptionInferencer(model=args.cfg, pretrained=args.ckpt)
model = inferencer.model
chat = Chat(inferencer, device=device, is_half=(device.type != 'cpu'))
def reset(chat_state, img_list):
if chat_state is not None:
chat_state.messages = []
if img_list is not None:
img_list = []
return (None, gr.update(value=None, interactive=True),
gr.update(
value=None,
placeholder='Please upload your image first',
interactive=False),
gr.update(value='Upload & Start Chat',
interactive=True), chat_state, img_list,
gr.update(value='Restart', interactive=False),
gr.update(value='English', interactive=True))
def upload_img(gr_img, language, chat_state):
if gr_img is None:
return (None,
gr.update(
placeholder='Please upload your image first',
interactive=False),
gr.update(value='Upload & Start Chat',
interactive=True), chat_state, None,
gr.update(value='Restart', interactive=False),
gr.update(value='English', interactive=True))
if (language == 'English'):
chat_state = EN_CONV_VISION.copy()
else:
chat_state = ZH_CONV_VISION.copy()
img_list = []
gr_img_array = np.asarray(gr_img)
chat.upload_img(gr_img_array, chat_state, img_list)
return (gr.update(interactive=False),
gr.update(placeholder='Type and press Enter', interactive=True),
gr.update(value='Start Chatting',
interactive=False), chat_state, img_list,
gr.update(value='Restart',
interactive=True), gr.update(interactive=False))
def ask(user_message, chatbot, chat_state):
if (len(user_message) == 0):
return gr.update(
value=None,
placeholder='Input should not be empty!',
interactive=True), chatbot, chat_state
chat.ask(user_message, chat_state)
chatbot = chatbot + [[user_message, None]]
return '', chatbot, chat_state
def answer(chatbot, chat_state, img_list):
llm_message = chat.answer(
conv=chat_state,
img_list=img_list,
generation_cfg=model.generation_cfg)
chatbot[-1][1] = llm_message
return chatbot, chat_state, img_list
if __name__ == '__main__':
title = 'MMPretrain MiniGPT-4 Inference Demo'
with gr.Blocks(analytics_enabled=False, title=title) as demo:
gr.Markdown(f'# {title}')
with gr.Row():
with gr.Column():
image = gr.Image(type='pil')
language = gr.Dropdown(['English', 'Chinese'],
label='Language',
info='Select chatbot\'s language',
value='English',
interactive=True)
upload_button = gr.Button(
value='Upload & Start Chat', interactive=True)
clear = gr.Button(value='Restart', interactive=False)
with gr.Column():
chat_state = gr.State()
img_list = gr.State()
chatbot = gr.Chatbot(
label='MiniGPT-4', min_width=320, height=600)
text_input = gr.Textbox(
label='User',
placeholder='Please upload your image first',
interactive=False)
upload_button.click(upload_img, [image, language, chat_state], [
image, text_input, upload_button, chat_state, img_list, clear,
language
])
text_input.submit(ask, [text_input, chatbot, chat_state],
[text_input, chatbot, chat_state]).then(
answer, [chatbot, chat_state, img_list],
[chatbot, chat_state, img_list])
clear.click(reset, [chat_state, img_list], [
chatbot, image, text_input, upload_button, chat_state, img_list,
clear, language
])
demo.launch(share=True)