diff --git a/docs/tutorials/onnx2tensorrt.md b/docs/tutorials/onnx2tensorrt.md
index 0974cb991..27f744d15 100644
--- a/docs/tutorials/onnx2tensorrt.md
+++ b/docs/tutorials/onnx2tensorrt.md
@@ -43,7 +43,7 @@ Description of all arguments:
Example:
```bash
-python tools/onnx2tensorrt.py \
+python tools/deployment/onnx2tensorrt.py \
checkpoints/resnet/resnet18_b16x8_cifar10.onnx \
--trt-file checkpoints/resnet/resnet18_b16x8_cifar10.trt \
--shape 224 224 \
diff --git a/docs/tutorials/pytorch2onnx.md b/docs/tutorials/pytorch2onnx.md
index d1d48fcf5..c2315eea9 100644
--- a/docs/tutorials/pytorch2onnx.md
+++ b/docs/tutorials/pytorch2onnx.md
@@ -6,6 +6,12 @@
- [How to convert models from Pytorch to ONNX](#how-to-convert-models-from-pytorch-to-onnx)
- [Prerequisite](#prerequisite)
- [Usage](#usage)
+ - [Description of all arguments](#description-of-all-arguments)
+ - [How to evaluate ONNX models with ONNX Runtime](#how-to-evaluate-onnx-models-with-onnx-runtime)
+ - [Prerequisite](#prerequisite-1)
+ - [Usage](#usage-1)
+ - [Description of all arguments](#description-of-all-arguments-1)
+ - [Results and Models](#results-and-models)
- [List of supported models exportable to ONNX](#list-of-supported-models-exportable-to-onnx)
- [Reminders](#reminders)
- [FAQs](#faqs)
@@ -26,26 +32,26 @@
### Usage
```bash
-python tools/pytorch2onnx.py \
+python tools/deployment/pytorch2onnx.py \
${CONFIG_FILE} \
--checkpoint ${CHECKPOINT_FILE} \
--output-file ${OUTPUT_FILE} \
--shape ${IMAGE_SHAPE} \
--opset-version ${OPSET_VERSION} \
- --dynamic-shape \
+ --dynamic-export \
--show \
--simplify \
--verify \
```
-Description of all arguments:
+### Description of all arguments:
- `config` : The path of a model config file.
- `--checkpoint` : The path of a model checkpoint file.
- `--output-file`: The path of output ONNX model. If not specified, it will be set to `tmp.onnx`.
- `--shape`: The height and width of input tensor to the model. If not specified, it will be set to `224 224`.
- `--opset-version` : The opset version of ONNX. If not specified, it will be set to `11`.
-- `--dynamic-shape` : Determines whether to export ONNX with dynamic input shape. If not specified, it will be set to `False`.
+- `--dynamic-export` : Determines whether to export ONNX with dynamic input shape and output shapes. If not specified, it will be set to `False`.
- `--show`: Determines whether to print the architecture of the exported model. If not specified, it will be set to `False`.
- `--simplify`: Determines whether to simplify the exported ONNX model. If not specified, it will be set to `False`.
- `--verify`: Determines whether to verify the correctness of an exported model. If not specified, it will be set to `False`.
@@ -53,28 +59,121 @@ Description of all arguments:
Example:
```bash
-python tools/pytorch2onnx.py \
+python tools/deployment/pytorch2onnx.py \
configs/resnet/resnet18_b16x8_cifar10.py \
--checkpoint checkpoints/resnet/resnet18_b16x8_cifar10.pth \
--output-file checkpoints/resnet/resnet18_b16x8_cifar10.onnx \
- --dynamic-shape \
+ --dynamic-export \
--show \
--simplify \
--verify \
```
+## How to evaluate ONNX models with ONNX Runtime
+
+We prepare a tool `tools/deployment/test.py` to evaluate ONNX models with ONNX Runtime backend.
+
+### Prerequisite
+
+- Install onnx and onnxruntime-gpu
+
+ ```shell
+ pip install onnx onnxruntime-gpu
+ ```
+
+### Usage
+
+```bash
+python tools/deployment/test.py \
+ ${CONFIG_FILE} \
+ ${ONNX_FILE} \
+ --out ${OUTPUT_FILE} \
+ --metrics ${EVALUATION_METRICS} \
+ --metric-options ${EVALUATION_OPTIONS} \
+ --show
+ --show-dir ${SHOW_DIRECTORY} \
+ --cfg-options ${CFG_OPTIONS} \
+```
+
+### Description of all arguments
+
+- `config`: The path of a model config file.
+- `model`: The path of a ONNX model file.
+- `--out`: The path of output result file in pickle format.
+- `--metrics`: Evaluation metrics, which depends on the dataset, e.g., "accuracy", "precision", "recall", "f1_score", "support" for single label dataset, and "mAP", "CP", "CR", "CF1", "OP", "OR", "OF1" for multi-label dataset.
+- `--show`: Determines whether to show classifier outputs. If not specified, it will be set to `False`.
+- `--show-dir`: Directory where painted images will be saved
+- `--metrics-options`: Custom options for evaluation, the key-value pair in `xxx=yyy` format will be kwargs for `dataset.evaluate()` function
+- `--cfg-options`: Override some settings in the used config file, the key-value pair in `xxx=yyy` format will be merged into config file.
+
+### Results and Models
+
+This part selects ImageNet for onnxruntime verification. ImageNet has multiple versions, but the most commonly used one is [ILSVRC 2012](http://www.image-net.org/challenges/LSVRC/2012/).
+
+
+
+Model |
+Config |
+Metric |
+PyTorch |
+ONNX Runtime |
+
+
+ResNet |
+resnet50_b32x8_imagenet.py |
+Top 1 / 5 |
+76.55 / 93.15 |
+76.49 / 93.22 |
+
+
+ResNeXt |
+resnext50_32x4d_b32x8_imagenet.py |
+Top 1 / 5 |
+77.83 / 93.65 |
+77.83 / 93.65 |
+
+
+SE-ResNet |
+seresnet50_b32x8_imagenet.py |
+Top 1 / 5 |
+77.74 / 93.84 |
+77.74 / 93.84 |
+
+
+ShuffleNetV1 |
+shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py |
+Top 1 / 5 |
+68.13 / 87.81 |
+68.13 / 87.81 |
+
+
+ShuffleNetV2 |
+shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py |
+Top 1 / 5 |
+69.55 / 88.92 |
+69.55 / 88.92 |
+
+
+MobileNetV2 |
+mobilenet_v2_b32x8_imagenet.py |
+Top 1 / 5 |
+71.86 / 90.42 |
+71.86 / 90.42 |
+
+
+
## List of supported models exportable to ONNX
The table below lists the models that are guaranteed to be exportable to ONNX and runnable in ONNX Runtime.
| Model | Config | Batch Inference | Dynamic Shape | Note |
| :----------: | :----------------------------------------------------------: | :-------------: | :-----------: | ---- |
-| MobileNetV2 | `configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py` | Y | Y | |
-| ResNet | `configs/resnet/resnet18_b16x8_cifar10.py` | Y | Y | |
-| ResNeXt | `configs/resnext/resnext50_32x4d_b32x8_imagenet.py` | Y | Y | |
-| SE-ResNet | `configs/seresnet/seresnet50_b32x8_imagenet.py` | Y | Y | |
-| ShuffleNetV1 | `configs/shufflenet_v1/shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py` | Y | Y | |
-| ShuffleNetV2 | `configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py` | Y | Y | |
+| MobileNetV2 | [mobilenet_v2_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py) | Y | Y | |
+| ResNet | [resnet18_b16x8_cifar10.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnet/resnet18_b16x8_cifar10.py) | Y | Y | |
+| ResNeXt | [resnext50_32x4d_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext/resnext50_32x4d_b32x8_imagenet.py) | Y | Y | |
+| SE-ResNet | [seresnet50_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/seresnet/seresnet50_b32x8_imagenet.py) | Y | Y | |
+| ShuffleNetV1 | [shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v1/shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py) | Y | Y | |
+| ShuffleNetV2 | [shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py) | Y | Y | |
Notes:
diff --git a/mmcls/core/export/__init__.py b/mmcls/core/export/__init__.py
new file mode 100644
index 000000000..955e72632
--- /dev/null
+++ b/mmcls/core/export/__init__.py
@@ -0,0 +1,3 @@
+from .test import ONNXRuntimeClassifier
+
+__all__ = ['ONNXRuntimeClassifier']
diff --git a/mmcls/core/export/test.py b/mmcls/core/export/test.py
new file mode 100644
index 000000000..11f44abf8
--- /dev/null
+++ b/mmcls/core/export/test.py
@@ -0,0 +1,57 @@
+import numpy as np
+import onnxruntime as ort
+
+from mmcls.models.classifiers import BaseClassifier
+
+
+class ONNXRuntimeClassifier(BaseClassifier):
+ """Wrapper for classifier's inference with ONNXRuntime."""
+
+ def __init__(self, onnx_file, class_names, device_id):
+ super(ONNXRuntimeClassifier, self).__init__()
+ sess = ort.InferenceSession(onnx_file)
+
+ providers = ['CPUExecutionProvider']
+ options = [{}]
+ is_cuda_available = ort.get_device() == 'GPU'
+ if is_cuda_available:
+ providers.insert(0, 'CUDAExecutionProvider')
+ options.insert(0, {'device_id': device_id})
+ sess.set_providers(providers, options)
+
+ self.sess = sess
+ self.CLASSES = class_names
+ self.device_id = device_id
+ self.io_binding = sess.io_binding()
+ self.output_names = [_.name for _ in sess.get_outputs()]
+ self.is_cuda_available = is_cuda_available
+
+ def simple_test(self, img, img_metas, **kwargs):
+ raise NotImplementedError('This method is not implemented.')
+
+ def extract_feat(self, imgs):
+ raise NotImplementedError('This method is not implemented.')
+
+ def forward_train(self, imgs, **kwargs):
+ raise NotImplementedError('This method is not implemented.')
+
+ def forward_test(self, imgs, img_metas, **kwargs):
+ input_data = imgs
+ # set io binding for inputs/outputs
+ device_type = 'cuda' if self.is_cuda_available else 'cpu'
+ if not self.is_cuda_available:
+ input_data = input_data.cpu()
+ self.io_binding.bind_input(
+ name='input',
+ device_type=device_type,
+ device_id=self.device_id,
+ element_type=np.float32,
+ shape=input_data.shape,
+ buffer_ptr=input_data.data_ptr())
+
+ for name in self.output_names:
+ self.io_binding.bind_output(name)
+ # run session to get outputs
+ self.sess.run_with_iobinding(self.io_binding)
+ results = self.io_binding.copy_outputs_to_cpu()[0]
+ return list(results)
diff --git a/tools/onnx2tensorrt.py b/tools/deployment/onnx2tensorrt.py
similarity index 100%
rename from tools/onnx2tensorrt.py
rename to tools/deployment/onnx2tensorrt.py
diff --git a/tools/pytorch2onnx.py b/tools/deployment/pytorch2onnx.py
similarity index 97%
rename from tools/pytorch2onnx.py
rename to tools/deployment/pytorch2onnx.py
index 3303066cf..98e7078c1 100644
--- a/tools/pytorch2onnx.py
+++ b/tools/deployment/pytorch2onnx.py
@@ -37,7 +37,7 @@ def _demo_mm_inputs(input_shape, num_classes):
def pytorch2onnx(model,
input_shape,
opset_version=11,
- dynamic_shape=False,
+ dynamic_export=False,
show=False,
output_file='tmp.onnx',
do_simplify=False,
@@ -70,7 +70,7 @@ def pytorch2onnx(model,
register_extra_symbolics(opset_version)
# support dynamic shape export
- if dynamic_shape:
+ if dynamic_export:
dynamic_axes = {
'input': {
0: 'batch',
@@ -121,7 +121,7 @@ def pytorch2onnx(model,
output_file,
input_shapes=input_shape_dic,
input_data=input_dic,
- dynamic_input_shape=dynamic_shape)
+ dynamic_input_shape=dynamic_export)
if verify:
# check by onnx
import onnx
@@ -129,7 +129,7 @@ def pytorch2onnx(model,
onnx.checker.check_model(onnx_model)
# test the dynamic model
- if dynamic_shape:
+ if dynamic_export:
dynamic_test_inputs = _demo_mm_inputs(
(input_shape[0], input_shape[1], input_shape[2] * 2,
input_shape[3] * 2), model.head.num_classes)
@@ -176,7 +176,7 @@ def parse_args():
default=[224, 224],
help='input image size')
parser.add_argument(
- '--dynamic-shape',
+ '--dynamic-export',
action='store_true',
help='Whether to export ONNX with dynamic input shape. \
Defaults to False.')
@@ -212,7 +212,7 @@ if __name__ == '__main__':
input_shape,
opset_version=args.opset_version,
show=args.show,
- dynamic_shape=args.dynamic_shape,
+ dynamic_export=args.dynamic_export,
output_file=args.output_file,
do_simplify=args.simplify,
verify=args.verify)
diff --git a/tools/deployment/test.py b/tools/deployment/test.py
new file mode 100644
index 000000000..6d0e5ad05
--- /dev/null
+++ b/tools/deployment/test.py
@@ -0,0 +1,103 @@
+import argparse
+import warnings
+
+import mmcv
+import numpy as np
+from mmcv import DictAction
+from mmcv.parallel import MMDataParallel
+
+from mmcls.apis import single_gpu_test
+from mmcls.core.export import ONNXRuntimeClassifier
+from mmcls.datasets import build_dataloader, build_dataset
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description='Test (and eval) an ONNX model using ONNXRuntime.')
+ parser.add_argument('config', help='model config file')
+ parser.add_argument('model', help='filename of the input ONNX model')
+ parser.add_argument(
+ '--out', type=str, help='output result file in pickle format')
+ parser.add_argument(
+ '--cfg-options',
+ nargs='+',
+ action=DictAction,
+ help='override some settings in the used config, the key-value pair '
+ 'in xxx=yyy format will be merged into config file.')
+ parser.add_argument(
+ '--metrics',
+ type=str,
+ nargs='+',
+ help='evaluation metrics, which depends on the dataset, e.g., '
+ '"accuracy", "precision", "recall", "f1_score", "support" for single '
+ 'label dataset, and "mAP", "CP", "CR", "CF1", "OP", "OR", "OF1" for '
+ 'multi-label dataset')
+ parser.add_argument(
+ '--metric-options',
+ nargs='+',
+ action=DictAction,
+ default={},
+ help='custom options for evaluation, the key-value pair in xxx=yyy '
+ 'format will be parsed as a dict metric_options for dataset.evaluate()'
+ ' function.')
+ parser.add_argument('--show', action='store_true', help='show results')
+ parser.add_argument(
+ '--show-dir', help='directory where painted images will be saved')
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
+
+ if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
+ raise ValueError('The output file must be a pkl file.')
+
+ cfg = mmcv.Config.fromfile(args.config)
+ if args.cfg_options is not None:
+ cfg.merge_from_dict(args.cfg_options)
+
+ # build dataset and dataloader
+ dataset = build_dataset(cfg.data.test)
+ data_loader = build_dataloader(
+ dataset,
+ samples_per_gpu=cfg.data.samples_per_gpu,
+ workers_per_gpu=cfg.data.workers_per_gpu,
+ shuffle=False,
+ round_up=False)
+
+ # build onnxruntime model and run inference.
+ model = ONNXRuntimeClassifier(
+ args.model, class_names=dataset.CLASSES, device_id=0)
+
+ model = MMDataParallel(model, device_ids=[0])
+ outputs = single_gpu_test(model, data_loader, args.show, args.show_dir)
+
+ if args.metrics:
+ results = dataset.evaluate(outputs, args.metrics, args.metric_options)
+ for k, v in results.items():
+ print(f'\n{k} : {v:.2f}')
+ else:
+ warnings.warn('Evaluation metrics are not specified.')
+ scores = np.vstack(outputs)
+ pred_score = np.max(scores, axis=1)
+ pred_label = np.argmax(scores, axis=1)
+ pred_class = [dataset.CLASSES[lb] for lb in pred_label]
+ results = {
+ 'pred_score': pred_score,
+ 'pred_label': pred_label,
+ 'pred_class': pred_class
+ }
+ if not args.out:
+ print('\nthe predicted result for the first element is '
+ f'pred_score = {pred_score[0]:.2f}, '
+ f'pred_label = {pred_label[0]} '
+ f'and pred_class = {pred_class[0]}. '
+ 'Specify --out to save all results to files.')
+ if args.out:
+ print(f'\nwriting results to {args.out}')
+ mmcv.dump(results, args.out)
+
+
+if __name__ == '__main__':
+ main()