From b4b84e637b3e69469953034942d7e13a1a7dc2a4 Mon Sep 17 00:00:00 2001 From: hmtbgc Date: Wed, 13 Sep 2023 15:08:01 +0800 Subject: [PATCH] provide command line arguments for minigpt4 gradio demo and update some comments --- mmpretrain/datasets/minigpt4_dataset.py | 37 +++++++++++++++++++ .../models/multimodal/minigpt4/minigpt4.py | 5 +-- projects/gradio_demo/minigpt4_demo.py | 20 +++++++++- 3 files changed, 58 insertions(+), 4 deletions(-) diff --git a/mmpretrain/datasets/minigpt4_dataset.py b/mmpretrain/datasets/minigpt4_dataset.py index 0f022ba2..e14e5c35 100644 --- a/mmpretrain/datasets/minigpt4_dataset.py +++ b/mmpretrain/datasets/minigpt4_dataset.py @@ -10,6 +10,43 @@ from mmpretrain.registry import DATASETS @DATASETS.register_module() class MiniGPT4Dataset(BaseDataset): + """Dataset for training MiniGPT4. + + MiniGPT4 dataset directory: + + minigpt4_dataset + ├── image + │ ├── id0.jpg + │ │── id1.jpg + │ │── id2.jpg + │ └── ... + └── conversation_data.json + + The structure of conversation_data.json: + + [ + // English data + { + "id": str(id0), + "conversation": "###Ask: [Ask content] + ###Answer: [Answer content]" + }, + + // Chinese data + { + "id": str(id1), + "conversation": "###问: [Ask content] + ###答:[Answer content]" + }, + + ... + ] + + Args: + data_root (str): The root directory for ``ann_file`` and ``image``. + ann_file (str): Conversation file path. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ def load_data_list(self) -> List[dict]: file_backend = get_file_backend(self.data_root) diff --git a/mmpretrain/models/multimodal/minigpt4/minigpt4.py b/mmpretrain/models/multimodal/minigpt4/minigpt4.py index dc82ccae..0a6db6de 100644 --- a/mmpretrain/models/multimodal/minigpt4/minigpt4.py +++ b/mmpretrain/models/multimodal/minigpt4/minigpt4.py @@ -201,13 +201,12 @@ class MiniGPT4(BaseModel): prompt: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: """The function to wrap the image and prompt. - Currently, the function only supports applying one prompt to all input - images in the one batch. + Make sure that len(prompt) == img_embeds.shape[0]. Args: img_embeds (torch.Tensor): The embedding of the input images. atts_img (torch.Tensor): Attention map of the image embeddings. - prompt (str): The prompt of the batch data. + prompt (List[str]): The prompt of the batch data. Returns: Tuple[torch.Tensor, torch.Tensor]: The embedding and attention map. diff --git a/projects/gradio_demo/minigpt4_demo.py b/projects/gradio_demo/minigpt4_demo.py index 0d255f79..402c0f23 100644 --- a/projects/gradio_demo/minigpt4_demo.py +++ b/projects/gradio_demo/minigpt4_demo.py @@ -1,3 +1,5 @@ +import argparse + import gradio as gr import numpy as np import torch @@ -5,6 +7,21 @@ from conversation import EN_CONV_VISION, ZH_CONV_VISION, Chat from mmpretrain import ImageCaptionInferencer +parser = argparse.ArgumentParser(description='MiniGPT4 demo') +parser.add_argument( + '-mp', + '--model_path', + type=str, + required=True, + help='config file for minigpt4 (absolute path)') +parser.add_argument( + '-pp', + '--pretrained_path', + type=str, + required=True, + help='pretrained file for minigpt4 (absolute path)') +args = parser.parse_args() + if torch.cuda.is_available(): devices = [ torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count()) @@ -26,7 +43,8 @@ def get_free_device(): device = get_free_device() -inferencer = ImageCaptionInferencer('minigpt-4_vicuna-7b_caption') +inferencer = ImageCaptionInferencer( + model=args.model_path, pretrained=args.pretrained_path) model = inferencer.model chat = Chat(inferencer, device=device, is_half=True)