[Feature]: add onnxruntime test tool (#212)

* [draft] add onnxruntime accuruacy verification

* fix a bug

* update code

* fix lint

* fix lint

* update code and doc

* update doc

* update code

* update code

* updata doc and updata code

* update doc and fix some bug

* update doc

* update doc

* update doc

* update doc

* update doc

* update doc

* fix bug

* update doc

* update code

* move CUDAExecutionProvider to first place

* update resnext accuracy

* update doc

Co-authored-by: maningsheng <maningsheng@sensetime.com>
This commit is contained in:
QingChuanWS 2021-04-26 13:57:08 +08:00 committed by GitHub
parent 9be435846c
commit 01d2849b76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 281 additions and 19 deletions

View File

@ -43,7 +43,7 @@ Description of all arguments:
Example:
```bash
python tools/onnx2tensorrt.py \
python tools/deployment/onnx2tensorrt.py \
checkpoints/resnet/resnet18_b16x8_cifar10.onnx \
--trt-file checkpoints/resnet/resnet18_b16x8_cifar10.trt \
--shape 224 224 \

View File

@ -6,6 +6,12 @@
- [How to convert models from Pytorch to ONNX](#how-to-convert-models-from-pytorch-to-onnx)
- [Prerequisite](#prerequisite)
- [Usage](#usage)
- [Description of all arguments](#description-of-all-arguments)
- [How to evaluate ONNX models with ONNX Runtime](#how-to-evaluate-onnx-models-with-onnx-runtime)
- [Prerequisite](#prerequisite-1)
- [Usage](#usage-1)
- [Description of all arguments](#description-of-all-arguments-1)
- [Results and Models](#results-and-models)
- [List of supported models exportable to ONNX](#list-of-supported-models-exportable-to-onnx)
- [Reminders](#reminders)
- [FAQs](#faqs)
@ -26,26 +32,26 @@
### Usage
```bash
python tools/pytorch2onnx.py \
python tools/deployment/pytorch2onnx.py \
${CONFIG_FILE} \
--checkpoint ${CHECKPOINT_FILE} \
--output-file ${OUTPUT_FILE} \
--shape ${IMAGE_SHAPE} \
--opset-version ${OPSET_VERSION} \
--dynamic-shape \
--dynamic-export \
--show \
--simplify \
--verify \
```
Description of all arguments:
### Description of all arguments:
- `config` : The path of a model config file.
- `--checkpoint` : The path of a model checkpoint file.
- `--output-file`: The path of output ONNX model. If not specified, it will be set to `tmp.onnx`.
- `--shape`: The height and width of input tensor to the model. If not specified, it will be set to `224 224`.
- `--opset-version` : The opset version of ONNX. If not specified, it will be set to `11`.
- `--dynamic-shape` : Determines whether to export ONNX with dynamic input shape. If not specified, it will be set to `False`.
- `--dynamic-export` : Determines whether to export ONNX with dynamic input shape and output shapes. If not specified, it will be set to `False`.
- `--show`: Determines whether to print the architecture of the exported model. If not specified, it will be set to `False`.
- `--simplify`: Determines whether to simplify the exported ONNX model. If not specified, it will be set to `False`.
- `--verify`: Determines whether to verify the correctness of an exported model. If not specified, it will be set to `False`.
@ -53,28 +59,121 @@ Description of all arguments:
Example:
```bash
python tools/pytorch2onnx.py \
python tools/deployment/pytorch2onnx.py \
configs/resnet/resnet18_b16x8_cifar10.py \
--checkpoint checkpoints/resnet/resnet18_b16x8_cifar10.pth \
--output-file checkpoints/resnet/resnet18_b16x8_cifar10.onnx \
--dynamic-shape \
--dynamic-export \
--show \
--simplify \
--verify \
```
## How to evaluate ONNX models with ONNX Runtime
We prepare a tool `tools/deployment/test.py` to evaluate ONNX models with ONNX Runtime backend.
### Prerequisite
- Install onnx and onnxruntime-gpu
```shell
pip install onnx onnxruntime-gpu
```
### Usage
```bash
python tools/deployment/test.py \
${CONFIG_FILE} \
${ONNX_FILE} \
--out ${OUTPUT_FILE} \
--metrics ${EVALUATION_METRICS} \
--metric-options ${EVALUATION_OPTIONS} \
--show
--show-dir ${SHOW_DIRECTORY} \
--cfg-options ${CFG_OPTIONS} \
```
### Description of all arguments
- `config`: The path of a model config file.
- `model`: The path of a ONNX model file.
- `--out`: The path of output result file in pickle format.
- `--metrics`: Evaluation metrics, which depends on the dataset, e.g., "accuracy", "precision", "recall", "f1_score", "support" for single label dataset, and "mAP", "CP", "CR", "CF1", "OP", "OR", "OF1" for multi-label dataset.
- `--show`: Determines whether to show classifier outputs. If not specified, it will be set to `False`.
- `--show-dir`: Directory where painted images will be saved
- `--metrics-options`: Custom options for evaluation, the key-value pair in `xxx=yyy` format will be kwargs for `dataset.evaluate()` function
- `--cfg-options`: Override some settings in the used config file, the key-value pair in `xxx=yyy` format will be merged into config file.
### Results and Models
This part selects ImageNet for onnxruntime verification. ImageNet has multiple versions, but the most commonly used one is [ILSVRC 2012](http://www.image-net.org/challenges/LSVRC/2012/).
<table>
<tr>
<th>Model</th>
<th>Config</th>
<th>Metric</th>
<th>PyTorch</th>
<th>ONNX Runtime</th>
</tr>
<tr>
<td>ResNet</td>
<td>resnet50_b32x8_imagenet.py</td>
<td>Top 1 / 5</td>
<td>76.55 / 93.15</td>
<td>76.49 / 93.22</td>
</tr>
<tr>
<td>ResNeXt</td>
<td>resnext50_32x4d_b32x8_imagenet.py</td>
<td>Top 1 / 5</td>
<td>77.83 / 93.65</td>
<td>77.83 / 93.65</td>
</tr>
<tr>
<td>SE-ResNet</td>
<td>seresnet50_b32x8_imagenet.py</td>
<td>Top 1 / 5</td>
<td>77.74 / 93.84</td>
<td>77.74 / 93.84</td>
</tr>
<tr>
<td>ShuffleNetV1</td>
<td>shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py</td>
<td>Top 1 / 5</td>
<td>68.13 / 87.81</td>
<td>68.13 / 87.81</td>
</tr>
<tr>
<td>ShuffleNetV2</td>
<td>shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py</td>
<td>Top 1 / 5</td>
<td>69.55 / 88.92</td>
<td>69.55 / 88.92</td>
</tr>
<tr>
<td>MobileNetV2</td>
<td>mobilenet_v2_b32x8_imagenet.py</td>
<td>Top 1 / 5</td>
<td>71.86 / 90.42</td>
<td>71.86 / 90.42</td>
</tr>
</table>
## List of supported models exportable to ONNX
The table below lists the models that are guaranteed to be exportable to ONNX and runnable in ONNX Runtime.
| Model | Config | Batch Inference | Dynamic Shape | Note |
| :----------: | :----------------------------------------------------------: | :-------------: | :-----------: | ---- |
| MobileNetV2 | `configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py` | Y | Y | |
| ResNet | `configs/resnet/resnet18_b16x8_cifar10.py` | Y | Y | |
| ResNeXt | `configs/resnext/resnext50_32x4d_b32x8_imagenet.py` | Y | Y | |
| SE-ResNet | `configs/seresnet/seresnet50_b32x8_imagenet.py` | Y | Y | |
| ShuffleNetV1 | `configs/shufflenet_v1/shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py` | Y | Y | |
| ShuffleNetV2 | `configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py` | Y | Y | |
| MobileNetV2 | [mobilenet_v2_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py) | Y | Y | |
| ResNet | [resnet18_b16x8_cifar10.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnet/resnet18_b16x8_cifar10.py) | Y | Y | |
| ResNeXt | [resnext50_32x4d_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext/resnext50_32x4d_b32x8_imagenet.py) | Y | Y | |
| SE-ResNet | [seresnet50_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/seresnet/seresnet50_b32x8_imagenet.py) | Y | Y | |
| ShuffleNetV1 | [shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v1/shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py) | Y | Y | |
| ShuffleNetV2 | [shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py) | Y | Y | |
Notes:

View File

@ -0,0 +1,3 @@
from .test import ONNXRuntimeClassifier
__all__ = ['ONNXRuntimeClassifier']

57
mmcls/core/export/test.py Normal file
View File

@ -0,0 +1,57 @@
import numpy as np
import onnxruntime as ort
from mmcls.models.classifiers import BaseClassifier
class ONNXRuntimeClassifier(BaseClassifier):
"""Wrapper for classifier's inference with ONNXRuntime."""
def __init__(self, onnx_file, class_names, device_id):
super(ONNXRuntimeClassifier, self).__init__()
sess = ort.InferenceSession(onnx_file)
providers = ['CPUExecutionProvider']
options = [{}]
is_cuda_available = ort.get_device() == 'GPU'
if is_cuda_available:
providers.insert(0, 'CUDAExecutionProvider')
options.insert(0, {'device_id': device_id})
sess.set_providers(providers, options)
self.sess = sess
self.CLASSES = class_names
self.device_id = device_id
self.io_binding = sess.io_binding()
self.output_names = [_.name for _ in sess.get_outputs()]
self.is_cuda_available = is_cuda_available
def simple_test(self, img, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')
def extract_feat(self, imgs):
raise NotImplementedError('This method is not implemented.')
def forward_train(self, imgs, **kwargs):
raise NotImplementedError('This method is not implemented.')
def forward_test(self, imgs, img_metas, **kwargs):
input_data = imgs
# set io binding for inputs/outputs
device_type = 'cuda' if self.is_cuda_available else 'cpu'
if not self.is_cuda_available:
input_data = input_data.cpu()
self.io_binding.bind_input(
name='input',
device_type=device_type,
device_id=self.device_id,
element_type=np.float32,
shape=input_data.shape,
buffer_ptr=input_data.data_ptr())
for name in self.output_names:
self.io_binding.bind_output(name)
# run session to get outputs
self.sess.run_with_iobinding(self.io_binding)
results = self.io_binding.copy_outputs_to_cpu()[0]
return list(results)

View File

@ -37,7 +37,7 @@ def _demo_mm_inputs(input_shape, num_classes):
def pytorch2onnx(model,
input_shape,
opset_version=11,
dynamic_shape=False,
dynamic_export=False,
show=False,
output_file='tmp.onnx',
do_simplify=False,
@ -70,7 +70,7 @@ def pytorch2onnx(model,
register_extra_symbolics(opset_version)
# support dynamic shape export
if dynamic_shape:
if dynamic_export:
dynamic_axes = {
'input': {
0: 'batch',
@ -121,7 +121,7 @@ def pytorch2onnx(model,
output_file,
input_shapes=input_shape_dic,
input_data=input_dic,
dynamic_input_shape=dynamic_shape)
dynamic_input_shape=dynamic_export)
if verify:
# check by onnx
import onnx
@ -129,7 +129,7 @@ def pytorch2onnx(model,
onnx.checker.check_model(onnx_model)
# test the dynamic model
if dynamic_shape:
if dynamic_export:
dynamic_test_inputs = _demo_mm_inputs(
(input_shape[0], input_shape[1], input_shape[2] * 2,
input_shape[3] * 2), model.head.num_classes)
@ -176,7 +176,7 @@ def parse_args():
default=[224, 224],
help='input image size')
parser.add_argument(
'--dynamic-shape',
'--dynamic-export',
action='store_true',
help='Whether to export ONNX with dynamic input shape. \
Defaults to False.')
@ -212,7 +212,7 @@ if __name__ == '__main__':
input_shape,
opset_version=args.opset_version,
show=args.show,
dynamic_shape=args.dynamic_shape,
dynamic_export=args.dynamic_export,
output_file=args.output_file,
do_simplify=args.simplify,
verify=args.verify)

103
tools/deployment/test.py Normal file
View File

@ -0,0 +1,103 @@
import argparse
import warnings
import mmcv
import numpy as np
from mmcv import DictAction
from mmcv.parallel import MMDataParallel
from mmcls.apis import single_gpu_test
from mmcls.core.export import ONNXRuntimeClassifier
from mmcls.datasets import build_dataloader, build_dataset
def parse_args():
parser = argparse.ArgumentParser(
description='Test (and eval) an ONNX model using ONNXRuntime.')
parser.add_argument('config', help='model config file')
parser.add_argument('model', help='filename of the input ONNX model')
parser.add_argument(
'--out', type=str, help='output result file in pickle format')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file.')
parser.add_argument(
'--metrics',
type=str,
nargs='+',
help='evaluation metrics, which depends on the dataset, e.g., '
'"accuracy", "precision", "recall", "f1_score", "support" for single '
'label dataset, and "mAP", "CP", "CR", "CF1", "OP", "OR", "OF1" for '
'multi-label dataset')
parser.add_argument(
'--metric-options',
nargs='+',
action=DictAction,
default={},
help='custom options for evaluation, the key-value pair in xxx=yyy '
'format will be parsed as a dict metric_options for dataset.evaluate()'
' function.')
parser.add_argument('--show', action='store_true', help='show results')
parser.add_argument(
'--show-dir', help='directory where painted images will be saved')
args = parser.parse_args()
return args
def main():
args = parse_args()
if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
raise ValueError('The output file must be a pkl file.')
cfg = mmcv.Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# build dataset and dataloader
dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(
dataset,
samples_per_gpu=cfg.data.samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
shuffle=False,
round_up=False)
# build onnxruntime model and run inference.
model = ONNXRuntimeClassifier(
args.model, class_names=dataset.CLASSES, device_id=0)
model = MMDataParallel(model, device_ids=[0])
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir)
if args.metrics:
results = dataset.evaluate(outputs, args.metrics, args.metric_options)
for k, v in results.items():
print(f'\n{k} : {v:.2f}')
else:
warnings.warn('Evaluation metrics are not specified.')
scores = np.vstack(outputs)
pred_score = np.max(scores, axis=1)
pred_label = np.argmax(scores, axis=1)
pred_class = [dataset.CLASSES[lb] for lb in pred_label]
results = {
'pred_score': pred_score,
'pred_label': pred_label,
'pred_class': pred_class
}
if not args.out:
print('\nthe predicted result for the first element is '
f'pred_score = {pred_score[0]:.2f}, '
f'pred_label = {pred_label[0]} '
f'and pred_class = {pred_class[0]}. '
'Specify --out to save all results to files.')
if args.out:
print(f'\nwriting results to {args.out}')
mmcv.dump(results, args.out)
if __name__ == '__main__':
main()