mmpretrain/docs/zh_CN/user_guides/inference.md

118 lines
4.3 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# 使用现有模型进行推理
本文将展示如何使用以下API
1. [**`list_models`**](mmpretrain.apis.list_models) 和 [**`get_model`**](mmpretrain.apis.get_model) :列出 MMPreTrain 中的模型并获取模型。
2. [**`ImageClassificationInferencer`**](mmpretrain.apis.ImageClassificationInferencer): 在给定图像上进行推理。
3. [**`FeatureExtractor`**](mmpretrain.apis.FeatureExtractor): 从图像文件直接提取特征。
## 列出模型和获取模型
列出 MMPreTrain 中的所有已支持的模型。
```
>>> from mmpretrain import list_models
>>> list_models()
['barlowtwins_resnet50_8xb256-coslr-300e_in1k',
'beit-base-p16_beit-in21k-pre_3rdparty_in1k',
.................]
```
`list_models` 支持模糊匹配,您可以使用 **\*** 匹配任意字符。
```
>>> from mmpretrain import list_models
>>> list_models("*convnext-b*21k")
['convnext-base_3rdparty_in21k',
'convnext-base_in21k-pre-3rdparty_in1k-384px',
'convnext-base_in21k-pre_3rdparty_in1k']
```
了解了已经支持了哪些模型后,你可以使用 `get_model` 获取特定模型。
```
>>> from mmpretrain import get_model
# 没有预训练权重的模型
>>> model = get_model("convnext-base_in21k-pre_3rdparty_in1k")
# 使用MMPreTrain中默认的权重
>>> model = get_model("convnext-base_in21k-pre_3rdparty_in1k", pretrained=True)
# 使用本地权重
>>> model = get_model("convnext-base_in21k-pre_3rdparty_in1k", pretrained="your_local_checkpoint_path")
# 您还可以做一些修改,例如修改 head 中的 num_classes。
>>> model = get_model("convnext-base_in21k-pre_3rdparty_in1k", head=dict(num_classes=10))
# 您可以获得没有 neckhead 的模型,并直接从 backbone 中的 stage 1, 2, 3 输出
>>> model_headless = get_model("resnet18_8xb32_in1k", head=None, neck=None, backbone=dict(out_indices=(1, 2, 3)))
```
得到模型后,你可以进行推理:
```
>>> import torch
>>> from mmpretrain import get_model
>>> model = get_model('convnext-base_in21k-pre_3rdparty_in1k', pretrained=True)
>>> x = torch.rand((1, 3, 224, 224))
>>> y = model(x)
>>> print(type(y), y.shape)
<class 'torch.Tensor'> torch.Size([1, 1000])
```
## 在给定图像上进行推理
这是一个使用 ImageNet-1k 预训练权重在给定图像上构建推理器的示例。
```
>>> from mmpretrain import ImageClassificationInferencer
>>> inferencer = ImageClassificationInferencer('resnet50_8xb32_in1k')
>>> results = inferencer('https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG')
>>> print(results[0]['pred_class'])
sea snake
```
result 是一个包含 pred_label、pred_score、pred_scores 和 pred_class 的字典,结果如下:
```{text}
{"pred_label":65,"pred_score":0.6649366617202759,"pred_class":"sea snake", "pred_scores": [..., 0.6649366617202759, ...]}
```
如果你想使用自己的配置和权重:
```
>>> from mmpretrain import ImageClassificationInferencer
>>> inferencer = ImageClassificationInferencer(
model='configs/resnet/resnet50_8xb32_in1k.py',
pretrained='https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth',
device='cuda')
>>> inferencer('https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG')
```
你还可以在CUDA上通过批处理进行多个图像的推理
```{python}
>>> from mmpretrain import ImageClassificationInferencer
>>> inferencer = ImageClassificationInferencer('resnet50_8xb32_in1k', device='cuda')
>>> imgs = ['https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG'] * 5
>>> results = inferencer(imgs, batch_size=2)
>>> print(results[1]['pred_class'])
sea snake
```
## 从图像中提取特征
`model.extract_feat` 相比,`FeatureExtractor` 用于直接从图像文件中提取特征,而不是从一批张量中提取特征。简单说,`model.extract_feat` 的输入是 `torch.Tensor``FeatureExtractor` 的输入是图像。
```
>>> from mmpretrain import FeatureExtractor, get_model
>>> model = get_model('resnet50_8xb32_in1k', backbone=dict(out_indices=(0, 1, 2, 3)))
>>> extractor = FeatureExtractor(model)
>>> features = extractor('https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG')[0]
>>> features[0].shape, features[1].shape, features[2].shape, features[3].shape
(torch.Size([256]), torch.Size([512]), torch.Size([1024]), torch.Size([2048]))
```