mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Feature]: add onnxruntime test tool (#212)
* [draft] add onnxruntime accuruacy verification * fix a bug * update code * fix lint * fix lint * update code and doc * update doc * update code * update code * updata doc and updata code * update doc and fix some bug * update doc * update doc * update doc * update doc * update doc * update doc * fix bug * update doc * update code * move CUDAExecutionProvider to first place * update resnext accuracy * update doc Co-authored-by: maningsheng <maningsheng@sensetime.com>
This commit is contained in:
parent
9be435846c
commit
01d2849b76
@ -43,7 +43,7 @@ Description of all arguments:
|
||||
Example:
|
||||
|
||||
```bash
|
||||
python tools/onnx2tensorrt.py \
|
||||
python tools/deployment/onnx2tensorrt.py \
|
||||
checkpoints/resnet/resnet18_b16x8_cifar10.onnx \
|
||||
--trt-file checkpoints/resnet/resnet18_b16x8_cifar10.trt \
|
||||
--shape 224 224 \
|
||||
|
@ -6,6 +6,12 @@
|
||||
- [How to convert models from Pytorch to ONNX](#how-to-convert-models-from-pytorch-to-onnx)
|
||||
- [Prerequisite](#prerequisite)
|
||||
- [Usage](#usage)
|
||||
- [Description of all arguments](#description-of-all-arguments)
|
||||
- [How to evaluate ONNX models with ONNX Runtime](#how-to-evaluate-onnx-models-with-onnx-runtime)
|
||||
- [Prerequisite](#prerequisite-1)
|
||||
- [Usage](#usage-1)
|
||||
- [Description of all arguments](#description-of-all-arguments-1)
|
||||
- [Results and Models](#results-and-models)
|
||||
- [List of supported models exportable to ONNX](#list-of-supported-models-exportable-to-onnx)
|
||||
- [Reminders](#reminders)
|
||||
- [FAQs](#faqs)
|
||||
@ -26,26 +32,26 @@
|
||||
### Usage
|
||||
|
||||
```bash
|
||||
python tools/pytorch2onnx.py \
|
||||
python tools/deployment/pytorch2onnx.py \
|
||||
${CONFIG_FILE} \
|
||||
--checkpoint ${CHECKPOINT_FILE} \
|
||||
--output-file ${OUTPUT_FILE} \
|
||||
--shape ${IMAGE_SHAPE} \
|
||||
--opset-version ${OPSET_VERSION} \
|
||||
--dynamic-shape \
|
||||
--dynamic-export \
|
||||
--show \
|
||||
--simplify \
|
||||
--verify \
|
||||
```
|
||||
|
||||
Description of all arguments:
|
||||
### Description of all arguments:
|
||||
|
||||
- `config` : The path of a model config file.
|
||||
- `--checkpoint` : The path of a model checkpoint file.
|
||||
- `--output-file`: The path of output ONNX model. If not specified, it will be set to `tmp.onnx`.
|
||||
- `--shape`: The height and width of input tensor to the model. If not specified, it will be set to `224 224`.
|
||||
- `--opset-version` : The opset version of ONNX. If not specified, it will be set to `11`.
|
||||
- `--dynamic-shape` : Determines whether to export ONNX with dynamic input shape. If not specified, it will be set to `False`.
|
||||
- `--dynamic-export` : Determines whether to export ONNX with dynamic input shape and output shapes. If not specified, it will be set to `False`.
|
||||
- `--show`: Determines whether to print the architecture of the exported model. If not specified, it will be set to `False`.
|
||||
- `--simplify`: Determines whether to simplify the exported ONNX model. If not specified, it will be set to `False`.
|
||||
- `--verify`: Determines whether to verify the correctness of an exported model. If not specified, it will be set to `False`.
|
||||
@ -53,28 +59,121 @@ Description of all arguments:
|
||||
Example:
|
||||
|
||||
```bash
|
||||
python tools/pytorch2onnx.py \
|
||||
python tools/deployment/pytorch2onnx.py \
|
||||
configs/resnet/resnet18_b16x8_cifar10.py \
|
||||
--checkpoint checkpoints/resnet/resnet18_b16x8_cifar10.pth \
|
||||
--output-file checkpoints/resnet/resnet18_b16x8_cifar10.onnx \
|
||||
--dynamic-shape \
|
||||
--dynamic-export \
|
||||
--show \
|
||||
--simplify \
|
||||
--verify \
|
||||
```
|
||||
|
||||
## How to evaluate ONNX models with ONNX Runtime
|
||||
|
||||
We prepare a tool `tools/deployment/test.py` to evaluate ONNX models with ONNX Runtime backend.
|
||||
|
||||
### Prerequisite
|
||||
|
||||
- Install onnx and onnxruntime-gpu
|
||||
|
||||
```shell
|
||||
pip install onnx onnxruntime-gpu
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
```bash
|
||||
python tools/deployment/test.py \
|
||||
${CONFIG_FILE} \
|
||||
${ONNX_FILE} \
|
||||
--out ${OUTPUT_FILE} \
|
||||
--metrics ${EVALUATION_METRICS} \
|
||||
--metric-options ${EVALUATION_OPTIONS} \
|
||||
--show
|
||||
--show-dir ${SHOW_DIRECTORY} \
|
||||
--cfg-options ${CFG_OPTIONS} \
|
||||
```
|
||||
|
||||
### Description of all arguments
|
||||
|
||||
- `config`: The path of a model config file.
|
||||
- `model`: The path of a ONNX model file.
|
||||
- `--out`: The path of output result file in pickle format.
|
||||
- `--metrics`: Evaluation metrics, which depends on the dataset, e.g., "accuracy", "precision", "recall", "f1_score", "support" for single label dataset, and "mAP", "CP", "CR", "CF1", "OP", "OR", "OF1" for multi-label dataset.
|
||||
- `--show`: Determines whether to show classifier outputs. If not specified, it will be set to `False`.
|
||||
- `--show-dir`: Directory where painted images will be saved
|
||||
- `--metrics-options`: Custom options for evaluation, the key-value pair in `xxx=yyy` format will be kwargs for `dataset.evaluate()` function
|
||||
- `--cfg-options`: Override some settings in the used config file, the key-value pair in `xxx=yyy` format will be merged into config file.
|
||||
|
||||
### Results and Models
|
||||
|
||||
This part selects ImageNet for onnxruntime verification. ImageNet has multiple versions, but the most commonly used one is [ILSVRC 2012](http://www.image-net.org/challenges/LSVRC/2012/).
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th>Model</th>
|
||||
<th>Config</th>
|
||||
<th>Metric</th>
|
||||
<th>PyTorch</th>
|
||||
<th>ONNX Runtime</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ResNet</td>
|
||||
<td>resnet50_b32x8_imagenet.py</td>
|
||||
<td>Top 1 / 5</td>
|
||||
<td>76.55 / 93.15</td>
|
||||
<td>76.49 / 93.22</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ResNeXt</td>
|
||||
<td>resnext50_32x4d_b32x8_imagenet.py</td>
|
||||
<td>Top 1 / 5</td>
|
||||
<td>77.83 / 93.65</td>
|
||||
<td>77.83 / 93.65</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>SE-ResNet</td>
|
||||
<td>seresnet50_b32x8_imagenet.py</td>
|
||||
<td>Top 1 / 5</td>
|
||||
<td>77.74 / 93.84</td>
|
||||
<td>77.74 / 93.84</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ShuffleNetV1</td>
|
||||
<td>shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py</td>
|
||||
<td>Top 1 / 5</td>
|
||||
<td>68.13 / 87.81</td>
|
||||
<td>68.13 / 87.81</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ShuffleNetV2</td>
|
||||
<td>shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py</td>
|
||||
<td>Top 1 / 5</td>
|
||||
<td>69.55 / 88.92</td>
|
||||
<td>69.55 / 88.92</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>MobileNetV2</td>
|
||||
<td>mobilenet_v2_b32x8_imagenet.py</td>
|
||||
<td>Top 1 / 5</td>
|
||||
<td>71.86 / 90.42</td>
|
||||
<td>71.86 / 90.42</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## List of supported models exportable to ONNX
|
||||
|
||||
The table below lists the models that are guaranteed to be exportable to ONNX and runnable in ONNX Runtime.
|
||||
|
||||
| Model | Config | Batch Inference | Dynamic Shape | Note |
|
||||
| :----------: | :----------------------------------------------------------: | :-------------: | :-----------: | ---- |
|
||||
| MobileNetV2 | `configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py` | Y | Y | |
|
||||
| ResNet | `configs/resnet/resnet18_b16x8_cifar10.py` | Y | Y | |
|
||||
| ResNeXt | `configs/resnext/resnext50_32x4d_b32x8_imagenet.py` | Y | Y | |
|
||||
| SE-ResNet | `configs/seresnet/seresnet50_b32x8_imagenet.py` | Y | Y | |
|
||||
| ShuffleNetV1 | `configs/shufflenet_v1/shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py` | Y | Y | |
|
||||
| ShuffleNetV2 | `configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py` | Y | Y | |
|
||||
| MobileNetV2 | [mobilenet_v2_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py) | Y | Y | |
|
||||
| ResNet | [resnet18_b16x8_cifar10.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnet/resnet18_b16x8_cifar10.py) | Y | Y | |
|
||||
| ResNeXt | [resnext50_32x4d_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext/resnext50_32x4d_b32x8_imagenet.py) | Y | Y | |
|
||||
| SE-ResNet | [seresnet50_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/seresnet/seresnet50_b32x8_imagenet.py) | Y | Y | |
|
||||
| ShuffleNetV1 | [shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v1/shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py) | Y | Y | |
|
||||
| ShuffleNetV2 | [shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py) | Y | Y | |
|
||||
|
||||
Notes:
|
||||
|
||||
|
3
mmcls/core/export/__init__.py
Normal file
3
mmcls/core/export/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .test import ONNXRuntimeClassifier
|
||||
|
||||
__all__ = ['ONNXRuntimeClassifier']
|
57
mmcls/core/export/test.py
Normal file
57
mmcls/core/export/test.py
Normal file
@ -0,0 +1,57 @@
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
from mmcls.models.classifiers import BaseClassifier
|
||||
|
||||
|
||||
class ONNXRuntimeClassifier(BaseClassifier):
|
||||
"""Wrapper for classifier's inference with ONNXRuntime."""
|
||||
|
||||
def __init__(self, onnx_file, class_names, device_id):
|
||||
super(ONNXRuntimeClassifier, self).__init__()
|
||||
sess = ort.InferenceSession(onnx_file)
|
||||
|
||||
providers = ['CPUExecutionProvider']
|
||||
options = [{}]
|
||||
is_cuda_available = ort.get_device() == 'GPU'
|
||||
if is_cuda_available:
|
||||
providers.insert(0, 'CUDAExecutionProvider')
|
||||
options.insert(0, {'device_id': device_id})
|
||||
sess.set_providers(providers, options)
|
||||
|
||||
self.sess = sess
|
||||
self.CLASSES = class_names
|
||||
self.device_id = device_id
|
||||
self.io_binding = sess.io_binding()
|
||||
self.output_names = [_.name for _ in sess.get_outputs()]
|
||||
self.is_cuda_available = is_cuda_available
|
||||
|
||||
def simple_test(self, img, img_metas, **kwargs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def extract_feat(self, imgs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def forward_train(self, imgs, **kwargs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def forward_test(self, imgs, img_metas, **kwargs):
|
||||
input_data = imgs
|
||||
# set io binding for inputs/outputs
|
||||
device_type = 'cuda' if self.is_cuda_available else 'cpu'
|
||||
if not self.is_cuda_available:
|
||||
input_data = input_data.cpu()
|
||||
self.io_binding.bind_input(
|
||||
name='input',
|
||||
device_type=device_type,
|
||||
device_id=self.device_id,
|
||||
element_type=np.float32,
|
||||
shape=input_data.shape,
|
||||
buffer_ptr=input_data.data_ptr())
|
||||
|
||||
for name in self.output_names:
|
||||
self.io_binding.bind_output(name)
|
||||
# run session to get outputs
|
||||
self.sess.run_with_iobinding(self.io_binding)
|
||||
results = self.io_binding.copy_outputs_to_cpu()[0]
|
||||
return list(results)
|
@ -37,7 +37,7 @@ def _demo_mm_inputs(input_shape, num_classes):
|
||||
def pytorch2onnx(model,
|
||||
input_shape,
|
||||
opset_version=11,
|
||||
dynamic_shape=False,
|
||||
dynamic_export=False,
|
||||
show=False,
|
||||
output_file='tmp.onnx',
|
||||
do_simplify=False,
|
||||
@ -70,7 +70,7 @@ def pytorch2onnx(model,
|
||||
register_extra_symbolics(opset_version)
|
||||
|
||||
# support dynamic shape export
|
||||
if dynamic_shape:
|
||||
if dynamic_export:
|
||||
dynamic_axes = {
|
||||
'input': {
|
||||
0: 'batch',
|
||||
@ -121,7 +121,7 @@ def pytorch2onnx(model,
|
||||
output_file,
|
||||
input_shapes=input_shape_dic,
|
||||
input_data=input_dic,
|
||||
dynamic_input_shape=dynamic_shape)
|
||||
dynamic_input_shape=dynamic_export)
|
||||
if verify:
|
||||
# check by onnx
|
||||
import onnx
|
||||
@ -129,7 +129,7 @@ def pytorch2onnx(model,
|
||||
onnx.checker.check_model(onnx_model)
|
||||
|
||||
# test the dynamic model
|
||||
if dynamic_shape:
|
||||
if dynamic_export:
|
||||
dynamic_test_inputs = _demo_mm_inputs(
|
||||
(input_shape[0], input_shape[1], input_shape[2] * 2,
|
||||
input_shape[3] * 2), model.head.num_classes)
|
||||
@ -176,7 +176,7 @@ def parse_args():
|
||||
default=[224, 224],
|
||||
help='input image size')
|
||||
parser.add_argument(
|
||||
'--dynamic-shape',
|
||||
'--dynamic-export',
|
||||
action='store_true',
|
||||
help='Whether to export ONNX with dynamic input shape. \
|
||||
Defaults to False.')
|
||||
@ -212,7 +212,7 @@ if __name__ == '__main__':
|
||||
input_shape,
|
||||
opset_version=args.opset_version,
|
||||
show=args.show,
|
||||
dynamic_shape=args.dynamic_shape,
|
||||
dynamic_export=args.dynamic_export,
|
||||
output_file=args.output_file,
|
||||
do_simplify=args.simplify,
|
||||
verify=args.verify)
|
103
tools/deployment/test.py
Normal file
103
tools/deployment/test.py
Normal file
@ -0,0 +1,103 @@
|
||||
import argparse
|
||||
import warnings
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv import DictAction
|
||||
from mmcv.parallel import MMDataParallel
|
||||
|
||||
from mmcls.apis import single_gpu_test
|
||||
from mmcls.core.export import ONNXRuntimeClassifier
|
||||
from mmcls.datasets import build_dataloader, build_dataset
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Test (and eval) an ONNX model using ONNXRuntime.')
|
||||
parser.add_argument('config', help='model config file')
|
||||
parser.add_argument('model', help='filename of the input ONNX model')
|
||||
parser.add_argument(
|
||||
'--out', type=str, help='output result file in pickle format')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='override some settings in the used config, the key-value pair '
|
||||
'in xxx=yyy format will be merged into config file.')
|
||||
parser.add_argument(
|
||||
'--metrics',
|
||||
type=str,
|
||||
nargs='+',
|
||||
help='evaluation metrics, which depends on the dataset, e.g., '
|
||||
'"accuracy", "precision", "recall", "f1_score", "support" for single '
|
||||
'label dataset, and "mAP", "CP", "CR", "CF1", "OP", "OR", "OF1" for '
|
||||
'multi-label dataset')
|
||||
parser.add_argument(
|
||||
'--metric-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
default={},
|
||||
help='custom options for evaluation, the key-value pair in xxx=yyy '
|
||||
'format will be parsed as a dict metric_options for dataset.evaluate()'
|
||||
' function.')
|
||||
parser.add_argument('--show', action='store_true', help='show results')
|
||||
parser.add_argument(
|
||||
'--show-dir', help='directory where painted images will be saved')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
|
||||
raise ValueError('The output file must be a pkl file.')
|
||||
|
||||
cfg = mmcv.Config.fromfile(args.config)
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
# build dataset and dataloader
|
||||
dataset = build_dataset(cfg.data.test)
|
||||
data_loader = build_dataloader(
|
||||
dataset,
|
||||
samples_per_gpu=cfg.data.samples_per_gpu,
|
||||
workers_per_gpu=cfg.data.workers_per_gpu,
|
||||
shuffle=False,
|
||||
round_up=False)
|
||||
|
||||
# build onnxruntime model and run inference.
|
||||
model = ONNXRuntimeClassifier(
|
||||
args.model, class_names=dataset.CLASSES, device_id=0)
|
||||
|
||||
model = MMDataParallel(model, device_ids=[0])
|
||||
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir)
|
||||
|
||||
if args.metrics:
|
||||
results = dataset.evaluate(outputs, args.metrics, args.metric_options)
|
||||
for k, v in results.items():
|
||||
print(f'\n{k} : {v:.2f}')
|
||||
else:
|
||||
warnings.warn('Evaluation metrics are not specified.')
|
||||
scores = np.vstack(outputs)
|
||||
pred_score = np.max(scores, axis=1)
|
||||
pred_label = np.argmax(scores, axis=1)
|
||||
pred_class = [dataset.CLASSES[lb] for lb in pred_label]
|
||||
results = {
|
||||
'pred_score': pred_score,
|
||||
'pred_label': pred_label,
|
||||
'pred_class': pred_class
|
||||
}
|
||||
if not args.out:
|
||||
print('\nthe predicted result for the first element is '
|
||||
f'pred_score = {pred_score[0]:.2f}, '
|
||||
f'pred_label = {pred_label[0]} '
|
||||
f'and pred_class = {pred_class[0]}. '
|
||||
'Specify --out to save all results to files.')
|
||||
if args.out:
|
||||
print(f'\nwriting results to {args.out}')
|
||||
mmcv.dump(results, args.out)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user