diff --git a/docs/tutorials/onnx2tensorrt.md b/docs/tutorials/onnx2tensorrt.md index b5cbd99c..b2d5aa4d 100644 --- a/docs/tutorials/onnx2tensorrt.md +++ b/docs/tutorials/onnx2tensorrt.md @@ -26,7 +26,9 @@ python tools/deployment/onnx2tensorrt.py \ ${MODEL} \ --trt-file ${TRT_FILE} \ --shape ${IMAGE_SHAPE} \ - --workspace-size {WORKSPACE_SIZE} \ + --max-batch-size ${MAX_BATCH_SIZE} \ + --workspace-size ${WORKSPACE_SIZE} \ + --fp16 \ --show \ --verify \ ``` @@ -36,6 +38,8 @@ Description of all arguments: - `model` : The path of an ONNX model file. - `--trt-file`: The Path of output TensorRT engine file. If not specified, it will be set to `tmp.trt`. - `--shape`: The height and width of model input. If not specified, it will be set to `224 224`. +- `--max-batch-size`: The max batch size of TensorRT model, should not be less than 1. +- `--fp16`: Enable fp16 mode. - `--workspace-size` : The required GPU workspace size in GiB to build TensorRT engine. If not specified, it will be set to `1` GiB. - `--show`: Determines whether to show the outputs of the model. If not specified, it will be set to `False`. - `--verify`: Determines whether to verify the correctness of models between ONNXRuntime and TensorRT. If not specified, it will be set to `False`. @@ -55,11 +59,11 @@ python tools/deployment/onnx2tensorrt.py \ The table below lists the models that are guaranteed to be convertable to TensorRT. -| Model | Config | Status | -| :----------: | :----------------------------------------------------------: | :----: | -| MobileNetV2 | `configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py` | Y | -| ResNet | `configs/resnet/resnet18_b16x8_cifar10.py` | Y | -| ResNeXt | `configs/resnext/resnext50_32x4d_b32x8_imagenet.py` | Y | +| Model | Config | Status | +| :----------: | :--------------------------------------------------------------------------: | :----: | +| MobileNetV2 | `configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py` | Y | +| ResNet | `configs/resnet/resnet18_b16x8_cifar10.py` | Y | +| ResNeXt | `configs/resnext/resnext50_32x4d_b32x8_imagenet.py` | Y | | ShuffleNetV1 | `configs/shufflenet_v1/shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py` | Y | | ShuffleNetV2 | `configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py` | Y | diff --git a/docs/tutorials/pytorch2onnx.md b/docs/tutorials/pytorch2onnx.md index b4f5d880..0032a5d8 100644 --- a/docs/tutorials/pytorch2onnx.md +++ b/docs/tutorials/pytorch2onnx.md @@ -6,7 +6,7 @@ - [How to convert models from Pytorch to ONNX](#how-to-convert-models-from-pytorch-to-onnx) - [Prerequisite](#prerequisite) - [Usage](#usage) - - [Description of all arguments](#description-of-all-arguments) + - [Description of all arguments:](#description-of-all-arguments) - [How to evaluate ONNX models with ONNX Runtime](#how-to-evaluate-onnx-models-with-onnx-runtime) - [Prerequisite](#prerequisite-1) - [Usage](#usage-1) @@ -71,7 +71,7 @@ python tools/deployment/pytorch2onnx.py \ ## How to evaluate ONNX models with ONNX Runtime -We prepare a tool `tools/deployment/test.py` to evaluate ONNX models with ONNX Runtime backend. +We prepare a tool `tools/deployment/test.py` to evaluate ONNX models with ONNXRuntime or TensorRT. ### Prerequisite @@ -87,6 +87,7 @@ We prepare a tool `tools/deployment/test.py` to evaluate ONNX models with ONNX R python tools/deployment/test.py \ ${CONFIG_FILE} \ ${ONNX_FILE} \ + --backend ${BACKEND} \ --out ${OUTPUT_FILE} \ --metrics ${EVALUATION_METRICS} \ --metric-options ${EVALUATION_OPTIONS} \ @@ -99,6 +100,7 @@ python tools/deployment/test.py \ - `config`: The path of a model config file. - `model`: The path of a ONNX model file. +- `--backend`: Backend for input model to run and should be `onnxruntime` or `tensorrt`. - `--out`: The path of output result file in pickle format. - `--metrics`: Evaluation metrics, which depends on the dataset, e.g., "accuracy", "precision", "recall", "f1_score", "support" for single label dataset, and "mAP", "CP", "CR", "CF1", "OP", "OR", "OF1" for multi-label dataset. - `--show`: Determines whether to show classifier outputs. If not specified, it will be set to `False`. @@ -117,6 +119,8 @@ This part selects ImageNet for onnxruntime verification. ImageNet has multiple v Metric PyTorch ONNXRuntime + TensorRT-fp32 + TensorRT-fp16 ResNet @@ -124,13 +128,17 @@ This part selects ImageNet for onnxruntime verification. ImageNet has multiple v Top 1 / 5 76.55 / 93.15 76.49 / 93.22 + 76.49 / 93.22 + 76.50 / 93.20 ResNeXt resnext50_32x4d_b32x8_imagenet.py Top 1 / 5 - 77.83 / 93.65 - 77.83 / 93.65 + 77.90 / 93.66 + 77.90 / 93.66 + 77.90 / 93.66 + 77.89 / 93.65 SE-ResNet @@ -138,6 +146,8 @@ This part selects ImageNet for onnxruntime verification. ImageNet has multiple v Top 1 / 5 77.74 / 93.84 77.74 / 93.84 + 77.74 / 93.84 + 77.74 / 93.85 ShuffleNetV1 @@ -145,6 +155,8 @@ This part selects ImageNet for onnxruntime verification. ImageNet has multiple v Top 1 / 5 68.13 / 87.81 68.13 / 87.81 + 68.13 / 87.81 + 68.10 / 87.80 ShuffleNetV2 @@ -152,6 +164,8 @@ This part selects ImageNet for onnxruntime verification. ImageNet has multiple v Top 1 / 5 69.55 / 88.92 69.55 / 88.92 + 69.55 / 88.92 + 69.55 / 88.92 MobileNetV2 @@ -159,6 +173,8 @@ This part selects ImageNet for onnxruntime verification. ImageNet has multiple v Top 1 / 5 71.86 / 90.42 71.86 / 90.42 + 71.86 / 90.42 + 71.88 / 90.40 @@ -166,12 +182,12 @@ This part selects ImageNet for onnxruntime verification. ImageNet has multiple v The table below lists the models that are guaranteed to be exportable to ONNX and runnable in ONNX Runtime. -| Model | Config | Batch Inference | Dynamic Shape | Note | -| :----------: | :----------------------------------------------------------: | :-------------: | :-----------: | ---- | -| MobileNetV2 | [mobilenet_v2_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py) | Y | Y | | -| ResNet | [resnet18_b16x8_cifar10.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnet/resnet18_b16x8_cifar10.py) | Y | Y | | -| ResNeXt | [resnext50_32x4d_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext/resnext50_32x4d_b32x8_imagenet.py) | Y | Y | | -| SE-ResNet | [seresnet50_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/seresnet/seresnet50_b32x8_imagenet.py) | Y | Y | | +| Model | Config | Batch Inference | Dynamic Shape | Note | +| :----------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------: | :-----------: | ---- | +| MobileNetV2 | [mobilenet_v2_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py) | Y | Y | | +| ResNet | [resnet18_b16x8_cifar10.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnet/resnet18_b16x8_cifar10.py) | Y | Y | | +| ResNeXt | [resnext50_32x4d_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext/resnext50_32x4d_b32x8_imagenet.py) | Y | Y | | +| SE-ResNet | [seresnet50_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/seresnet/seresnet50_b32x8_imagenet.py) | Y | Y | | | ShuffleNetV1 | [shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v1/shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py) | Y | Y | | | ShuffleNetV2 | [shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py) | Y | Y | | diff --git a/mmcls/core/export/__init__.py b/mmcls/core/export/__init__.py index 955e7263..4a4e2bf3 100644 --- a/mmcls/core/export/__init__.py +++ b/mmcls/core/export/__init__.py @@ -1,3 +1,3 @@ -from .test import ONNXRuntimeClassifier +from .test import ONNXRuntimeClassifier, TensorRTClassifier -__all__ = ['ONNXRuntimeClassifier'] +__all__ = ['ONNXRuntimeClassifier', 'TensorRTClassifier'] diff --git a/mmcls/core/export/test.py b/mmcls/core/export/test.py index 11f44abf..52245c55 100644 --- a/mmcls/core/export/test.py +++ b/mmcls/core/export/test.py @@ -1,5 +1,8 @@ +import warnings + import numpy as np import onnxruntime as ort +import torch from mmcls.models.classifiers import BaseClassifier @@ -55,3 +58,38 @@ class ONNXRuntimeClassifier(BaseClassifier): self.sess.run_with_iobinding(self.io_binding) results = self.io_binding.copy_outputs_to_cpu()[0] return list(results) + + +class TensorRTClassifier(BaseClassifier): + + def __init__(self, trt_file, class_names, device_id): + super(TensorRTClassifier, self).__init__() + from mmcv.tensorrt import TRTWraper, load_tensorrt_plugin + try: + load_tensorrt_plugin() + except (ImportError, ModuleNotFoundError): + warnings.warn('If input model has custom op from mmcv, \ + you may have to build mmcv with TensorRT from source.') + model = TRTWraper( + trt_file, input_names=['input'], output_names=['probs']) + + self.model = model + self.device_id = device_id + self.CLASSES = class_names + + def simple_test(self, img, img_metas, **kwargs): + raise NotImplementedError('This method is not implemented.') + + def extract_feat(self, imgs): + raise NotImplementedError('This method is not implemented.') + + def forward_train(self, imgs, **kwargs): + raise NotImplementedError('This method is not implemented.') + + def forward_test(self, imgs, img_metas, **kwargs): + input_data = imgs + with torch.cuda.device(self.device_id), torch.no_grad(): + results = self.model({'input': input_data})['probs'] + results = results.detach().cpu().numpy() + + return list(results) diff --git a/tools/deployment/onnx2tensorrt.py b/tools/deployment/onnx2tensorrt.py index d6271e5d..60454066 100644 --- a/tools/deployment/onnx2tensorrt.py +++ b/tools/deployment/onnx2tensorrt.py @@ -13,6 +13,8 @@ def get_GiB(x: int): def onnx2tensorrt(onnx_file, trt_file, input_shape, + max_batch_size, + fp16_mode=False, verify=False, workspace_size=1): """Create tensorrt engine from onnx model. @@ -22,6 +24,7 @@ def onnx2tensorrt(onnx_file, trt_file (str): Filename of the output TensorRT engine file. input_shape (list[int]): Input shape of the model. eg [1, 3, 224, 224]. + max_batch_size (int): Max batch size of the model. verify (bool, optional): Whether to verify the converted model. Defaults to False. workspace_size (int, optional): Maximium workspace of GPU. @@ -32,12 +35,14 @@ def onnx2tensorrt(onnx_file, onnx_model = onnx.load(onnx_file) # create trt engine and wraper - opt_shape_dict = {'input': [input_shape, input_shape, input_shape]} + assert max_batch_size >= 1 + max_shape = [max_batch_size] + list(input_shape[1:]) + opt_shape_dict = {'input': [input_shape, input_shape, max_shape]} max_workspace_size = get_GiB(workspace_size) trt_engine = onnx2trt( onnx_model, opt_shape_dict, - fp16_mode=False, + fp16_mode=fp16_mode, max_workspace_size=max_workspace_size) save_dir, _ = osp.split(trt_file) if save_dir: @@ -99,6 +104,12 @@ def parse_args(): nargs='+', default=[224, 224], help='Input size of the model') + parser.add_argument( + '--max-batch-size', + type=int, + default=1, + help='Maximum batch size of TensorRT model.') + parser.add_argument('--fp16', action='store_true', help='Enable fp16 mode') parser.add_argument( '--workspace-size', type=int, @@ -124,5 +135,7 @@ if __name__ == '__main__': args.model, args.trt_file, input_shape, + args.max_batch_size, + fp16_mode=args.fp16, verify=args.verify, workspace_size=args.workspace_size) diff --git a/tools/deployment/test.py b/tools/deployment/test.py index 6d0e5ad0..b57665c6 100644 --- a/tools/deployment/test.py +++ b/tools/deployment/test.py @@ -7,7 +7,7 @@ from mmcv import DictAction from mmcv.parallel import MMDataParallel from mmcls.apis import single_gpu_test -from mmcls.core.export import ONNXRuntimeClassifier +from mmcls.core.export import ONNXRuntimeClassifier, TensorRTClassifier from mmcls.datasets import build_dataloader, build_dataset @@ -16,6 +16,10 @@ def parse_args(): description='Test (and eval) an ONNX model using ONNXRuntime.') parser.add_argument('config', help='model config file') parser.add_argument('model', help='filename of the input ONNX model') + parser.add_argument( + '--backend', + help='Backend of the model.', + choices=['onnxruntime', 'tensorrt']) parser.add_argument( '--out', type=str, help='output result file in pickle format') parser.add_argument( @@ -67,10 +71,18 @@ def main(): round_up=False) # build onnxruntime model and run inference. - model = ONNXRuntimeClassifier( - args.model, class_names=dataset.CLASSES, device_id=0) + if args.backend == 'onnxruntime': + model = ONNXRuntimeClassifier( + args.model, class_names=dataset.CLASSES, device_id=0) + elif args.backend == 'tensorrt': + model = TensorRTClassifier( + args.model, class_names=dataset.CLASSES, device_id=0) + else: + print('Unknown backend: {}.'.format(args.model)) + exit() model = MMDataParallel(model, device_ids=[0]) + model.CLASSES = dataset.CLASSES outputs = single_gpu_test(model, data_loader, args.show, args.show_dir) if args.metrics: