diff --git a/docs/tutorials/onnx2tensorrt.md b/docs/tutorials/onnx2tensorrt.md
index b5cbd99c..b2d5aa4d 100644
--- a/docs/tutorials/onnx2tensorrt.md
+++ b/docs/tutorials/onnx2tensorrt.md
@@ -26,7 +26,9 @@ python tools/deployment/onnx2tensorrt.py \
${MODEL} \
--trt-file ${TRT_FILE} \
--shape ${IMAGE_SHAPE} \
- --workspace-size {WORKSPACE_SIZE} \
+ --max-batch-size ${MAX_BATCH_SIZE} \
+ --workspace-size ${WORKSPACE_SIZE} \
+ --fp16 \
--show \
--verify \
```
@@ -36,6 +38,8 @@ Description of all arguments:
- `model` : The path of an ONNX model file.
- `--trt-file`: The Path of output TensorRT engine file. If not specified, it will be set to `tmp.trt`.
- `--shape`: The height and width of model input. If not specified, it will be set to `224 224`.
+- `--max-batch-size`: The max batch size of TensorRT model, should not be less than 1.
+- `--fp16`: Enable fp16 mode.
- `--workspace-size` : The required GPU workspace size in GiB to build TensorRT engine. If not specified, it will be set to `1` GiB.
- `--show`: Determines whether to show the outputs of the model. If not specified, it will be set to `False`.
- `--verify`: Determines whether to verify the correctness of models between ONNXRuntime and TensorRT. If not specified, it will be set to `False`.
@@ -55,11 +59,11 @@ python tools/deployment/onnx2tensorrt.py \
The table below lists the models that are guaranteed to be convertable to TensorRT.
-| Model | Config | Status |
-| :----------: | :----------------------------------------------------------: | :----: |
-| MobileNetV2 | `configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py` | Y |
-| ResNet | `configs/resnet/resnet18_b16x8_cifar10.py` | Y |
-| ResNeXt | `configs/resnext/resnext50_32x4d_b32x8_imagenet.py` | Y |
+| Model | Config | Status |
+| :----------: | :--------------------------------------------------------------------------: | :----: |
+| MobileNetV2 | `configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py` | Y |
+| ResNet | `configs/resnet/resnet18_b16x8_cifar10.py` | Y |
+| ResNeXt | `configs/resnext/resnext50_32x4d_b32x8_imagenet.py` | Y |
| ShuffleNetV1 | `configs/shufflenet_v1/shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py` | Y |
| ShuffleNetV2 | `configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py` | Y |
diff --git a/docs/tutorials/pytorch2onnx.md b/docs/tutorials/pytorch2onnx.md
index b4f5d880..0032a5d8 100644
--- a/docs/tutorials/pytorch2onnx.md
+++ b/docs/tutorials/pytorch2onnx.md
@@ -6,7 +6,7 @@
- [How to convert models from Pytorch to ONNX](#how-to-convert-models-from-pytorch-to-onnx)
- [Prerequisite](#prerequisite)
- [Usage](#usage)
- - [Description of all arguments](#description-of-all-arguments)
+ - [Description of all arguments:](#description-of-all-arguments)
- [How to evaluate ONNX models with ONNX Runtime](#how-to-evaluate-onnx-models-with-onnx-runtime)
- [Prerequisite](#prerequisite-1)
- [Usage](#usage-1)
@@ -71,7 +71,7 @@ python tools/deployment/pytorch2onnx.py \
## How to evaluate ONNX models with ONNX Runtime
-We prepare a tool `tools/deployment/test.py` to evaluate ONNX models with ONNX Runtime backend.
+We prepare a tool `tools/deployment/test.py` to evaluate ONNX models with ONNXRuntime or TensorRT.
### Prerequisite
@@ -87,6 +87,7 @@ We prepare a tool `tools/deployment/test.py` to evaluate ONNX models with ONNX R
python tools/deployment/test.py \
${CONFIG_FILE} \
${ONNX_FILE} \
+ --backend ${BACKEND} \
--out ${OUTPUT_FILE} \
--metrics ${EVALUATION_METRICS} \
--metric-options ${EVALUATION_OPTIONS} \
@@ -99,6 +100,7 @@ python tools/deployment/test.py \
- `config`: The path of a model config file.
- `model`: The path of a ONNX model file.
+- `--backend`: Backend for input model to run and should be `onnxruntime` or `tensorrt`.
- `--out`: The path of output result file in pickle format.
- `--metrics`: Evaluation metrics, which depends on the dataset, e.g., "accuracy", "precision", "recall", "f1_score", "support" for single label dataset, and "mAP", "CP", "CR", "CF1", "OP", "OR", "OF1" for multi-label dataset.
- `--show`: Determines whether to show classifier outputs. If not specified, it will be set to `False`.
@@ -117,6 +119,8 @@ This part selects ImageNet for onnxruntime verification. ImageNet has multiple v
Metric |
PyTorch |
ONNXRuntime |
+ TensorRT-fp32 |
+ TensorRT-fp16 |
ResNet |
@@ -124,13 +128,17 @@ This part selects ImageNet for onnxruntime verification. ImageNet has multiple v
Top 1 / 5 |
76.55 / 93.15 |
76.49 / 93.22 |
+ 76.49 / 93.22 |
+ 76.50 / 93.20 |
ResNeXt |
resnext50_32x4d_b32x8_imagenet.py |
Top 1 / 5 |
- 77.83 / 93.65 |
- 77.83 / 93.65 |
+ 77.90 / 93.66 |
+ 77.90 / 93.66 |
+ 77.90 / 93.66 |
+ 77.89 / 93.65 |
SE-ResNet |
@@ -138,6 +146,8 @@ This part selects ImageNet for onnxruntime verification. ImageNet has multiple v
Top 1 / 5 |
77.74 / 93.84 |
77.74 / 93.84 |
+ 77.74 / 93.84 |
+ 77.74 / 93.85 |
ShuffleNetV1 |
@@ -145,6 +155,8 @@ This part selects ImageNet for onnxruntime verification. ImageNet has multiple v
Top 1 / 5 |
68.13 / 87.81 |
68.13 / 87.81 |
+ 68.13 / 87.81 |
+ 68.10 / 87.80 |
ShuffleNetV2 |
@@ -152,6 +164,8 @@ This part selects ImageNet for onnxruntime verification. ImageNet has multiple v
Top 1 / 5 |
69.55 / 88.92 |
69.55 / 88.92 |
+ 69.55 / 88.92 |
+ 69.55 / 88.92 |
MobileNetV2 |
@@ -159,6 +173,8 @@ This part selects ImageNet for onnxruntime verification. ImageNet has multiple v
Top 1 / 5 |
71.86 / 90.42 |
71.86 / 90.42 |
+ 71.86 / 90.42 |
+ 71.88 / 90.40 |
@@ -166,12 +182,12 @@ This part selects ImageNet for onnxruntime verification. ImageNet has multiple v
The table below lists the models that are guaranteed to be exportable to ONNX and runnable in ONNX Runtime.
-| Model | Config | Batch Inference | Dynamic Shape | Note |
-| :----------: | :----------------------------------------------------------: | :-------------: | :-----------: | ---- |
-| MobileNetV2 | [mobilenet_v2_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py) | Y | Y | |
-| ResNet | [resnet18_b16x8_cifar10.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnet/resnet18_b16x8_cifar10.py) | Y | Y | |
-| ResNeXt | [resnext50_32x4d_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext/resnext50_32x4d_b32x8_imagenet.py) | Y | Y | |
-| SE-ResNet | [seresnet50_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/seresnet/seresnet50_b32x8_imagenet.py) | Y | Y | |
+| Model | Config | Batch Inference | Dynamic Shape | Note |
+| :----------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------: | :-----------: | ---- |
+| MobileNetV2 | [mobilenet_v2_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py) | Y | Y | |
+| ResNet | [resnet18_b16x8_cifar10.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnet/resnet18_b16x8_cifar10.py) | Y | Y | |
+| ResNeXt | [resnext50_32x4d_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext/resnext50_32x4d_b32x8_imagenet.py) | Y | Y | |
+| SE-ResNet | [seresnet50_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/seresnet/seresnet50_b32x8_imagenet.py) | Y | Y | |
| ShuffleNetV1 | [shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v1/shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py) | Y | Y | |
| ShuffleNetV2 | [shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py) | Y | Y | |
diff --git a/mmcls/core/export/__init__.py b/mmcls/core/export/__init__.py
index 955e7263..4a4e2bf3 100644
--- a/mmcls/core/export/__init__.py
+++ b/mmcls/core/export/__init__.py
@@ -1,3 +1,3 @@
-from .test import ONNXRuntimeClassifier
+from .test import ONNXRuntimeClassifier, TensorRTClassifier
-__all__ = ['ONNXRuntimeClassifier']
+__all__ = ['ONNXRuntimeClassifier', 'TensorRTClassifier']
diff --git a/mmcls/core/export/test.py b/mmcls/core/export/test.py
index 11f44abf..52245c55 100644
--- a/mmcls/core/export/test.py
+++ b/mmcls/core/export/test.py
@@ -1,5 +1,8 @@
+import warnings
+
import numpy as np
import onnxruntime as ort
+import torch
from mmcls.models.classifiers import BaseClassifier
@@ -55,3 +58,38 @@ class ONNXRuntimeClassifier(BaseClassifier):
self.sess.run_with_iobinding(self.io_binding)
results = self.io_binding.copy_outputs_to_cpu()[0]
return list(results)
+
+
+class TensorRTClassifier(BaseClassifier):
+
+ def __init__(self, trt_file, class_names, device_id):
+ super(TensorRTClassifier, self).__init__()
+ from mmcv.tensorrt import TRTWraper, load_tensorrt_plugin
+ try:
+ load_tensorrt_plugin()
+ except (ImportError, ModuleNotFoundError):
+ warnings.warn('If input model has custom op from mmcv, \
+ you may have to build mmcv with TensorRT from source.')
+ model = TRTWraper(
+ trt_file, input_names=['input'], output_names=['probs'])
+
+ self.model = model
+ self.device_id = device_id
+ self.CLASSES = class_names
+
+ def simple_test(self, img, img_metas, **kwargs):
+ raise NotImplementedError('This method is not implemented.')
+
+ def extract_feat(self, imgs):
+ raise NotImplementedError('This method is not implemented.')
+
+ def forward_train(self, imgs, **kwargs):
+ raise NotImplementedError('This method is not implemented.')
+
+ def forward_test(self, imgs, img_metas, **kwargs):
+ input_data = imgs
+ with torch.cuda.device(self.device_id), torch.no_grad():
+ results = self.model({'input': input_data})['probs']
+ results = results.detach().cpu().numpy()
+
+ return list(results)
diff --git a/tools/deployment/onnx2tensorrt.py b/tools/deployment/onnx2tensorrt.py
index d6271e5d..60454066 100644
--- a/tools/deployment/onnx2tensorrt.py
+++ b/tools/deployment/onnx2tensorrt.py
@@ -13,6 +13,8 @@ def get_GiB(x: int):
def onnx2tensorrt(onnx_file,
trt_file,
input_shape,
+ max_batch_size,
+ fp16_mode=False,
verify=False,
workspace_size=1):
"""Create tensorrt engine from onnx model.
@@ -22,6 +24,7 @@ def onnx2tensorrt(onnx_file,
trt_file (str): Filename of the output TensorRT engine file.
input_shape (list[int]): Input shape of the model.
eg [1, 3, 224, 224].
+ max_batch_size (int): Max batch size of the model.
verify (bool, optional): Whether to verify the converted model.
Defaults to False.
workspace_size (int, optional): Maximium workspace of GPU.
@@ -32,12 +35,14 @@ def onnx2tensorrt(onnx_file,
onnx_model = onnx.load(onnx_file)
# create trt engine and wraper
- opt_shape_dict = {'input': [input_shape, input_shape, input_shape]}
+ assert max_batch_size >= 1
+ max_shape = [max_batch_size] + list(input_shape[1:])
+ opt_shape_dict = {'input': [input_shape, input_shape, max_shape]}
max_workspace_size = get_GiB(workspace_size)
trt_engine = onnx2trt(
onnx_model,
opt_shape_dict,
- fp16_mode=False,
+ fp16_mode=fp16_mode,
max_workspace_size=max_workspace_size)
save_dir, _ = osp.split(trt_file)
if save_dir:
@@ -99,6 +104,12 @@ def parse_args():
nargs='+',
default=[224, 224],
help='Input size of the model')
+ parser.add_argument(
+ '--max-batch-size',
+ type=int,
+ default=1,
+ help='Maximum batch size of TensorRT model.')
+ parser.add_argument('--fp16', action='store_true', help='Enable fp16 mode')
parser.add_argument(
'--workspace-size',
type=int,
@@ -124,5 +135,7 @@ if __name__ == '__main__':
args.model,
args.trt_file,
input_shape,
+ args.max_batch_size,
+ fp16_mode=args.fp16,
verify=args.verify,
workspace_size=args.workspace_size)
diff --git a/tools/deployment/test.py b/tools/deployment/test.py
index 6d0e5ad0..b57665c6 100644
--- a/tools/deployment/test.py
+++ b/tools/deployment/test.py
@@ -7,7 +7,7 @@ from mmcv import DictAction
from mmcv.parallel import MMDataParallel
from mmcls.apis import single_gpu_test
-from mmcls.core.export import ONNXRuntimeClassifier
+from mmcls.core.export import ONNXRuntimeClassifier, TensorRTClassifier
from mmcls.datasets import build_dataloader, build_dataset
@@ -16,6 +16,10 @@ def parse_args():
description='Test (and eval) an ONNX model using ONNXRuntime.')
parser.add_argument('config', help='model config file')
parser.add_argument('model', help='filename of the input ONNX model')
+ parser.add_argument(
+ '--backend',
+ help='Backend of the model.',
+ choices=['onnxruntime', 'tensorrt'])
parser.add_argument(
'--out', type=str, help='output result file in pickle format')
parser.add_argument(
@@ -67,10 +71,18 @@ def main():
round_up=False)
# build onnxruntime model and run inference.
- model = ONNXRuntimeClassifier(
- args.model, class_names=dataset.CLASSES, device_id=0)
+ if args.backend == 'onnxruntime':
+ model = ONNXRuntimeClassifier(
+ args.model, class_names=dataset.CLASSES, device_id=0)
+ elif args.backend == 'tensorrt':
+ model = TensorRTClassifier(
+ args.model, class_names=dataset.CLASSES, device_id=0)
+ else:
+ print('Unknown backend: {}.'.format(args.model))
+ exit()
model = MMDataParallel(model, device_ids=[0])
+ model.CLASSES = dataset.CLASSES
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir)
if args.metrics: