provide command line arguments for minigpt4 gradio demo and update some comments
parent
120ef98853
commit
b4b84e637b
|
@ -10,6 +10,43 @@ from mmpretrain.registry import DATASETS
|
|||
|
||||
@DATASETS.register_module()
|
||||
class MiniGPT4Dataset(BaseDataset):
|
||||
"""Dataset for training MiniGPT4.
|
||||
|
||||
MiniGPT4 dataset directory:
|
||||
|
||||
minigpt4_dataset
|
||||
├── image
|
||||
│ ├── id0.jpg
|
||||
│ │── id1.jpg
|
||||
│ │── id2.jpg
|
||||
│ └── ...
|
||||
└── conversation_data.json
|
||||
|
||||
The structure of conversation_data.json:
|
||||
|
||||
[
|
||||
// English data
|
||||
{
|
||||
"id": str(id0),
|
||||
"conversation": "###Ask: <Img><ImageHere></Img> [Ask content]
|
||||
###Answer: [Answer content]"
|
||||
},
|
||||
|
||||
// Chinese data
|
||||
{
|
||||
"id": str(id1),
|
||||
"conversation": "###问:<Img><ImageHere></Img> [Ask content]
|
||||
###答:[Answer content]"
|
||||
},
|
||||
|
||||
...
|
||||
]
|
||||
|
||||
Args:
|
||||
data_root (str): The root directory for ``ann_file`` and ``image``.
|
||||
ann_file (str): Conversation file path.
|
||||
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
||||
"""
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
file_backend = get_file_backend(self.data_root)
|
||||
|
|
|
@ -201,13 +201,12 @@ class MiniGPT4(BaseModel):
|
|||
prompt: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""The function to wrap the image and prompt.
|
||||
|
||||
Currently, the function only supports applying one prompt to all input
|
||||
images in the one batch.
|
||||
Make sure that len(prompt) == img_embeds.shape[0].
|
||||
|
||||
Args:
|
||||
img_embeds (torch.Tensor): The embedding of the input images.
|
||||
atts_img (torch.Tensor): Attention map of the image embeddings.
|
||||
prompt (str): The prompt of the batch data.
|
||||
prompt (List[str]): The prompt of the batch data.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The embedding and attention map.
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import argparse
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -5,6 +7,21 @@ from conversation import EN_CONV_VISION, ZH_CONV_VISION, Chat
|
|||
|
||||
from mmpretrain import ImageCaptionInferencer
|
||||
|
||||
parser = argparse.ArgumentParser(description='MiniGPT4 demo')
|
||||
parser.add_argument(
|
||||
'-mp',
|
||||
'--model_path',
|
||||
type=str,
|
||||
required=True,
|
||||
help='config file for minigpt4 (absolute path)')
|
||||
parser.add_argument(
|
||||
'-pp',
|
||||
'--pretrained_path',
|
||||
type=str,
|
||||
required=True,
|
||||
help='pretrained file for minigpt4 (absolute path)')
|
||||
args = parser.parse_args()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
devices = [
|
||||
torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())
|
||||
|
@ -26,7 +43,8 @@ def get_free_device():
|
|||
|
||||
|
||||
device = get_free_device()
|
||||
inferencer = ImageCaptionInferencer('minigpt-4_vicuna-7b_caption')
|
||||
inferencer = ImageCaptionInferencer(
|
||||
model=args.model_path, pretrained=args.pretrained_path)
|
||||
model = inferencer.model
|
||||
chat = Chat(inferencer, device=device, is_half=True)
|
||||
|
||||
|
|
Loading…
Reference in New Issue