Ma Zerun 6847d20d57
[Feature] Support multiple multi-modal algorithms and inferencers. (#1561)
* [Feat] Migrate blip caption to mmpretrain. (#50)

* Migrate blip caption to mmpretrain

* minor fix

* support train

* [Feature] Support OFA caption task. (#51)

* [Feature] Support OFA caption task.

* Remove duplicated files.

* [Feature] Support OFA vqa task. (#58)

* [Feature] Support OFA vqa task.

* Fix lint.

* [Feat] Add BLIP retrieval to mmpretrain. (#55)

* init

* minor fix for train

* fix according to comments

* refactor

* Update Blip retrieval. (#62)

* [Feature] Support OFA visual grounding task. (#59)

* [Feature] Support OFA visual grounding task.

* minor add TODO

---------

Co-authored-by: yingfhu <yingfhu@gmail.com>

* [Feat] Add flamingos coco caption and vqa. (#60)

* first init

* init flamingo coco

* add vqa

* minor fix

* remove unnecessary modules

* Update config

* Use `ApplyToList`.

---------

Co-authored-by: mzr1996 <mzr1996@163.com>

* [Feature]: BLIP2 coco retrieval  (#53)

* [Feature]: Add blip2 retriever

* [Feature]: Add blip2 all modules

* [Feature]: Refine model

* [Feature]: x1

* [Feature]: Runnable coco ret

* [Feature]: Runnable version

* [Feature]: Fix lint

* [Fix]: Fix lint

* [Feature]: Use 364 img size

* [Feature]: Refactor blip2

* [Fix]: Fix lint

* refactor files

* minor fix

* minor fix

---------

Co-authored-by: yingfhu <yingfhu@gmail.com>

* Remove

* fix blip caption inputs (#68)

* [Feat] Add BLIP NLVR support. (#67)

* first init

* init flamingo coco

* add vqa

* add nlvr

* refactor nlvr

* minor fix

* minor fix

* Update dataset

---------

Co-authored-by: mzr1996 <mzr1996@163.com>

* [Feature]: BLIP2 Caption (#70)

* [Feature]: Add language model

* [Feature]: blip2 caption forward

* [Feature]: Reproduce the results

* [Feature]: Refactor caption

* refine config

---------

Co-authored-by: yingfhu <yingfhu@gmail.com>

* [Feat] Migrate BLIP VQA to mmpretrain (#69)

* reformat

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* change

* refactor code

---------

Co-authored-by: yingfhu <yingfhu@gmail.com>

* Update RefCOCO dataset

* [Fix] fix lint

* [Feature] Implement inference APIs for multi-modal tasks. (#65)

* [Feature] Implement inference APIs for multi-modal tasks.

* [Project] Add gradio demo.

* [Improve] Update requirements

* Update flamingo

* Update blip

* Add NLVR inferencer

* Update flamingo

* Update hugging face model register

* Update ofa vqa

* Update BLIP-vqa (#71)

* Update blip-vqa docstring (#72)

* Refine flamingo docstring (#73)

* [Feature]: BLIP2 VQA (#61)

* [Feature]: VQA forward

* [Feature]: Reproduce accuracy

* [Fix]: Fix lint

* [Fix]: Add blank line

* minor fix

---------

Co-authored-by: yingfhu <yingfhu@gmail.com>

* [Feature]: BLIP2 docstring (#74)

* [Feature]: Add caption docstring

* [Feature]: Add docstring to blip2 vqa

* [Feature]: Add docstring to retrieval

* Update BLIP-2 metafile and README (#75)

* [Feature]: Add readme and docstring

* Update blip2 results

---------

Co-authored-by: mzr1996 <mzr1996@163.com>

* [Feature] BLIP Visual Grounding on MMPretrain Branch (#66)

* blip grounding merge with mmpretrain

* remove commit

* blip grounding test and inference api

* refcoco dataset

* refcoco dataset refine config

* rebasing

* gitignore

* rebasing

* minor edit

* minor edit

* Update blip-vqa docstring (#72)

* rebasing

* Revert "minor edit"

This reverts commit 639cec757c215e654625ed0979319e60f0be9044.

* blip grounding final

* precommit

* refine config

* refine config

* Update blip visual grounding

---------

Co-authored-by: Yiqin Wang 王逸钦 <wyq1217@outlook.com>
Co-authored-by: mzr1996 <mzr1996@163.com>

* Update visual grounding metric

* Update OFA docstring, README and metafiles. (#76)

* [Docs] Update installation docs and gradio demo docs. (#77)

* Update OFA name

* Update Visual Grounding Visualizer

* Integrate accelerate support

* Fix imports.

* Fix timm backbone

* Update imports

* Update README

* Update circle ci

* Update flamingo config

* Add gradio demo README

* [Feature]: Add scienceqa (#1571)

* [Feature]: Add scienceqa

* [Feature]: Change param name

* Update docs

* Update video

---------

Co-authored-by: Hubert <42952108+yingfhu@users.noreply.github.com>
Co-authored-by: yingfhu <yingfhu@gmail.com>
Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com>
Co-authored-by: Yiqin Wang 王逸钦 <wyq1217@outlook.com>
Co-authored-by: Rongjie Li <limo97@163.com>
2023-05-19 16:50:04 +08:00

177 lines
7.5 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 使用现有模型进行推理
本文将展示如何使用以下API
- [**`list_models`**](mmpretrain.apis.list_models): 列举 MMPretrain 中所有可用模型名称
- [**`get_model`**](mmpretrain.apis.get_model): 通过模型名称或模型配置文件获取模型
- [**`inference_model`**](mmpretrain.apis.inference_model): 使用与模型相对应任务的推理器进行推理。主要用作快速
展示。如需配置进阶用法,还需要直接使用下列推理器。
- 推理器:
1. [**`ImageClassificationInferencer`**](mmpretrain.apis.ImageClassificationInferencer):
对给定图像执行图像分类。
2. [**`ImageRetrievalInferencer`**](mmpretrain.apis.ImageRetrievalInferencer):
从给定的一系列图像中,检索与给定图像最相似的图像。
3. [**`ImageCaptionInferencer`**](mmpretrain.apis.ImageCaptionInferencer):
生成给定图像的一段描述。
4. [**`VisualQuestionAnsweringInferencer`**](mmpretrain.apis.VisualQuestionAnsweringInferencer):
根据给定的图像回答问题。
5. [**`VisualGroundingInferencer`**](mmpretrain.apis.VisualGroundingInferencer):
根据一段描述,从给定图像中找到一个与描述对应的对象。
6. [**`TextToImageRetrievalInferencer`**](mmpretrain.apis.TextToImageRetrievalInferencer):
从给定的一系列图像中,检索与给定文本最相似的图像。
7. [**`ImageToTextRetrievalInferencer`**](mmpretrain.apis.ImageToTextRetrievalInferencer):
从给定的一系列文本中,检索与给定图像最相似的文本。
8. [**`NLVRInferencer`**](mmpretrain.apis.NLVRInferencer):
对给定的一对图像和一段文本进行自然语言视觉推理NLVR 任务)。
9. [**`FeatureExtractor`**](mmpretrain.apis.FeatureExtractor):
通过视觉主干网络从图像文件提取特征。
## 列举可用模型
列出 MMPreTrain 中的所有已支持的模型。
```python
>>> from mmpretrain import list_models
>>> list_models()
['barlowtwins_resnet50_8xb256-coslr-300e_in1k',
'beit-base-p16_beit-in21k-pre_3rdparty_in1k',
...]
```
`list_models` 支持 Unix 文件名风格的模式匹配,你可以使用 \*\* * \*\* 匹配任意字符。
```python
>>> from mmpretrain import list_models
>>> list_models("*convnext-b*21k")
['convnext-base_3rdparty_in21k',
'convnext-base_in21k-pre-3rdparty_in1k-384px',
'convnext-base_in21k-pre_3rdparty_in1k']
```
你还可以使用推理器的 `list_models` 方法获取对应任务可用的所有模型。
```python
>>> from mmpretrain import ImageCaptionInferencer
>>> ImageCaptionInferencer.list_models()
['blip-base_3rdparty_caption',
'blip2-opt2.7b_3rdparty-zeroshot_caption',
'flamingo_3rdparty-zeroshot_caption',
'ofa-base_3rdparty-finetuned_caption']
```
## 获取模型
选定需要的模型后,你可以使用 `get_model` 获取特定模型。
```python
>>> from mmpretrain import get_model
# 不加载预训练权重的模型
>>> model = get_model("convnext-base_in21k-pre_3rdparty_in1k")
# 加载默认的权重文件
>>> model = get_model("convnext-base_in21k-pre_3rdparty_in1k", pretrained=True)
# 加载制定的权重文件
>>> model = get_model("convnext-base_in21k-pre_3rdparty_in1k", pretrained="your_local_checkpoint_path")
# 指定额外的模型初始化参数,例如修改 head 中的 num_classes。
>>> model = get_model("convnext-base_in21k-pre_3rdparty_in1k", head=dict(num_classes=10))
# 另外一个例子:移除模型的 neckhead 模块,直接从 backbone 中的 stage 1, 2, 3 输出
>>> model_headless = get_model("resnet18_8xb32_in1k", head=None, neck=None, backbone=dict(out_indices=(1, 2, 3)))
```
获得的模型是一个通常的 PyTorch Module
```python
>>> import torch
>>> from mmpretrain import get_model
>>> model = get_model('convnext-base_in21k-pre_3rdparty_in1k', pretrained=True)
>>> x = torch.rand((1, 3, 224, 224))
>>> y = model(x)
>>> print(type(y), y.shape)
<class 'torch.Tensor'> torch.Size([1, 1000])
```
## 在给定图像上进行推理
这里是一个例子,我们将使用 ResNet-50 预训练模型对给定的 [图像](https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG) 进行分类。
```python
>>> from mmpretrain import inference_model
>>> image = 'https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG'
>>> # 如果你没有图形界面,请设置 `show=False`
>>> result = inference_model('resnet50_8xb32_in1k', image, show=True)
>>> print(result['pred_class'])
sea snake
```
上述 `inference_model` 接口可以快速进行模型推理,但它每次调用都需要重新初始化模型,也无法进行多个样本的推理。
因此我们需要使用推理器来进行多次调用。
```python
>>> from mmpretrain import ImageClassificationInferencer
>>> image = 'https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG'
>>> inferencer = ImageClassificationInferencer('resnet50_8xb32_in1k')
>>> # 注意推理器的输出始终为一个结果列表,即使输入只有一个样本
>>> result = inferencer('https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG')[0]
>>> print(result['pred_class'])
sea snake
>>>
>>> # 你可以对多张图像进行批量推理
>>> image_list = ['demo/demo.JPEG', 'demo/bird.JPEG'] * 16
>>> results = inferencer(image_list, batch_size=8)
>>> print(len(results))
32
>>> print(results[1]['pred_class'])
house finch, linnet, Carpodacus mexicanus
```
通常,每个样本的结果都是一个字典。比如图像分类的结果是一个包含了 `pred_label``pred_score``pred_scores``pred_class` 等字段的字典:
```python
{
"pred_label": 65,
"pred_score": 0.6649366617202759,
"pred_class":"sea snake",
"pred_scores": array([..., 0.6649366617202759, ...], dtype=float32)
}
```
你可以为推理器配置额外的参数,比如使用你自己的配置文件和权重文件,在 CUDA 上进行推理:
```python
>>> from mmpretrain import ImageClassificationInferencer
>>> image = 'https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG'
>>> config = 'configs/resnet/resnet50_8xb32_in1k.py'
>>> checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth'
>>> inferencer = ImageClassificationInferencer(model=config, pretrained=checkpoint, device='cuda')
>>> result = inferencer(image)[0]
>>> print(result['pred_class'])
sea snake
```
## 使用 Gradio 推理示例
我们还提供了一个基于 gradio 的推理示例,提供了 MMPretrain 所支持的所有任务的推理展示功能,你可以在 [projects/gradio_demo/launch.py](https://github.com/open-mmlab/mmpretrain/blob/main/projects/gradio_demo/launch.py) 找到这一例程。
请首先使用 `pip install -U gradio` 安装 `gradio` 库。
这里是界面效果预览:
<img src="https://user-images.githubusercontent.com/26739999/236147750-90ccb517-92c0-44e9-905e-1473677023b1.jpg" width="100%"/>
## 从图像中提取特征
`model.extract_feat` 相比,`FeatureExtractor` 用于直接从图像文件中提取特征,而不是从一批张量中提取特征。简单说,`model.extract_feat` 的输入是 `torch.Tensor``FeatureExtractor` 的输入是图像。
```
>>> from mmpretrain import FeatureExtractor, get_model
>>> model = get_model('resnet50_8xb32_in1k', backbone=dict(out_indices=(0, 1, 2, 3)))
>>> extractor = FeatureExtractor(model)
>>> features = extractor('https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG')[0]
>>> features[0].shape, features[1].shape, features[2].shape, features[3].shape
(torch.Size([256]), torch.Size([512]), torch.Size([1024]), torch.Size([2048]))
```