provide command line arguments for minigpt4 gradio demo and update some comments

pull/1758/head
hmtbgc 2023-09-13 15:08:01 +08:00
parent 120ef98853
commit b4b84e637b
3 changed files with 58 additions and 4 deletions

View File

@ -10,6 +10,43 @@ from mmpretrain.registry import DATASETS
@DATASETS.register_module()
class MiniGPT4Dataset(BaseDataset):
"""Dataset for training MiniGPT4.
MiniGPT4 dataset directory:
minigpt4_dataset
image
id0.jpg
id1.jpg
id2.jpg
...
conversation_data.json
The structure of conversation_data.json:
[
// English data
{
"id": str(id0),
"conversation": "###Ask: <Img><ImageHere></Img> [Ask content]
###Answer: [Answer content]"
},
// Chinese data
{
"id": str(id1),
"conversation": "###问:<Img><ImageHere></Img> [Ask content]
###答:[Answer content]"
},
...
]
Args:
data_root (str): The root directory for ``ann_file`` and ``image``.
ann_file (str): Conversation file path.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
def load_data_list(self) -> List[dict]:
file_backend = get_file_backend(self.data_root)

View File

@ -201,13 +201,12 @@ class MiniGPT4(BaseModel):
prompt: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
"""The function to wrap the image and prompt.
Currently, the function only supports applying one prompt to all input
images in the one batch.
Make sure that len(prompt) == img_embeds.shape[0].
Args:
img_embeds (torch.Tensor): The embedding of the input images.
atts_img (torch.Tensor): Attention map of the image embeddings.
prompt (str): The prompt of the batch data.
prompt (List[str]): The prompt of the batch data.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The embedding and attention map.

View File

@ -1,3 +1,5 @@
import argparse
import gradio as gr
import numpy as np
import torch
@ -5,6 +7,21 @@ from conversation import EN_CONV_VISION, ZH_CONV_VISION, Chat
from mmpretrain import ImageCaptionInferencer
parser = argparse.ArgumentParser(description='MiniGPT4 demo')
parser.add_argument(
'-mp',
'--model_path',
type=str,
required=True,
help='config file for minigpt4 (absolute path)')
parser.add_argument(
'-pp',
'--pretrained_path',
type=str,
required=True,
help='pretrained file for minigpt4 (absolute path)')
args = parser.parse_args()
if torch.cuda.is_available():
devices = [
torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())
@ -26,7 +43,8 @@ def get_free_device():
device = get_free_device()
inferencer = ImageCaptionInferencer('minigpt-4_vicuna-7b_caption')
inferencer = ImageCaptionInferencer(
model=args.model_path, pretrained=args.pretrained_path)
model = inferencer.model
chat = Chat(inferencer, device=device, is_half=True)