[DOC] Refine Inference Doc (#1489)

* update en doc

* update

* update zh doc

* refine

* refine
pull/1503/head
Ezra-Yu 2023-05-06 17:54:13 +08:00 committed by GitHub
parent afa60c73bb
commit 9cf37b315c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 206 additions and 47 deletions

View File

@ -1,32 +1,77 @@
# Inference with existing models
MMPretrain provides pre-trained models in [Model Zoo](../modelzoo_statistics.md).
This note will show **how to use existing models to inference on given images**.
This tutorial will show how to use the following APIs
As for how to test existing models on standard datasets, please see this [guide](./test.md)
1. [**`list_models`**](mmpretrain.apis.list_models) & [**`get_model`**](mmpretrain.apis.get_model) list models in MMPreTrain and get a specific model.
2. [**`ImageClassificationInferencer`**](mmpretrain.apis.ImageClassificationInferencer): inference on given images.
3. [**`FeatureExtractor`**](mmpretrain.apis.FeatureExtractor): extract features from the image files directly.
## List models and Get model
list all the models in MMPreTrain.
```
>>> from mmpretrain import list_models
>>> list_models()
['barlowtwins_resnet50_8xb256-coslr-300e_in1k',
'beit-base-p16_beit-in21k-pre_3rdparty_in1k',
.................]
```
`list_models` supports fuzzy matching, you can use **\*** to match any character.
```
>>> from mmpretrain import list_models
>>> list_models("*convnext-b*21k")
['convnext-base_3rdparty_in21k',
'convnext-base_in21k-pre-3rdparty_in1k-384px',
'convnext-base_in21k-pre_3rdparty_in1k']
```
you can use `get_model` get the model.
```
>>> from mmpretrain import get_model
# model without pre-trained weight
>>> model = get_model("convnext-base_in21k-pre_3rdparty_in1k")
# model with default weight in MMPreTrain
>>> model = get_model("convnext-base_in21k-pre_3rdparty_in1k", pretrained=True)
# model with weight in local
>>> model = get_model("convnext-base_in21k-pre_3rdparty_in1k", pretrained="your_local_checkpoint_path")
# you can also do some modification, like modify the num_classes in head.
>>> model = get_model("convnext-base_in21k-pre_3rdparty_in1k", head=dict(num_classes=10))
# you can get model without neck, head, and output from stage 1, 2, 3 in backbone
>>> model_headless = get_model("resnet18_8xb32_in1k", head=None, neck=None, backbone=dict(out_indices=(1, 2, 3)))
```
Then you can do the forward:
```
>>> import torch
>>> from mmpretrain import get_model
>>> model = get_model('convnext-base_in21k-pre_3rdparty_in1k', pretrained=True)
>>> x = torch.rand((1, 3, 224, 224))
>>> y = model(x)
>>> print(type(y), y.shape)
<class 'torch.Tensor'> torch.Size([1, 1000])
```
## Inference on a given image
MMPretrain provides high-level Python APIs for inference on a given image:
- [`get_model`](mmpretrain.apis.get_model): Get a model with the model name.
- [`inference_model`](mmpretrain.apis.inference_model): Inference on a given image
Here is an example of building the model and inference on a given image by using ImageNet-1k pre-trained checkpoint.
```{note}
You can use `wget https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG` to download the example image or use your own image.
```
Here is an example of building the inferencer on a [given image](https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG) by using ImageNet-1k pre-trained checkpoint.
```python
from mmpretrain import get_model, inference_model
>>> from mmpretrain import ImageClassificationInferencer
img_path = 'demo.JPEG' # you can specify your own picture path
# build the model from a config file and a checkpoint file
model = get_model('resnet50_8xb32_in1k', pretrained=True, device="cpu") # device can be 'cuda:0'
# test a single image
result = inference_model(model, img_path)
>>> inferencer = ImageClassificationInferencer('resnet50_8xb32_in1k')
>>> results = inferencer('https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG')
>>> print(results[0]['pred_class'])
sea snake
```
`result` is a dictionary containing `pred_label`, `pred_score`, `pred_scores` and `pred_class`, the result is as follows:
@ -35,4 +80,39 @@ result = inference_model(model, img_path)
{"pred_label":65,"pred_score":0.6649366617202759,"pred_class":"sea snake", "pred_scores": [..., 0.6649366617202759, ...]}
```
An image demo can be found in [demo/image_demo.py](https://github.com/open-mmlab/mmpretrain/blob/main/demo/image_demo.py).
If you want to use your own config and checkpoint:
```
>>> from mmpretrain import ImageClassificationInferencer
>>> inferencer = ImageClassificationInferencer(
model='configs/resnet/resnet50_8xb32_in1k.py',
pretrained='https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth',
device='cuda')
>>> inferencer('https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG')
```
You can also inference multiple images by batch on CUDA:
```python
>>> from mmpretrain import ImageClassificationInferencer
>>> inferencer = ImageClassificationInferencer('resnet50_8xb32_in1k', device='cuda')
>>> imgs = ['https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG'] * 5
>>> results = inferencer(imgs, batch_size=2)
>>> print(results[1]['pred_class'])
sea snake
```
## Extract Features From Image
Compared with `model.extract_feat`, `FeatureExtractor` is used to extract features from the image files directly, instead of a batch of tensors.
In a word, the input of `model.extract_feat` is `torch.Tensor`, the input of `FeatureExtractor` is images.
```
>>> from mmpretrain import FeatureExtractor, get_model
>>> model = get_model('resnet50_8xb32_in1k', backbone=dict(out_indices=(0, 1, 2, 3)))
>>> extractor = FeatureExtractor(model)
>>> features = extractor('https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG')[0]
>>> features[0].shape, features[1].shape, features[2].shape, features[3].shape
(torch.Size([256]), torch.Size([512]), torch.Size([1024]), torch.Size([2048]))
```

View File

@ -1,38 +1,117 @@
# 使用现有模型推理
# 使用现有模型进行推理
MMPretrain 在 [Model Zoo](../modelzoo_statistics.md) 中提供了预训练模型。
本说明将展示**如何使用现有模型对给定图像进行推理**。
本文将展示如何使用以下API
至于如何在标准数据集上测试现有模型,请看这个[指南](./test.md)
1. [**`list_models`**](mmpretrain.apis.list_models) 和 [**`get_model`**](mmpretrain.apis.get_model) :列出 MMPreTrain 中的模型并获取模型。
2. [**`ImageClassificationInferencer`**](mmpretrain.apis.ImageClassificationInferencer): 在给定图像上进行推理。
3. [**`FeatureExtractor`**](mmpretrain.apis.FeatureExtractor): 从图像文件直接提取特征。
## 推理单张图片
## 列出模型和获取模型
MMPretrain 为图像推理提供高级 Python API
列出 MMPreTrain 中的所有已支持的模型。
- [`get_model`](mmpretrain.apis.get_model): 根据名称获取一个模型。
- [`inference_model`](mmpretrain.apis.inference_model):对给定图片进行推理。
下面是一个示例,如何使用一个 ImageNet-1k 预训练权重初始化模型并推理给定图像。
```{note}
可以运行 `wget https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG` 下载样例图片,或使用其他图片。
```
>>> from mmpretrain import list_models
>>> list_models()
['barlowtwins_resnet50_8xb256-coslr-300e_in1k',
'beit-base-p16_beit-in21k-pre_3rdparty_in1k',
.................]
```
```python
from mmpretrain import get_model, inference_model
`list_models` 支持模糊匹配,您可以使用 **\*** 匹配任意字符。
img_path = 'demo.JPEG' # 可以指定自己的图片路径
# 构建模型
model = get_model('resnet50_8xb32_in1k', pretrained=True, device="cpu") # `device` 可以为 'cuda:0'
# 执行推理
result = inference_model(model, img_path)
```
>>> from mmpretrain import list_models
>>> list_models("*convnext-b*21k")
['convnext-base_3rdparty_in21k',
'convnext-base_in21k-pre-3rdparty_in1k-384px',
'convnext-base_in21k-pre_3rdparty_in1k']
```
`result` 为一个包含了 `pred_label`, `pred_score`, `pred_scores``pred_class`的字典,结果如下:
了解了已经支持了哪些模型后,你可以使用 `get_model` 获取特定模型。
```text
```
>>> from mmpretrain import get_model
# 没有预训练权重的模型
>>> model = get_model("convnext-base_in21k-pre_3rdparty_in1k")
# 使用MMPreTrain中默认的权重
>>> model = get_model("convnext-base_in21k-pre_3rdparty_in1k", pretrained=True)
# 使用本地权重
>>> model = get_model("convnext-base_in21k-pre_3rdparty_in1k", pretrained="your_local_checkpoint_path")
# 您还可以做一些修改,例如修改 head 中的 num_classes。
>>> model = get_model("convnext-base_in21k-pre_3rdparty_in1k", head=dict(num_classes=10))
# 您可以获得没有 neckhead 的模型,并直接从 backbone 中的 stage 1, 2, 3 输出
>>> model_headless = get_model("resnet18_8xb32_in1k", head=None, neck=None, backbone=dict(out_indices=(1, 2, 3)))
```
得到模型后,你可以进行推理:
```
>>> import torch
>>> from mmpretrain import get_model
>>> model = get_model('convnext-base_in21k-pre_3rdparty_in1k', pretrained=True)
>>> x = torch.rand((1, 3, 224, 224))
>>> y = model(x)
>>> print(type(y), y.shape)
<class 'torch.Tensor'> torch.Size([1, 1000])
```
## 在给定图像上进行推理
这是一个使用 ImageNet-1k 预训练权重在给定图像上构建推理器的示例。
```
>>> from mmpretrain import ImageClassificationInferencer
>>> inferencer = ImageClassificationInferencer('resnet50_8xb32_in1k')
>>> results = inferencer('https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG')
>>> print(results[0]['pred_class'])
sea snake
```
result 是一个包含 pred_label、pred_score、pred_scores 和 pred_class 的字典,结果如下:
```{text}
{"pred_label":65,"pred_score":0.6649366617202759,"pred_class":"sea snake", "pred_scores": [..., 0.6649366617202759, ...]}
```
演示可以在 [demo/image_demo.py](https://github.com/open-mmlab/mmpretrain/blob/main/demo/image_demo.py) 中找到。
如果你想使用自己的配置和权重:
```
>>> from mmpretrain import ImageClassificationInferencer
>>> inferencer = ImageClassificationInferencer(
model='configs/resnet/resnet50_8xb32_in1k.py',
pretrained='https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth',
device='cuda')
>>> inferencer('https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG')
```
你还可以在CUDA上通过批处理进行多个图像的推理
```{python}
>>> from mmpretrain import ImageClassificationInferencer
>>> inferencer = ImageClassificationInferencer('resnet50_8xb32_in1k', device='cuda')
>>> imgs = ['https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG'] * 5
>>> results = inferencer(imgs, batch_size=2)
>>> print(results[1]['pred_class'])
sea snake
```
## 从图像中提取特征
`model.extract_feat` 相比,`FeatureExtractor` 用于直接从图像文件中提取特征,而不是从一批张量中提取特征。简单说,`model.extract_feat` 的输入是 `torch.Tensor``FeatureExtractor` 的输入是图像。
```
>>> from mmpretrain import FeatureExtractor, get_model
>>> model = get_model('resnet50_8xb32_in1k', backbone=dict(out_indices=(0, 1, 2, 3)))
>>> extractor = FeatureExtractor(model)
>>> features = extractor('https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG')[0]
>>> features[0].shape, features[1].shape, features[2].shape, features[3].shape
(torch.Size([256]), torch.Size([512]), torch.Size([1024]), torch.Size([2048]))
```

View File

@ -28,7 +28,7 @@ class ImageClassificationInferencer(BaseInferencer):
file, or a :obj:`BaseModel` object. The model name can be found
by ``ImageClassificationInferencer.list_models()`` and you can also
query it in :doc:`/modelzoo_statistics`.
weights (str, optional): Path to the checkpoint. If None, it will try
pretrained (str, optional): Path to the checkpoint. If None, it will try
to find a pre-defined weight from the model you specified
(only work if the ``model`` is a model name). Defaults to None.
device (str, optional): Device to run inference. If None, use CPU or
@ -51,7 +51,7 @@ class ImageClassificationInferencer(BaseInferencer):
>>> from mmpretrain import ImageClassificationInferencer
>>> inferencer = ImageClassificationInferencer(
model='configs/resnet/resnet50_8xb32_in1k.py',
weights='https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth',
pretrained='https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth',
device='cuda')
>>> inferencer(['demo/dog.jpg', 'demo/bird.JPEG'], show_dir="./visualize/")
""" # noqa: E501