2022-08-31 14:21:56 +08:00
# Inference with existing models
2023-05-06 17:54:13 +08:00
This tutorial will show how to use the following APIs:
2022-08-31 14:21:56 +08:00
2023-05-19 16:50:04 +08:00
- [**`list_models`** ](mmpretrain.apis.list_models ): List available model names in MMPreTrain.
- [**`get_model`** ](mmpretrain.apis.get_model ): Get a model from model name or model config.
- [**`inference_model`** ](mmpretrain.apis.inference_model ): Inference a model with the correspondding
inferencer. It's a shortcut for a quick start, and for advanced usage, please use the below inferencer
directly.
- Inferencers:
1. [**`ImageClassificationInferencer`** ](mmpretrain.apis.ImageClassificationInferencer ):
Perform image classification on the given image.
2. [**`ImageRetrievalInferencer`** ](mmpretrain.apis.ImageRetrievalInferencer ):
Perform image-to-image retrieval from the given image on a given image set.
3. [**`ImageCaptionInferencer`** ](mmpretrain.apis.ImageCaptionInferencer ):
Generate a caption on the given image.
4. [**`VisualQuestionAnsweringInferencer`** ](mmpretrain.apis.VisualQuestionAnsweringInferencer ):
Answer a question according to the given image.
5. [**`VisualGroundingInferencer`** ](mmpretrain.apis.VisualGroundingInferencer ):
Locate an object from the description on the given image.
6. [**`TextToImageRetrievalInferencer`** ](mmpretrain.apis.TextToImageRetrievalInferencer ):
Perform text-to-image retrieval from the given description on a given image set.
7. [**`ImageToTextRetrievalInferencer`** ](mmpretrain.apis.ImageToTextRetrievalInferencer ):
Perform image-to-text retrieval from the given image on a series of text.
8. [**`NLVRInferencer`** ](mmpretrain.apis.NLVRInferencer ):
Perform Natural Language for Visual Reasoning on a given image-pair and text.
9. [**`FeatureExtractor`** ](mmpretrain.apis.FeatureExtractor ):
Extract features from the image files by a vision backbone.
## List available models
2022-08-31 14:21:56 +08:00
2023-05-06 17:54:13 +08:00
list all the models in MMPreTrain.
2022-08-31 14:21:56 +08:00
2023-05-19 16:50:04 +08:00
```python
2023-05-06 17:54:13 +08:00
>>> from mmpretrain import list_models
>>> list_models()
['barlowtwins_resnet50_8xb256-coslr-300e_in1k',
'beit-base-p16_beit-in21k-pre_3rdparty_in1k',
2023-05-19 16:50:04 +08:00
...]
2023-05-06 17:54:13 +08:00
```
2022-08-31 14:21:56 +08:00
2023-05-19 16:50:04 +08:00
`list_models` supports Unix filename pattern matching, you can use \*\* * \* \* to match any character.
2022-08-31 14:21:56 +08:00
2023-05-19 16:50:04 +08:00
```python
2023-05-06 17:54:13 +08:00
>>> from mmpretrain import list_models
>>> list_models("*convnext-b*21k")
['convnext-base_3rdparty_in21k',
'convnext-base_in21k-pre-3rdparty_in1k-384px',
'convnext-base_in21k-pre_3rdparty_in1k']
2022-08-31 14:21:56 +08:00
```
2023-05-19 16:50:04 +08:00
You can use the `list_models` method of inferencers to get the available models of the correspondding tasks.
2023-05-06 17:54:13 +08:00
2023-05-19 16:50:04 +08:00
```python
>>> from mmpretrain import ImageCaptionInferencer
>>> ImageCaptionInferencer.list_models()
['blip-base_3rdparty_caption',
'blip2-opt2.7b_3rdparty-zeroshot_caption',
'flamingo_3rdparty-zeroshot_caption',
'ofa-base_3rdparty-finetuned_caption']
2023-05-06 17:54:13 +08:00
```
2023-05-19 16:50:04 +08:00
## Get a model
you can use `get_model` get the model.
```python
2023-05-06 17:54:13 +08:00
>>> from mmpretrain import get_model
2023-05-19 16:50:04 +08:00
# Get model without loading pre-trained weight.
2023-05-06 17:54:13 +08:00
>>> model = get_model("convnext-base_in21k-pre_3rdparty_in1k")
2023-05-19 16:50:04 +08:00
# Get model and load the default checkpoint.
2023-05-06 17:54:13 +08:00
>>> model = get_model("convnext-base_in21k-pre_3rdparty_in1k", pretrained=True)
2023-05-19 16:50:04 +08:00
# Get model and load the specified checkpoint.
2023-05-06 17:54:13 +08:00
>>> model = get_model("convnext-base_in21k-pre_3rdparty_in1k", pretrained="your_local_checkpoint_path")
2023-05-19 16:50:04 +08:00
# Get model with extra initialization arguments, for example, modify the num_classes in head.
2023-05-06 17:54:13 +08:00
>>> model = get_model("convnext-base_in21k-pre_3rdparty_in1k", head=dict(num_classes=10))
2023-05-19 16:50:04 +08:00
# Another example, remove the neck and head, and output from stage 1, 2, 3 in backbone
2023-05-06 17:54:13 +08:00
>>> model_headless = get_model("resnet18_8xb32_in1k", head=None, neck=None, backbone=dict(out_indices=(1, 2, 3)))
```
2023-05-19 16:50:04 +08:00
The obtained model is a usual PyTorch module.
2023-05-06 17:54:13 +08:00
2023-05-19 16:50:04 +08:00
```python
2023-05-06 17:54:13 +08:00
>>> import torch
>>> from mmpretrain import get_model
>>> model = get_model('convnext-base_in21k-pre_3rdparty_in1k', pretrained=True)
>>> x = torch.rand((1, 3, 224, 224))
>>> y = model(x)
>>> print(type(y), y.shape)
< class ' torch . Tensor ' > torch.Size([1, 1000])
```
2022-08-31 14:21:56 +08:00
2023-05-19 16:50:04 +08:00
## Inference on given images
2023-05-06 17:54:13 +08:00
2023-05-19 16:50:04 +08:00
Here is an example to inference an [image ](https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG ) by the ResNet-50 pre-trained classification model.
2022-08-31 14:21:56 +08:00
2023-05-06 17:54:13 +08:00
```python
2023-05-19 16:50:04 +08:00
>>> from mmpretrain import inference_model
>>> image = 'https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG'
>>> # If you have no graphical interface, please set `show=False`
>>> result = inference_model('resnet50_8xb32_in1k', image, show=True)
>>> print(result['pred_class'])
2023-05-06 17:54:13 +08:00
sea snake
2022-08-31 14:21:56 +08:00
```
2023-05-19 16:50:04 +08:00
The `inference_model` API is only for demo and cannot keep the model instance or inference on multiple
samples. You can use the inferencers for multiple calling.
2022-08-31 14:21:56 +08:00
2023-05-19 16:50:04 +08:00
```python
>>> from mmpretrain import ImageClassificationInferencer
>>> image = 'https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG'
>>> inferencer = ImageClassificationInferencer('resnet50_8xb32_in1k')
>>> # Note that the inferencer output is a list of result even if the input is a single sample.
>>> result = inferencer('https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG')[0]
>>> print(result['pred_class'])
sea snake
>>>
>>> # You can also use is for multiple images.
>>> image_list = ['demo/demo.JPEG', 'demo/bird.JPEG'] * 16
>>> results = inferencer(image_list, batch_size=8)
>>> print(len(results))
32
>>> print(results[1]['pred_class'])
house finch, linnet, Carpodacus mexicanus
2022-08-31 14:21:56 +08:00
```
2023-05-19 16:50:04 +08:00
Usually, the result for every sample is a dictionary. For example, the image classification result is a dictionary containing `pred_label` , `pred_score` , `pred_scores` and `pred_class` as follows:
2023-05-06 17:54:13 +08:00
2023-05-19 16:50:04 +08:00
```python
{
"pred_label": 65,
"pred_score": 0.6649366617202759,
"pred_class":"sea snake",
"pred_scores": array([..., 0.6649366617202759, ...], dtype=float32)
}
2023-05-06 17:54:13 +08:00
```
2023-05-19 16:50:04 +08:00
You can configure the inferencer by arguments, for example, use your own config file and checkpoint to
inference images by CUDA.
2023-05-06 17:54:13 +08:00
```python
>>> from mmpretrain import ImageClassificationInferencer
2023-05-19 16:50:04 +08:00
>>> image = 'https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG'
>>> config = 'configs/resnet/resnet50_8xb32_in1k.py'
>>> checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth'
>>> inferencer = ImageClassificationInferencer(model=config, pretrained=checkpoint, device='cuda')
>>> result = inferencer(image)[0]
>>> print(result['pred_class'])
2023-05-06 17:54:13 +08:00
sea snake
```
2023-05-19 16:50:04 +08:00
## Inference by a Gradio demo
We also provide a gradio demo for all supported tasks and you can find it in [projects/gradio_demo/launch.py ](https://github.com/open-mmlab/mmpretrain/blob/main/projects/gradio_demo/launch.py ).
Please install `gradio` by `pip install -U gradio` at first.
Here is the interface preview:
< img src = "https://user-images.githubusercontent.com/26739999/236147750-90ccb517-92c0-44e9-905e-1473677023b1.jpg" width = "100%" / >
2023-05-06 17:54:13 +08:00
## Extract Features From Image
Compared with `model.extract_feat` , `FeatureExtractor` is used to extract features from the image files directly, instead of a batch of tensors.
In a word, the input of `model.extract_feat` is `torch.Tensor` , the input of `FeatureExtractor` is images.
2023-05-19 16:50:04 +08:00
```python
2023-05-06 17:54:13 +08:00
>>> from mmpretrain import FeatureExtractor, get_model
>>> model = get_model('resnet50_8xb32_in1k', backbone=dict(out_indices=(0, 1, 2, 3)))
>>> extractor = FeatureExtractor(model)
>>> features = extractor('https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG')[0]
>>> features[0].shape, features[1].shape, features[2].shape, features[3].shape
(torch.Size([256]), torch.Size([512]), torch.Size([1024]), torch.Size([2048]))
```