mirror of
https://github.com/open-mmlab/mmpretrain.git
synced 2025-06-03 14:59:18 +08:00
* [Feat] Migrate blip caption to mmpretrain. (#50) * Migrate blip caption to mmpretrain * minor fix * support train * [Feature] Support OFA caption task. (#51) * [Feature] Support OFA caption task. * Remove duplicated files. * [Feature] Support OFA vqa task. (#58) * [Feature] Support OFA vqa task. * Fix lint. * [Feat] Add BLIP retrieval to mmpretrain. (#55) * init * minor fix for train * fix according to comments * refactor * Update Blip retrieval. (#62) * [Feature] Support OFA visual grounding task. (#59) * [Feature] Support OFA visual grounding task. * minor add TODO --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * [Feat] Add flamingos coco caption and vqa. (#60) * first init * init flamingo coco * add vqa * minor fix * remove unnecessary modules * Update config * Use `ApplyToList`. --------- Co-authored-by: mzr1996 <mzr1996@163.com> * [Feature]: BLIP2 coco retrieval (#53) * [Feature]: Add blip2 retriever * [Feature]: Add blip2 all modules * [Feature]: Refine model * [Feature]: x1 * [Feature]: Runnable coco ret * [Feature]: Runnable version * [Feature]: Fix lint * [Fix]: Fix lint * [Feature]: Use 364 img size * [Feature]: Refactor blip2 * [Fix]: Fix lint * refactor files * minor fix * minor fix --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * Remove * fix blip caption inputs (#68) * [Feat] Add BLIP NLVR support. (#67) * first init * init flamingo coco * add vqa * add nlvr * refactor nlvr * minor fix * minor fix * Update dataset --------- Co-authored-by: mzr1996 <mzr1996@163.com> * [Feature]: BLIP2 Caption (#70) * [Feature]: Add language model * [Feature]: blip2 caption forward * [Feature]: Reproduce the results * [Feature]: Refactor caption * refine config --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * [Feat] Migrate BLIP VQA to mmpretrain (#69) * reformat * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * change * refactor code --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * Update RefCOCO dataset * [Fix] fix lint * [Feature] Implement inference APIs for multi-modal tasks. (#65) * [Feature] Implement inference APIs for multi-modal tasks. * [Project] Add gradio demo. * [Improve] Update requirements * Update flamingo * Update blip * Add NLVR inferencer * Update flamingo * Update hugging face model register * Update ofa vqa * Update BLIP-vqa (#71) * Update blip-vqa docstring (#72) * Refine flamingo docstring (#73) * [Feature]: BLIP2 VQA (#61) * [Feature]: VQA forward * [Feature]: Reproduce accuracy * [Fix]: Fix lint * [Fix]: Add blank line * minor fix --------- Co-authored-by: yingfhu <yingfhu@gmail.com> * [Feature]: BLIP2 docstring (#74) * [Feature]: Add caption docstring * [Feature]: Add docstring to blip2 vqa * [Feature]: Add docstring to retrieval * Update BLIP-2 metafile and README (#75) * [Feature]: Add readme and docstring * Update blip2 results --------- Co-authored-by: mzr1996 <mzr1996@163.com> * [Feature] BLIP Visual Grounding on MMPretrain Branch (#66) * blip grounding merge with mmpretrain * remove commit * blip grounding test and inference api * refcoco dataset * refcoco dataset refine config * rebasing * gitignore * rebasing * minor edit * minor edit * Update blip-vqa docstring (#72) * rebasing * Revert "minor edit" This reverts commit 639cec757c215e654625ed0979319e60f0be9044. * blip grounding final * precommit * refine config * refine config * Update blip visual grounding --------- Co-authored-by: Yiqin Wang 王逸钦 <wyq1217@outlook.com> Co-authored-by: mzr1996 <mzr1996@163.com> * Update visual grounding metric * Update OFA docstring, README and metafiles. (#76) * [Docs] Update installation docs and gradio demo docs. (#77) * Update OFA name * Update Visual Grounding Visualizer * Integrate accelerate support * Fix imports. * Fix timm backbone * Update imports * Update README * Update circle ci * Update flamingo config * Add gradio demo README * [Feature]: Add scienceqa (#1571) * [Feature]: Add scienceqa * [Feature]: Change param name * Update docs * Update video --------- Co-authored-by: Hubert <42952108+yingfhu@users.noreply.github.com> Co-authored-by: yingfhu <yingfhu@gmail.com> Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com> Co-authored-by: Yiqin Wang 王逸钦 <wyq1217@outlook.com> Co-authored-by: Rongjie Li <limo97@163.com>
467 lines
16 KiB
Python
467 lines
16 KiB
Python
from functools import partial
|
|
from pathlib import Path
|
|
from typing import Callable
|
|
|
|
import gradio as gr
|
|
import torch
|
|
from mmengine.logging import MMLogger
|
|
|
|
import mmpretrain
|
|
from mmpretrain.apis import (ImageCaptionInferencer,
|
|
ImageClassificationInferencer,
|
|
ImageRetrievalInferencer,
|
|
TextToImageRetrievalInferencer,
|
|
VisualGroundingInferencer,
|
|
VisualQuestionAnsweringInferencer)
|
|
from mmpretrain.utils.dependency import WITH_MULTIMODAL
|
|
from mmpretrain.visualization import UniversalVisualizer
|
|
|
|
mmpretrain.utils.progress.disable_progress_bar = True
|
|
|
|
logger = MMLogger('mmpretrain', logger_name='mmpre')
|
|
if torch.cuda.is_available():
|
|
gpus = [
|
|
torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())
|
|
]
|
|
logger.info(f'Available GPUs: {len(gpus)}')
|
|
else:
|
|
gpus = None
|
|
logger.info('No available GPU.')
|
|
|
|
|
|
def get_free_device():
|
|
if gpus is None:
|
|
return torch.device('cpu')
|
|
if hasattr(torch.cuda, 'mem_get_info'):
|
|
free = [torch.cuda.mem_get_info(gpu)[0] for gpu in gpus]
|
|
select = max(zip(free, range(len(free))))[1]
|
|
else:
|
|
import random
|
|
select = random.randint(0, len(gpus) - 1)
|
|
return gpus[select]
|
|
|
|
|
|
class InferencerCache:
|
|
max_size = 2
|
|
_cache = []
|
|
|
|
@classmethod
|
|
def get_instance(cls, instance_name, callback: Callable):
|
|
if len(cls._cache) > 0:
|
|
for i, cache in enumerate(cls._cache):
|
|
if cache[0] == instance_name:
|
|
# Re-insert to the head of list.
|
|
cls._cache.insert(0, cls._cache.pop(i))
|
|
logger.info(f'Use cached {instance_name}.')
|
|
return cache[1]
|
|
|
|
if len(cls._cache) == cls.max_size:
|
|
cls._cache.pop(cls.max_size - 1)
|
|
torch.cuda.empty_cache()
|
|
device = get_free_device()
|
|
instance = callback(device=device)
|
|
logger.info(f'New instance {instance_name} on {device}.')
|
|
cls._cache.insert(0, (instance_name, instance))
|
|
return instance
|
|
|
|
|
|
class ImageCaptionTab:
|
|
|
|
def __init__(self) -> None:
|
|
self.model_list = ImageCaptionInferencer.list_models()
|
|
self.tab = self.create_ui()
|
|
|
|
def create_ui(self):
|
|
with gr.Row():
|
|
with gr.Column():
|
|
select_model = gr.Dropdown(
|
|
label='Choose a model',
|
|
elem_id='image_caption_models',
|
|
elem_classes='select_model',
|
|
choices=self.model_list,
|
|
value='blip-base_3rdparty_coco-caption',
|
|
)
|
|
with gr.Column():
|
|
image_input = gr.Image(
|
|
label='Input',
|
|
source='upload',
|
|
elem_classes='input_image',
|
|
interactive=True,
|
|
tool='editor',
|
|
)
|
|
caption_output = gr.Textbox(
|
|
label='Result',
|
|
lines=2,
|
|
elem_classes='caption_result',
|
|
interactive=False,
|
|
)
|
|
run_button = gr.Button(
|
|
'Run',
|
|
elem_classes='run_button',
|
|
)
|
|
run_button.click(
|
|
self.inference,
|
|
inputs=[select_model, image_input],
|
|
outputs=caption_output,
|
|
)
|
|
|
|
def inference(self, model, image):
|
|
image = image[:, :, ::-1]
|
|
inferencer_name = self.__class__.__name__ + model
|
|
inferencer = InferencerCache.get_instance(
|
|
inferencer_name, partial(ImageCaptionInferencer, model))
|
|
|
|
result = inferencer(image)[0]
|
|
return result['pred_caption']
|
|
|
|
|
|
class ImageClassificationTab:
|
|
|
|
def __init__(self) -> None:
|
|
self.short_list = [
|
|
'resnet50_8xb32_in1k',
|
|
'resnet50_8xb256-rsb-a1-600e_in1k',
|
|
'swin-base_16xb64_in1k',
|
|
'convnext-base_32xb128_in1k',
|
|
'vit-base-p16_32xb128-mae_in1k',
|
|
]
|
|
self.long_list = ImageClassificationInferencer.list_models()
|
|
self.tab = self.create_ui()
|
|
|
|
def create_ui(self):
|
|
with gr.Row():
|
|
with gr.Column():
|
|
select_model = gr.Dropdown(
|
|
label='Choose a model',
|
|
elem_id='image_classification_models',
|
|
elem_classes='select_model',
|
|
choices=self.short_list,
|
|
value='swin-base_16xb64_in1k',
|
|
)
|
|
expand = gr.Checkbox(label='Browse all models')
|
|
|
|
def browse_all_model(value):
|
|
models = self.long_list if value else self.short_list
|
|
return gr.update(choices=models)
|
|
|
|
expand.select(
|
|
fn=browse_all_model, inputs=expand, outputs=select_model)
|
|
with gr.Column():
|
|
in_image = gr.Image(
|
|
label='Input',
|
|
source='upload',
|
|
elem_classes='input_image',
|
|
interactive=True,
|
|
tool='editor',
|
|
)
|
|
out_cls = gr.Label(
|
|
label='Result',
|
|
num_top_classes=5,
|
|
elem_classes='cls_result',
|
|
)
|
|
run_button = gr.Button(
|
|
'Run',
|
|
elem_classes='run_button',
|
|
)
|
|
run_button.click(
|
|
self.inference,
|
|
inputs=[select_model, in_image],
|
|
outputs=out_cls,
|
|
)
|
|
|
|
def inference(self, model, image):
|
|
image = image[:, :, ::-1]
|
|
|
|
inferencer_name = self.__class__.__name__ + model
|
|
inferencer = InferencerCache.get_instance(
|
|
inferencer_name, partial(ImageClassificationInferencer, model))
|
|
result = inferencer(image)[0]['pred_scores'].tolist()
|
|
|
|
if inferencer.classes is not None:
|
|
classes = inferencer.classes
|
|
else:
|
|
classes = list(range(len(result)))
|
|
|
|
return dict(zip(classes, result))
|
|
|
|
|
|
class ImageRetrievalTab:
|
|
|
|
def __init__(self) -> None:
|
|
self.model_list = ImageRetrievalInferencer.list_models()
|
|
self.tab = self.create_ui()
|
|
|
|
def create_ui(self):
|
|
with gr.Row():
|
|
with gr.Column():
|
|
select_model = gr.Dropdown(
|
|
label='Choose a model',
|
|
elem_id='image_retri_models',
|
|
elem_classes='select_model',
|
|
choices=self.model_list,
|
|
value='resnet50-arcface_8xb32_inshop',
|
|
)
|
|
topk = gr.Slider(minimum=1, maximum=6, value=3, step=1)
|
|
with gr.Column():
|
|
prototype = gr.File(
|
|
label='Retrieve from',
|
|
file_count='multiple',
|
|
file_types=['image'])
|
|
image_input = gr.Image(
|
|
label='Query',
|
|
source='upload',
|
|
elem_classes='input_image',
|
|
interactive=True,
|
|
tool='editor',
|
|
)
|
|
retri_output = gr.Gallery(
|
|
label='Result',
|
|
elem_classes='img_retri_result',
|
|
).style(
|
|
columns=[3], object_fit='contain', height='auto')
|
|
run_button = gr.Button(
|
|
'Run',
|
|
elem_classes='run_button',
|
|
)
|
|
run_button.click(
|
|
self.inference,
|
|
inputs=[select_model, prototype, image_input, topk],
|
|
outputs=retri_output,
|
|
)
|
|
|
|
def inference(self, model, prototype, image, topk):
|
|
image = image[:, :, ::-1]
|
|
|
|
import hashlib
|
|
|
|
proto_signature = ''.join(file.name for file in prototype).encode()
|
|
proto_signature = hashlib.sha256(proto_signature).hexdigest()
|
|
inferencer_name = self.__class__.__name__ + model + proto_signature
|
|
tmp_dir = Path(prototype[0].name).parent
|
|
cache_file = tmp_dir / f'{inferencer_name}.pth'
|
|
|
|
inferencer = InferencerCache.get_instance(
|
|
inferencer_name,
|
|
partial(
|
|
ImageRetrievalInferencer,
|
|
model,
|
|
prototype=[file.name for file in prototype],
|
|
prototype_cache=str(cache_file),
|
|
),
|
|
)
|
|
|
|
result = inferencer(image, topk=min(topk, len(prototype)))[0]
|
|
return [(str(item['sample']['img_path']),
|
|
str(item['match_score'].cpu().item())) for item in result]
|
|
|
|
|
|
class TextToImageRetrievalTab:
|
|
|
|
def __init__(self) -> None:
|
|
self.model_list = TextToImageRetrievalInferencer.list_models()
|
|
self.tab = self.create_ui()
|
|
|
|
def create_ui(self):
|
|
with gr.Row():
|
|
with gr.Column():
|
|
select_model = gr.Dropdown(
|
|
label='Choose a model',
|
|
elem_id='t2i_retri_models',
|
|
elem_classes='select_model',
|
|
choices=self.model_list,
|
|
value='blip-base_3rdparty_coco-retrieval',
|
|
)
|
|
topk = gr.Slider(minimum=1, maximum=6, value=3, step=1)
|
|
with gr.Column():
|
|
prototype = gr.File(
|
|
file_count='multiple', file_types=['image'])
|
|
text_input = gr.Textbox(
|
|
label='Query',
|
|
elem_classes='input_text',
|
|
interactive=True,
|
|
)
|
|
retri_output = gr.Gallery(
|
|
label='Result',
|
|
elem_classes='img_retri_result',
|
|
).style(
|
|
columns=[3], object_fit='contain', height='auto')
|
|
run_button = gr.Button(
|
|
'Run',
|
|
elem_classes='run_button',
|
|
)
|
|
run_button.click(
|
|
self.inference,
|
|
inputs=[select_model, prototype, text_input, topk],
|
|
outputs=retri_output,
|
|
)
|
|
|
|
def inference(self, model, prototype, text, topk):
|
|
import hashlib
|
|
|
|
proto_signature = ''.join(file.name for file in prototype).encode()
|
|
proto_signature = hashlib.sha256(proto_signature).hexdigest()
|
|
inferencer_name = self.__class__.__name__ + model + proto_signature
|
|
tmp_dir = Path(prototype[0].name).parent
|
|
cache_file = tmp_dir / f'{inferencer_name}.pth'
|
|
|
|
inferencer = InferencerCache.get_instance(
|
|
inferencer_name,
|
|
partial(
|
|
TextToImageRetrievalInferencer,
|
|
model,
|
|
prototype=[file.name for file in prototype],
|
|
prototype_cache=str(cache_file),
|
|
),
|
|
)
|
|
|
|
result = inferencer(text, topk=min(topk, len(prototype)))[0]
|
|
return [(str(item['sample']['img_path']),
|
|
str(item['match_score'].cpu().item())) for item in result]
|
|
|
|
|
|
class VisualGroundingTab:
|
|
|
|
def __init__(self) -> None:
|
|
self.model_list = VisualGroundingInferencer.list_models()
|
|
self.tab = self.create_ui()
|
|
self.visualizer = UniversalVisualizer(
|
|
fig_save_cfg=dict(figsize=(16, 9)))
|
|
|
|
def create_ui(self):
|
|
with gr.Row():
|
|
with gr.Column():
|
|
select_model = gr.Dropdown(
|
|
label='Choose a model',
|
|
elem_id='vg_models',
|
|
elem_classes='select_model',
|
|
choices=self.model_list,
|
|
value='ofa-base_3rdparty_refcoco',
|
|
)
|
|
with gr.Column():
|
|
image_input = gr.Image(
|
|
label='Image',
|
|
source='upload',
|
|
elem_classes='input_image',
|
|
interactive=True,
|
|
tool='editor',
|
|
)
|
|
text_input = gr.Textbox(
|
|
label='The object to search',
|
|
elem_classes='input_text',
|
|
interactive=True,
|
|
)
|
|
vg_output = gr.Image(
|
|
label='Result',
|
|
source='upload',
|
|
interactive=False,
|
|
elem_classes='vg_result',
|
|
)
|
|
run_button = gr.Button(
|
|
'Run',
|
|
elem_classes='run_button',
|
|
)
|
|
run_button.click(
|
|
self.inference,
|
|
inputs=[select_model, image_input, text_input],
|
|
outputs=vg_output,
|
|
)
|
|
|
|
def inference(self, model, image, text):
|
|
|
|
inferencer_name = self.__class__.__name__ + model
|
|
|
|
inferencer = InferencerCache.get_instance(
|
|
inferencer_name,
|
|
partial(VisualGroundingInferencer, model),
|
|
)
|
|
|
|
result = inferencer(
|
|
image[:, :, ::-1], text, return_datasamples=True)[0]
|
|
vis = self.visualizer.visualize_visual_grounding(
|
|
image, result, resize=512)
|
|
return vis
|
|
|
|
|
|
class VisualQuestionAnsweringTab:
|
|
|
|
def __init__(self) -> None:
|
|
self.model_list = VisualQuestionAnsweringInferencer.list_models()
|
|
# The fine-tuned OFA vqa models requires extra object description.
|
|
self.model_list.remove('ofa-base_3rdparty-finetuned_vqa')
|
|
self.tab = self.create_ui()
|
|
|
|
def create_ui(self):
|
|
with gr.Row():
|
|
with gr.Column():
|
|
select_model = gr.Dropdown(
|
|
label='Choose a model',
|
|
elem_id='vqa_models',
|
|
elem_classes='select_model',
|
|
choices=self.model_list,
|
|
value='ofa-base_3rdparty-zeroshot_coco-vqa',
|
|
)
|
|
with gr.Column():
|
|
image_input = gr.Image(
|
|
label='Input',
|
|
source='upload',
|
|
elem_classes='input_image',
|
|
interactive=True,
|
|
tool='editor',
|
|
)
|
|
question_input = gr.Textbox(
|
|
label='Question',
|
|
elem_classes='question_input',
|
|
)
|
|
answer_output = gr.Textbox(
|
|
label='Answer',
|
|
elem_classes='answer_result',
|
|
)
|
|
run_button = gr.Button(
|
|
'Run',
|
|
elem_classes='run_button',
|
|
)
|
|
run_button.click(
|
|
self.inference,
|
|
inputs=[select_model, image_input, question_input],
|
|
outputs=answer_output,
|
|
)
|
|
|
|
def inference(self, model, image, question):
|
|
image = image[:, :, ::-1]
|
|
|
|
inferencer_name = self.__class__.__name__ + model
|
|
inferencer = InferencerCache.get_instance(
|
|
inferencer_name, partial(VisualQuestionAnsweringInferencer, model))
|
|
|
|
result = inferencer(image, question)[0]
|
|
return result['pred_answer']
|
|
|
|
|
|
if __name__ == '__main__':
|
|
title = 'MMPretrain Inference Demo'
|
|
with gr.Blocks(analytics_enabled=False, title=title) as demo:
|
|
gr.Markdown(f'# {title}')
|
|
with gr.Tabs():
|
|
with gr.TabItem('Image Classification'):
|
|
ImageClassificationTab()
|
|
with gr.TabItem('Image-To-Image Retrieval'):
|
|
ImageRetrievalTab()
|
|
if WITH_MULTIMODAL:
|
|
with gr.TabItem('Image Caption'):
|
|
ImageCaptionTab()
|
|
with gr.TabItem('Text-To-Image Retrieval'):
|
|
TextToImageRetrievalTab()
|
|
with gr.TabItem('Visual Grounding'):
|
|
VisualGroundingTab()
|
|
with gr.TabItem('Visual Question Answering'):
|
|
VisualQuestionAnsweringTab()
|
|
else:
|
|
with gr.TabItem('Multi-modal tasks'):
|
|
gr.Markdown(
|
|
'To inference multi-modal models, please install '
|
|
'the extra multi-modal dependencies, please refer '
|
|
'to https://mmpretrain.readthedocs.io/en/latest/'
|
|
'get_started.html#installation')
|
|
|
|
demo.launch()
|