diff --git a/docs/tutorials/onnx2tensorrt.md b/docs/tutorials/onnx2tensorrt.md index 0974cb991..27f744d15 100644 --- a/docs/tutorials/onnx2tensorrt.md +++ b/docs/tutorials/onnx2tensorrt.md @@ -43,7 +43,7 @@ Description of all arguments: Example: ```bash -python tools/onnx2tensorrt.py \ +python tools/deployment/onnx2tensorrt.py \ checkpoints/resnet/resnet18_b16x8_cifar10.onnx \ --trt-file checkpoints/resnet/resnet18_b16x8_cifar10.trt \ --shape 224 224 \ diff --git a/docs/tutorials/pytorch2onnx.md b/docs/tutorials/pytorch2onnx.md index d1d48fcf5..c2315eea9 100644 --- a/docs/tutorials/pytorch2onnx.md +++ b/docs/tutorials/pytorch2onnx.md @@ -6,6 +6,12 @@ - [How to convert models from Pytorch to ONNX](#how-to-convert-models-from-pytorch-to-onnx) - [Prerequisite](#prerequisite) - [Usage](#usage) + - [Description of all arguments](#description-of-all-arguments) + - [How to evaluate ONNX models with ONNX Runtime](#how-to-evaluate-onnx-models-with-onnx-runtime) + - [Prerequisite](#prerequisite-1) + - [Usage](#usage-1) + - [Description of all arguments](#description-of-all-arguments-1) + - [Results and Models](#results-and-models) - [List of supported models exportable to ONNX](#list-of-supported-models-exportable-to-onnx) - [Reminders](#reminders) - [FAQs](#faqs) @@ -26,26 +32,26 @@ ### Usage ```bash -python tools/pytorch2onnx.py \ +python tools/deployment/pytorch2onnx.py \ ${CONFIG_FILE} \ --checkpoint ${CHECKPOINT_FILE} \ --output-file ${OUTPUT_FILE} \ --shape ${IMAGE_SHAPE} \ --opset-version ${OPSET_VERSION} \ - --dynamic-shape \ + --dynamic-export \ --show \ --simplify \ --verify \ ``` -Description of all arguments: +### Description of all arguments: - `config` : The path of a model config file. - `--checkpoint` : The path of a model checkpoint file. - `--output-file`: The path of output ONNX model. If not specified, it will be set to `tmp.onnx`. - `--shape`: The height and width of input tensor to the model. If not specified, it will be set to `224 224`. - `--opset-version` : The opset version of ONNX. If not specified, it will be set to `11`. -- `--dynamic-shape` : Determines whether to export ONNX with dynamic input shape. If not specified, it will be set to `False`. +- `--dynamic-export` : Determines whether to export ONNX with dynamic input shape and output shapes. If not specified, it will be set to `False`. - `--show`: Determines whether to print the architecture of the exported model. If not specified, it will be set to `False`. - `--simplify`: Determines whether to simplify the exported ONNX model. If not specified, it will be set to `False`. - `--verify`: Determines whether to verify the correctness of an exported model. If not specified, it will be set to `False`. @@ -53,28 +59,121 @@ Description of all arguments: Example: ```bash -python tools/pytorch2onnx.py \ +python tools/deployment/pytorch2onnx.py \ configs/resnet/resnet18_b16x8_cifar10.py \ --checkpoint checkpoints/resnet/resnet18_b16x8_cifar10.pth \ --output-file checkpoints/resnet/resnet18_b16x8_cifar10.onnx \ - --dynamic-shape \ + --dynamic-export \ --show \ --simplify \ --verify \ ``` +## How to evaluate ONNX models with ONNX Runtime + +We prepare a tool `tools/deployment/test.py` to evaluate ONNX models with ONNX Runtime backend. + +### Prerequisite + +- Install onnx and onnxruntime-gpu + + ```shell + pip install onnx onnxruntime-gpu + ``` + +### Usage + +```bash +python tools/deployment/test.py \ + ${CONFIG_FILE} \ + ${ONNX_FILE} \ + --out ${OUTPUT_FILE} \ + --metrics ${EVALUATION_METRICS} \ + --metric-options ${EVALUATION_OPTIONS} \ + --show + --show-dir ${SHOW_DIRECTORY} \ + --cfg-options ${CFG_OPTIONS} \ +``` + +### Description of all arguments + +- `config`: The path of a model config file. +- `model`: The path of a ONNX model file. +- `--out`: The path of output result file in pickle format. +- `--metrics`: Evaluation metrics, which depends on the dataset, e.g., "accuracy", "precision", "recall", "f1_score", "support" for single label dataset, and "mAP", "CP", "CR", "CF1", "OP", "OR", "OF1" for multi-label dataset. +- `--show`: Determines whether to show classifier outputs. If not specified, it will be set to `False`. +- `--show-dir`: Directory where painted images will be saved +- `--metrics-options`: Custom options for evaluation, the key-value pair in `xxx=yyy` format will be kwargs for `dataset.evaluate()` function +- `--cfg-options`: Override some settings in the used config file, the key-value pair in `xxx=yyy` format will be merged into config file. + +### Results and Models + +This part selects ImageNet for onnxruntime verification. ImageNet has multiple versions, but the most commonly used one is [ILSVRC 2012](http://www.image-net.org/challenges/LSVRC/2012/). + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ModelConfigMetricPyTorchONNX Runtime
ResNetresnet50_b32x8_imagenet.pyTop 1 / 576.55 / 93.1576.49 / 93.22
ResNeXtresnext50_32x4d_b32x8_imagenet.pyTop 1 / 577.83 / 93.6577.83 / 93.65
SE-ResNetseresnet50_b32x8_imagenet.pyTop 1 / 577.74 / 93.8477.74 / 93.84
ShuffleNetV1shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.pyTop 1 / 568.13 / 87.8168.13 / 87.81
ShuffleNetV2shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.pyTop 1 / 569.55 / 88.9269.55 / 88.92
MobileNetV2mobilenet_v2_b32x8_imagenet.pyTop 1 / 571.86 / 90.4271.86 / 90.42
+ ## List of supported models exportable to ONNX The table below lists the models that are guaranteed to be exportable to ONNX and runnable in ONNX Runtime. | Model | Config | Batch Inference | Dynamic Shape | Note | | :----------: | :----------------------------------------------------------: | :-------------: | :-----------: | ---- | -| MobileNetV2 | `configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py` | Y | Y | | -| ResNet | `configs/resnet/resnet18_b16x8_cifar10.py` | Y | Y | | -| ResNeXt | `configs/resnext/resnext50_32x4d_b32x8_imagenet.py` | Y | Y | | -| SE-ResNet | `configs/seresnet/seresnet50_b32x8_imagenet.py` | Y | Y | | -| ShuffleNetV1 | `configs/shufflenet_v1/shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py` | Y | Y | | -| ShuffleNetV2 | `configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py` | Y | Y | | +| MobileNetV2 | [mobilenet_v2_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py) | Y | Y | | +| ResNet | [resnet18_b16x8_cifar10.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnet/resnet18_b16x8_cifar10.py) | Y | Y | | +| ResNeXt | [resnext50_32x4d_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext/resnext50_32x4d_b32x8_imagenet.py) | Y | Y | | +| SE-ResNet | [seresnet50_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/seresnet/seresnet50_b32x8_imagenet.py) | Y | Y | | +| ShuffleNetV1 | [shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v1/shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py) | Y | Y | | +| ShuffleNetV2 | [shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py) | Y | Y | | Notes: diff --git a/mmcls/core/export/__init__.py b/mmcls/core/export/__init__.py new file mode 100644 index 000000000..955e72632 --- /dev/null +++ b/mmcls/core/export/__init__.py @@ -0,0 +1,3 @@ +from .test import ONNXRuntimeClassifier + +__all__ = ['ONNXRuntimeClassifier'] diff --git a/mmcls/core/export/test.py b/mmcls/core/export/test.py new file mode 100644 index 000000000..11f44abf8 --- /dev/null +++ b/mmcls/core/export/test.py @@ -0,0 +1,57 @@ +import numpy as np +import onnxruntime as ort + +from mmcls.models.classifiers import BaseClassifier + + +class ONNXRuntimeClassifier(BaseClassifier): + """Wrapper for classifier's inference with ONNXRuntime.""" + + def __init__(self, onnx_file, class_names, device_id): + super(ONNXRuntimeClassifier, self).__init__() + sess = ort.InferenceSession(onnx_file) + + providers = ['CPUExecutionProvider'] + options = [{}] + is_cuda_available = ort.get_device() == 'GPU' + if is_cuda_available: + providers.insert(0, 'CUDAExecutionProvider') + options.insert(0, {'device_id': device_id}) + sess.set_providers(providers, options) + + self.sess = sess + self.CLASSES = class_names + self.device_id = device_id + self.io_binding = sess.io_binding() + self.output_names = [_.name for _ in sess.get_outputs()] + self.is_cuda_available = is_cuda_available + + def simple_test(self, img, img_metas, **kwargs): + raise NotImplementedError('This method is not implemented.') + + def extract_feat(self, imgs): + raise NotImplementedError('This method is not implemented.') + + def forward_train(self, imgs, **kwargs): + raise NotImplementedError('This method is not implemented.') + + def forward_test(self, imgs, img_metas, **kwargs): + input_data = imgs + # set io binding for inputs/outputs + device_type = 'cuda' if self.is_cuda_available else 'cpu' + if not self.is_cuda_available: + input_data = input_data.cpu() + self.io_binding.bind_input( + name='input', + device_type=device_type, + device_id=self.device_id, + element_type=np.float32, + shape=input_data.shape, + buffer_ptr=input_data.data_ptr()) + + for name in self.output_names: + self.io_binding.bind_output(name) + # run session to get outputs + self.sess.run_with_iobinding(self.io_binding) + results = self.io_binding.copy_outputs_to_cpu()[0] + return list(results) diff --git a/tools/onnx2tensorrt.py b/tools/deployment/onnx2tensorrt.py similarity index 100% rename from tools/onnx2tensorrt.py rename to tools/deployment/onnx2tensorrt.py diff --git a/tools/pytorch2onnx.py b/tools/deployment/pytorch2onnx.py similarity index 97% rename from tools/pytorch2onnx.py rename to tools/deployment/pytorch2onnx.py index 3303066cf..98e7078c1 100644 --- a/tools/pytorch2onnx.py +++ b/tools/deployment/pytorch2onnx.py @@ -37,7 +37,7 @@ def _demo_mm_inputs(input_shape, num_classes): def pytorch2onnx(model, input_shape, opset_version=11, - dynamic_shape=False, + dynamic_export=False, show=False, output_file='tmp.onnx', do_simplify=False, @@ -70,7 +70,7 @@ def pytorch2onnx(model, register_extra_symbolics(opset_version) # support dynamic shape export - if dynamic_shape: + if dynamic_export: dynamic_axes = { 'input': { 0: 'batch', @@ -121,7 +121,7 @@ def pytorch2onnx(model, output_file, input_shapes=input_shape_dic, input_data=input_dic, - dynamic_input_shape=dynamic_shape) + dynamic_input_shape=dynamic_export) if verify: # check by onnx import onnx @@ -129,7 +129,7 @@ def pytorch2onnx(model, onnx.checker.check_model(onnx_model) # test the dynamic model - if dynamic_shape: + if dynamic_export: dynamic_test_inputs = _demo_mm_inputs( (input_shape[0], input_shape[1], input_shape[2] * 2, input_shape[3] * 2), model.head.num_classes) @@ -176,7 +176,7 @@ def parse_args(): default=[224, 224], help='input image size') parser.add_argument( - '--dynamic-shape', + '--dynamic-export', action='store_true', help='Whether to export ONNX with dynamic input shape. \ Defaults to False.') @@ -212,7 +212,7 @@ if __name__ == '__main__': input_shape, opset_version=args.opset_version, show=args.show, - dynamic_shape=args.dynamic_shape, + dynamic_export=args.dynamic_export, output_file=args.output_file, do_simplify=args.simplify, verify=args.verify) diff --git a/tools/deployment/test.py b/tools/deployment/test.py new file mode 100644 index 000000000..6d0e5ad05 --- /dev/null +++ b/tools/deployment/test.py @@ -0,0 +1,103 @@ +import argparse +import warnings + +import mmcv +import numpy as np +from mmcv import DictAction +from mmcv.parallel import MMDataParallel + +from mmcls.apis import single_gpu_test +from mmcls.core.export import ONNXRuntimeClassifier +from mmcls.datasets import build_dataloader, build_dataset + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Test (and eval) an ONNX model using ONNXRuntime.') + parser.add_argument('config', help='model config file') + parser.add_argument('model', help='filename of the input ONNX model') + parser.add_argument( + '--out', type=str, help='output result file in pickle format') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file.') + parser.add_argument( + '--metrics', + type=str, + nargs='+', + help='evaluation metrics, which depends on the dataset, e.g., ' + '"accuracy", "precision", "recall", "f1_score", "support" for single ' + 'label dataset, and "mAP", "CP", "CR", "CF1", "OP", "OR", "OF1" for ' + 'multi-label dataset') + parser.add_argument( + '--metric-options', + nargs='+', + action=DictAction, + default={}, + help='custom options for evaluation, the key-value pair in xxx=yyy ' + 'format will be parsed as a dict metric_options for dataset.evaluate()' + ' function.') + parser.add_argument('--show', action='store_true', help='show results') + parser.add_argument( + '--show-dir', help='directory where painted images will be saved') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + if args.out is not None and not args.out.endswith(('.pkl', '.pickle')): + raise ValueError('The output file must be a pkl file.') + + cfg = mmcv.Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # build dataset and dataloader + dataset = build_dataset(cfg.data.test) + data_loader = build_dataloader( + dataset, + samples_per_gpu=cfg.data.samples_per_gpu, + workers_per_gpu=cfg.data.workers_per_gpu, + shuffle=False, + round_up=False) + + # build onnxruntime model and run inference. + model = ONNXRuntimeClassifier( + args.model, class_names=dataset.CLASSES, device_id=0) + + model = MMDataParallel(model, device_ids=[0]) + outputs = single_gpu_test(model, data_loader, args.show, args.show_dir) + + if args.metrics: + results = dataset.evaluate(outputs, args.metrics, args.metric_options) + for k, v in results.items(): + print(f'\n{k} : {v:.2f}') + else: + warnings.warn('Evaluation metrics are not specified.') + scores = np.vstack(outputs) + pred_score = np.max(scores, axis=1) + pred_label = np.argmax(scores, axis=1) + pred_class = [dataset.CLASSES[lb] for lb in pred_label] + results = { + 'pred_score': pred_score, + 'pred_label': pred_label, + 'pred_class': pred_class + } + if not args.out: + print('\nthe predicted result for the first element is ' + f'pred_score = {pred_score[0]:.2f}, ' + f'pred_label = {pred_label[0]} ' + f'and pred_class = {pred_class[0]}. ' + 'Specify --out to save all results to files.') + if args.out: + print(f'\nwriting results to {args.out}') + mmcv.dump(results, args.out) + + +if __name__ == '__main__': + main()