from functools import partial from pathlib import Path from typing import Callable import gradio as gr import torch from mmengine.logging import MMLogger import mmpretrain from mmpretrain.apis import (ImageCaptionInferencer, ImageClassificationInferencer, ImageRetrievalInferencer, TextToImageRetrievalInferencer, VisualGroundingInferencer, VisualQuestionAnsweringInferencer) from mmpretrain.utils.dependency import WITH_MULTIMODAL from mmpretrain.visualization import UniversalVisualizer mmpretrain.utils.progress.disable_progress_bar = True logger = MMLogger('mmpretrain', logger_name='mmpre') if torch.cuda.is_available(): devices = [ torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count()) ] logger.info(f'Available GPUs: {len(devices)}') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): devices = [torch.device('mps')] logger.info('Available MPS.') else: devices = [torch.device('cpu')] logger.info('Available CPU.') def get_free_device(): if hasattr(torch.cuda, 'mem_get_info'): free = [torch.cuda.mem_get_info(gpu)[0] for gpu in devices] select = max(zip(free, range(len(free))))[1] else: import random select = random.randint(0, len(devices) - 1) return devices[select] class InferencerCache: max_size = 2 _cache = [] @classmethod def get_instance(cls, instance_name, callback: Callable): if len(cls._cache) > 0: for i, cache in enumerate(cls._cache): if cache[0] == instance_name: # Re-insert to the head of list. cls._cache.insert(0, cls._cache.pop(i)) logger.info(f'Use cached {instance_name}.') return cache[1] if len(cls._cache) == cls.max_size: cls._cache.pop(cls.max_size - 1) torch.cuda.empty_cache() device = get_free_device() instance = callback(device=device) logger.info(f'New instance {instance_name} on {device}.') cls._cache.insert(0, (instance_name, instance)) return instance class ImageCaptionTab: def __init__(self) -> None: self.model_list = ImageCaptionInferencer.list_models() self.tab = self.create_ui() def create_ui(self): with gr.Row(): with gr.Column(): select_model = gr.Dropdown( label='Choose a model', elem_id='image_caption_models', elem_classes='select_model', choices=self.model_list, value='blip-base_3rdparty_coco-caption', ) with gr.Column(): image_input = gr.Image( label='Input', source='upload', elem_classes='input_image', interactive=True, tool='editor', ) caption_output = gr.Textbox( label='Result', lines=2, elem_classes='caption_result', interactive=False, ) run_button = gr.Button( 'Run', elem_classes='run_button', ) run_button.click( self.inference, inputs=[select_model, image_input], outputs=caption_output, ) def inference(self, model, image): image = image[:, :, ::-1] inferencer_name = self.__class__.__name__ + model inferencer = InferencerCache.get_instance( inferencer_name, partial(ImageCaptionInferencer, model)) result = inferencer(image)[0] return result['pred_caption'] class ImageClassificationTab: def __init__(self) -> None: self.short_list = [ 'resnet50_8xb32_in1k', 'resnet50_8xb256-rsb-a1-600e_in1k', 'swin-base_16xb64_in1k', 'convnext-base_32xb128_in1k', 'vit-base-p16_32xb128-mae_in1k', ] self.long_list = ImageClassificationInferencer.list_models() self.tab = self.create_ui() def create_ui(self): with gr.Row(): with gr.Column(): select_model = gr.Dropdown( label='Choose a model', elem_id='image_classification_models', elem_classes='select_model', choices=self.short_list, value='swin-base_16xb64_in1k', ) expand = gr.Checkbox(label='Browse all models') def browse_all_model(value): models = self.long_list if value else self.short_list return gr.update(choices=models) expand.select( fn=browse_all_model, inputs=expand, outputs=select_model) with gr.Column(): in_image = gr.Image( label='Input', source='upload', elem_classes='input_image', interactive=True, tool='editor', ) out_cls = gr.Label( label='Result', num_top_classes=5, elem_classes='cls_result', ) run_button = gr.Button( 'Run', elem_classes='run_button', ) run_button.click( self.inference, inputs=[select_model, in_image], outputs=out_cls, ) def inference(self, model, image): image = image[:, :, ::-1] inferencer_name = self.__class__.__name__ + model inferencer = InferencerCache.get_instance( inferencer_name, partial(ImageClassificationInferencer, model)) result = inferencer(image)[0]['pred_scores'].tolist() if inferencer.classes is not None: classes = inferencer.classes else: classes = list(range(len(result))) return dict(zip(classes, result)) class ImageRetrievalTab: def __init__(self) -> None: self.model_list = ImageRetrievalInferencer.list_models() self.tab = self.create_ui() def create_ui(self): with gr.Row(): with gr.Column(): select_model = gr.Dropdown( label='Choose a model', elem_id='image_retri_models', elem_classes='select_model', choices=self.model_list, value='resnet50-arcface_inshop', ) topk = gr.Slider(minimum=1, maximum=6, value=3, step=1) with gr.Column(): prototype = gr.File( label='Retrieve from', file_count='multiple', file_types=['image']) image_input = gr.Image( label='Query', source='upload', elem_classes='input_image', interactive=True, tool='editor', ) retri_output = gr.Gallery( label='Result', elem_classes='img_retri_result', ).style( columns=[3], object_fit='contain', height='auto') run_button = gr.Button( 'Run', elem_classes='run_button', ) run_button.click( self.inference, inputs=[select_model, prototype, image_input, topk], outputs=retri_output, ) def inference(self, model, prototype, image, topk): image = image[:, :, ::-1] import hashlib proto_signature = ''.join(file.name for file in prototype).encode() proto_signature = hashlib.sha256(proto_signature).hexdigest() inferencer_name = self.__class__.__name__ + model + proto_signature tmp_dir = Path(prototype[0].name).parent cache_file = tmp_dir / f'{inferencer_name}.pth' inferencer = InferencerCache.get_instance( inferencer_name, partial( ImageRetrievalInferencer, model, prototype=[file.name for file in prototype], prototype_cache=str(cache_file), ), ) result = inferencer(image, topk=min(topk, len(prototype)))[0] return [(str(item['sample']['img_path']), str(item['match_score'].cpu().item())) for item in result] class TextToImageRetrievalTab: def __init__(self) -> None: self.model_list = TextToImageRetrievalInferencer.list_models() self.tab = self.create_ui() def create_ui(self): with gr.Row(): with gr.Column(): select_model = gr.Dropdown( label='Choose a model', elem_id='t2i_retri_models', elem_classes='select_model', choices=self.model_list, value='blip-base_3rdparty_coco-retrieval', ) topk = gr.Slider(minimum=1, maximum=6, value=3, step=1) with gr.Column(): prototype = gr.File( file_count='multiple', file_types=['image']) text_input = gr.Textbox( label='Query', elem_classes='input_text', interactive=True, ) retri_output = gr.Gallery( label='Result', elem_classes='img_retri_result', ).style( columns=[3], object_fit='contain', height='auto') run_button = gr.Button( 'Run', elem_classes='run_button', ) run_button.click( self.inference, inputs=[select_model, prototype, text_input, topk], outputs=retri_output, ) def inference(self, model, prototype, text, topk): import hashlib proto_signature = ''.join(file.name for file in prototype).encode() proto_signature = hashlib.sha256(proto_signature).hexdigest() inferencer_name = self.__class__.__name__ + model + proto_signature tmp_dir = Path(prototype[0].name).parent cache_file = tmp_dir / f'{inferencer_name}.pth' inferencer = InferencerCache.get_instance( inferencer_name, partial( TextToImageRetrievalInferencer, model, prototype=[file.name for file in prototype], prototype_cache=str(cache_file), ), ) result = inferencer(text, topk=min(topk, len(prototype)))[0] return [(str(item['sample']['img_path']), str(item['match_score'].cpu().item())) for item in result] class VisualGroundingTab: def __init__(self) -> None: self.model_list = VisualGroundingInferencer.list_models() self.tab = self.create_ui() self.visualizer = UniversalVisualizer( fig_save_cfg=dict(figsize=(16, 9))) def create_ui(self): with gr.Row(): with gr.Column(): select_model = gr.Dropdown( label='Choose a model', elem_id='vg_models', elem_classes='select_model', choices=self.model_list, value='ofa-base_3rdparty_refcoco', ) with gr.Column(): image_input = gr.Image( label='Image', source='upload', elem_classes='input_image', interactive=True, tool='editor', ) text_input = gr.Textbox( label='The object to search', elem_classes='input_text', interactive=True, ) vg_output = gr.Image( label='Result', source='upload', interactive=False, elem_classes='vg_result', ) run_button = gr.Button( 'Run', elem_classes='run_button', ) run_button.click( self.inference, inputs=[select_model, image_input, text_input], outputs=vg_output, ) def inference(self, model, image, text): inferencer_name = self.__class__.__name__ + model inferencer = InferencerCache.get_instance( inferencer_name, partial(VisualGroundingInferencer, model), ) result = inferencer( image[:, :, ::-1], text, return_datasamples=True)[0] vis = self.visualizer.visualize_visual_grounding( image, result, resize=512) return vis class VisualQuestionAnsweringTab: def __init__(self) -> None: self.model_list = VisualQuestionAnsweringInferencer.list_models() # The fine-tuned OFA vqa models requires extra object description. self.model_list.remove('ofa-base_3rdparty-finetuned_vqa') self.tab = self.create_ui() def create_ui(self): with gr.Row(): with gr.Column(): select_model = gr.Dropdown( label='Choose a model', elem_id='vqa_models', elem_classes='select_model', choices=self.model_list, value='ofa-base_3rdparty-zeroshot_coco-vqa', ) with gr.Column(): image_input = gr.Image( label='Input', source='upload', elem_classes='input_image', interactive=True, tool='editor', ) question_input = gr.Textbox( label='Question', elem_classes='question_input', ) answer_output = gr.Textbox( label='Answer', elem_classes='answer_result', ) run_button = gr.Button( 'Run', elem_classes='run_button', ) run_button.click( self.inference, inputs=[select_model, image_input, question_input], outputs=answer_output, ) def inference(self, model, image, question): image = image[:, :, ::-1] inferencer_name = self.__class__.__name__ + model inferencer = InferencerCache.get_instance( inferencer_name, partial(VisualQuestionAnsweringInferencer, model)) result = inferencer(image, question)[0] return result['pred_answer'] if __name__ == '__main__': title = 'MMPretrain Inference Demo' with gr.Blocks(analytics_enabled=False, title=title) as demo: gr.Markdown(f'# {title}') with gr.Tabs(): with gr.TabItem('Image Classification'): ImageClassificationTab() with gr.TabItem('Image-To-Image Retrieval'): ImageRetrievalTab() if WITH_MULTIMODAL: with gr.TabItem('Image Caption'): ImageCaptionTab() with gr.TabItem('Text-To-Image Retrieval'): TextToImageRetrievalTab() with gr.TabItem('Visual Grounding'): VisualGroundingTab() with gr.TabItem('Visual Question Answering'): VisualQuestionAnsweringTab() else: with gr.TabItem('Multi-modal tasks'): gr.Markdown( 'To inference multi-modal models, please install ' 'the extra multi-modal dependencies, please refer ' 'to https://mmpretrain.readthedocs.io/en/latest/' 'get_started.html#installation') demo.launch()