mirror of
https://github.com/open-mmlab/mmpretrain.git
synced 2025-06-03 14:59:18 +08:00
[Feature] TensorRT test tools. (#284)
* first commit * update resnext result * update docs * update docstring
This commit is contained in:
parent
e2507cae97
commit
dbddde52ef
@ -26,7 +26,9 @@ python tools/deployment/onnx2tensorrt.py \
|
|||||||
${MODEL} \
|
${MODEL} \
|
||||||
--trt-file ${TRT_FILE} \
|
--trt-file ${TRT_FILE} \
|
||||||
--shape ${IMAGE_SHAPE} \
|
--shape ${IMAGE_SHAPE} \
|
||||||
--workspace-size {WORKSPACE_SIZE} \
|
--max-batch-size ${MAX_BATCH_SIZE} \
|
||||||
|
--workspace-size ${WORKSPACE_SIZE} \
|
||||||
|
--fp16 \
|
||||||
--show \
|
--show \
|
||||||
--verify \
|
--verify \
|
||||||
```
|
```
|
||||||
@ -36,6 +38,8 @@ Description of all arguments:
|
|||||||
- `model` : The path of an ONNX model file.
|
- `model` : The path of an ONNX model file.
|
||||||
- `--trt-file`: The Path of output TensorRT engine file. If not specified, it will be set to `tmp.trt`.
|
- `--trt-file`: The Path of output TensorRT engine file. If not specified, it will be set to `tmp.trt`.
|
||||||
- `--shape`: The height and width of model input. If not specified, it will be set to `224 224`.
|
- `--shape`: The height and width of model input. If not specified, it will be set to `224 224`.
|
||||||
|
- `--max-batch-size`: The max batch size of TensorRT model, should not be less than 1.
|
||||||
|
- `--fp16`: Enable fp16 mode.
|
||||||
- `--workspace-size` : The required GPU workspace size in GiB to build TensorRT engine. If not specified, it will be set to `1` GiB.
|
- `--workspace-size` : The required GPU workspace size in GiB to build TensorRT engine. If not specified, it will be set to `1` GiB.
|
||||||
- `--show`: Determines whether to show the outputs of the model. If not specified, it will be set to `False`.
|
- `--show`: Determines whether to show the outputs of the model. If not specified, it will be set to `False`.
|
||||||
- `--verify`: Determines whether to verify the correctness of models between ONNXRuntime and TensorRT. If not specified, it will be set to `False`.
|
- `--verify`: Determines whether to verify the correctness of models between ONNXRuntime and TensorRT. If not specified, it will be set to `False`.
|
||||||
@ -55,11 +59,11 @@ python tools/deployment/onnx2tensorrt.py \
|
|||||||
|
|
||||||
The table below lists the models that are guaranteed to be convertable to TensorRT.
|
The table below lists the models that are guaranteed to be convertable to TensorRT.
|
||||||
|
|
||||||
| Model | Config | Status |
|
| Model | Config | Status |
|
||||||
| :----------: | :----------------------------------------------------------: | :----: |
|
| :----------: | :--------------------------------------------------------------------------: | :----: |
|
||||||
| MobileNetV2 | `configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py` | Y |
|
| MobileNetV2 | `configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py` | Y |
|
||||||
| ResNet | `configs/resnet/resnet18_b16x8_cifar10.py` | Y |
|
| ResNet | `configs/resnet/resnet18_b16x8_cifar10.py` | Y |
|
||||||
| ResNeXt | `configs/resnext/resnext50_32x4d_b32x8_imagenet.py` | Y |
|
| ResNeXt | `configs/resnext/resnext50_32x4d_b32x8_imagenet.py` | Y |
|
||||||
| ShuffleNetV1 | `configs/shufflenet_v1/shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py` | Y |
|
| ShuffleNetV1 | `configs/shufflenet_v1/shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py` | Y |
|
||||||
| ShuffleNetV2 | `configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py` | Y |
|
| ShuffleNetV2 | `configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py` | Y |
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
- [How to convert models from Pytorch to ONNX](#how-to-convert-models-from-pytorch-to-onnx)
|
- [How to convert models from Pytorch to ONNX](#how-to-convert-models-from-pytorch-to-onnx)
|
||||||
- [Prerequisite](#prerequisite)
|
- [Prerequisite](#prerequisite)
|
||||||
- [Usage](#usage)
|
- [Usage](#usage)
|
||||||
- [Description of all arguments](#description-of-all-arguments)
|
- [Description of all arguments:](#description-of-all-arguments)
|
||||||
- [How to evaluate ONNX models with ONNX Runtime](#how-to-evaluate-onnx-models-with-onnx-runtime)
|
- [How to evaluate ONNX models with ONNX Runtime](#how-to-evaluate-onnx-models-with-onnx-runtime)
|
||||||
- [Prerequisite](#prerequisite-1)
|
- [Prerequisite](#prerequisite-1)
|
||||||
- [Usage](#usage-1)
|
- [Usage](#usage-1)
|
||||||
@ -71,7 +71,7 @@ python tools/deployment/pytorch2onnx.py \
|
|||||||
|
|
||||||
## How to evaluate ONNX models with ONNX Runtime
|
## How to evaluate ONNX models with ONNX Runtime
|
||||||
|
|
||||||
We prepare a tool `tools/deployment/test.py` to evaluate ONNX models with ONNX Runtime backend.
|
We prepare a tool `tools/deployment/test.py` to evaluate ONNX models with ONNXRuntime or TensorRT.
|
||||||
|
|
||||||
### Prerequisite
|
### Prerequisite
|
||||||
|
|
||||||
@ -87,6 +87,7 @@ We prepare a tool `tools/deployment/test.py` to evaluate ONNX models with ONNX R
|
|||||||
python tools/deployment/test.py \
|
python tools/deployment/test.py \
|
||||||
${CONFIG_FILE} \
|
${CONFIG_FILE} \
|
||||||
${ONNX_FILE} \
|
${ONNX_FILE} \
|
||||||
|
--backend ${BACKEND} \
|
||||||
--out ${OUTPUT_FILE} \
|
--out ${OUTPUT_FILE} \
|
||||||
--metrics ${EVALUATION_METRICS} \
|
--metrics ${EVALUATION_METRICS} \
|
||||||
--metric-options ${EVALUATION_OPTIONS} \
|
--metric-options ${EVALUATION_OPTIONS} \
|
||||||
@ -99,6 +100,7 @@ python tools/deployment/test.py \
|
|||||||
|
|
||||||
- `config`: The path of a model config file.
|
- `config`: The path of a model config file.
|
||||||
- `model`: The path of a ONNX model file.
|
- `model`: The path of a ONNX model file.
|
||||||
|
- `--backend`: Backend for input model to run and should be `onnxruntime` or `tensorrt`.
|
||||||
- `--out`: The path of output result file in pickle format.
|
- `--out`: The path of output result file in pickle format.
|
||||||
- `--metrics`: Evaluation metrics, which depends on the dataset, e.g., "accuracy", "precision", "recall", "f1_score", "support" for single label dataset, and "mAP", "CP", "CR", "CF1", "OP", "OR", "OF1" for multi-label dataset.
|
- `--metrics`: Evaluation metrics, which depends on the dataset, e.g., "accuracy", "precision", "recall", "f1_score", "support" for single label dataset, and "mAP", "CP", "CR", "CF1", "OP", "OR", "OF1" for multi-label dataset.
|
||||||
- `--show`: Determines whether to show classifier outputs. If not specified, it will be set to `False`.
|
- `--show`: Determines whether to show classifier outputs. If not specified, it will be set to `False`.
|
||||||
@ -117,6 +119,8 @@ This part selects ImageNet for onnxruntime verification. ImageNet has multiple v
|
|||||||
<th align="center">Metric</th>
|
<th align="center">Metric</th>
|
||||||
<th align="center">PyTorch</th>
|
<th align="center">PyTorch</th>
|
||||||
<th align="center">ONNXRuntime</th>
|
<th align="center">ONNXRuntime</th>
|
||||||
|
<th align="center">TensorRT-fp32</th>
|
||||||
|
<th align="center">TensorRT-fp16</th>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td align="center">ResNet</td>
|
<td align="center">ResNet</td>
|
||||||
@ -124,13 +128,17 @@ This part selects ImageNet for onnxruntime verification. ImageNet has multiple v
|
|||||||
<td align="center">Top 1 / 5</td>
|
<td align="center">Top 1 / 5</td>
|
||||||
<td align="center">76.55 / 93.15</td>
|
<td align="center">76.55 / 93.15</td>
|
||||||
<td align="center">76.49 / 93.22</td>
|
<td align="center">76.49 / 93.22</td>
|
||||||
|
<td align="center">76.49 / 93.22</td>
|
||||||
|
<td align="center">76.50 / 93.20</td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td align="center">ResNeXt</td>
|
<td align="center">ResNeXt</td>
|
||||||
<td align="center"><code>resnext50_32x4d_b32x8_imagenet.py</code></td>
|
<td align="center"><code>resnext50_32x4d_b32x8_imagenet.py</code></td>
|
||||||
<td align="center">Top 1 / 5</td>
|
<td align="center">Top 1 / 5</td>
|
||||||
<td align="center">77.83 / 93.65</td>
|
<td align="center">77.90 / 93.66</td>
|
||||||
<td align="center">77.83 / 93.65</td>
|
<td align="center">77.90 / 93.66</td>
|
||||||
|
<td align="center">77.90 / 93.66</td>
|
||||||
|
<td align="center">77.89 / 93.65</td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td align="center">SE-ResNet</td>
|
<td align="center">SE-ResNet</td>
|
||||||
@ -138,6 +146,8 @@ This part selects ImageNet for onnxruntime verification. ImageNet has multiple v
|
|||||||
<td align="center">Top 1 / 5</td>
|
<td align="center">Top 1 / 5</td>
|
||||||
<td align="center">77.74 / 93.84</td>
|
<td align="center">77.74 / 93.84</td>
|
||||||
<td align="center">77.74 / 93.84</td>
|
<td align="center">77.74 / 93.84</td>
|
||||||
|
<td align="center">77.74 / 93.84</td>
|
||||||
|
<td align="center">77.74 / 93.85</td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td align="center">ShuffleNetV1</td>
|
<td align="center">ShuffleNetV1</td>
|
||||||
@ -145,6 +155,8 @@ This part selects ImageNet for onnxruntime verification. ImageNet has multiple v
|
|||||||
<td align="center">Top 1 / 5</td>
|
<td align="center">Top 1 / 5</td>
|
||||||
<td align="center">68.13 / 87.81</td>
|
<td align="center">68.13 / 87.81</td>
|
||||||
<td align="center">68.13 / 87.81</td>
|
<td align="center">68.13 / 87.81</td>
|
||||||
|
<td align="center">68.13 / 87.81</td>
|
||||||
|
<td align="center">68.10 / 87.80</td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td align="center">ShuffleNetV2</td>
|
<td align="center">ShuffleNetV2</td>
|
||||||
@ -152,6 +164,8 @@ This part selects ImageNet for onnxruntime verification. ImageNet has multiple v
|
|||||||
<td align="center">Top 1 / 5</td>
|
<td align="center">Top 1 / 5</td>
|
||||||
<td align="center">69.55 / 88.92</td>
|
<td align="center">69.55 / 88.92</td>
|
||||||
<td align="center">69.55 / 88.92</td>
|
<td align="center">69.55 / 88.92</td>
|
||||||
|
<td align="center">69.55 / 88.92</td>
|
||||||
|
<td align="center">69.55 / 88.92</td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td align="center">MobileNetV2</td>
|
<td align="center">MobileNetV2</td>
|
||||||
@ -159,6 +173,8 @@ This part selects ImageNet for onnxruntime verification. ImageNet has multiple v
|
|||||||
<td align="center">Top 1 / 5</td>
|
<td align="center">Top 1 / 5</td>
|
||||||
<td align="center">71.86 / 90.42</td>
|
<td align="center">71.86 / 90.42</td>
|
||||||
<td align="center">71.86 / 90.42</td>
|
<td align="center">71.86 / 90.42</td>
|
||||||
|
<td align="center">71.86 / 90.42</td>
|
||||||
|
<td align="center">71.88 / 90.40</td>
|
||||||
</tr>
|
</tr>
|
||||||
</table>
|
</table>
|
||||||
|
|
||||||
@ -166,12 +182,12 @@ This part selects ImageNet for onnxruntime verification. ImageNet has multiple v
|
|||||||
|
|
||||||
The table below lists the models that are guaranteed to be exportable to ONNX and runnable in ONNX Runtime.
|
The table below lists the models that are guaranteed to be exportable to ONNX and runnable in ONNX Runtime.
|
||||||
|
|
||||||
| Model | Config | Batch Inference | Dynamic Shape | Note |
|
| Model | Config | Batch Inference | Dynamic Shape | Note |
|
||||||
| :----------: | :----------------------------------------------------------: | :-------------: | :-----------: | ---- |
|
| :----------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------: | :-----------: | ---- |
|
||||||
| MobileNetV2 | [mobilenet_v2_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py) | Y | Y | |
|
| MobileNetV2 | [mobilenet_v2_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/mobilenet_v2/mobilenet_v2_b32x8_imagenet.py) | Y | Y | |
|
||||||
| ResNet | [resnet18_b16x8_cifar10.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnet/resnet18_b16x8_cifar10.py) | Y | Y | |
|
| ResNet | [resnet18_b16x8_cifar10.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnet/resnet18_b16x8_cifar10.py) | Y | Y | |
|
||||||
| ResNeXt | [resnext50_32x4d_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext/resnext50_32x4d_b32x8_imagenet.py) | Y | Y | |
|
| ResNeXt | [resnext50_32x4d_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/resnext/resnext50_32x4d_b32x8_imagenet.py) | Y | Y | |
|
||||||
| SE-ResNet | [seresnet50_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/seresnet/seresnet50_b32x8_imagenet.py) | Y | Y | |
|
| SE-ResNet | [seresnet50_b32x8_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/seresnet/seresnet50_b32x8_imagenet.py) | Y | Y | |
|
||||||
| ShuffleNetV1 | [shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v1/shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py) | Y | Y | |
|
| ShuffleNetV1 | [shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v1/shufflenet_v1_1x_b64x16_linearlr_bn_nowd_imagenet.py) | Y | Y | |
|
||||||
| ShuffleNetV2 | [shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py) | Y | Y | |
|
| ShuffleNetV2 | [shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py](https://github.com/open-mmlab/mmclassification/tree/master/configs/shufflenet_v2/shufflenet_v2_1x_b64x16_linearlr_bn_nowd_imagenet.py) | Y | Y | |
|
||||||
|
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
from .test import ONNXRuntimeClassifier
|
from .test import ONNXRuntimeClassifier, TensorRTClassifier
|
||||||
|
|
||||||
__all__ = ['ONNXRuntimeClassifier']
|
__all__ = ['ONNXRuntimeClassifier', 'TensorRTClassifier']
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
|
import warnings
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
import torch
|
||||||
|
|
||||||
from mmcls.models.classifiers import BaseClassifier
|
from mmcls.models.classifiers import BaseClassifier
|
||||||
|
|
||||||
@ -55,3 +58,38 @@ class ONNXRuntimeClassifier(BaseClassifier):
|
|||||||
self.sess.run_with_iobinding(self.io_binding)
|
self.sess.run_with_iobinding(self.io_binding)
|
||||||
results = self.io_binding.copy_outputs_to_cpu()[0]
|
results = self.io_binding.copy_outputs_to_cpu()[0]
|
||||||
return list(results)
|
return list(results)
|
||||||
|
|
||||||
|
|
||||||
|
class TensorRTClassifier(BaseClassifier):
|
||||||
|
|
||||||
|
def __init__(self, trt_file, class_names, device_id):
|
||||||
|
super(TensorRTClassifier, self).__init__()
|
||||||
|
from mmcv.tensorrt import TRTWraper, load_tensorrt_plugin
|
||||||
|
try:
|
||||||
|
load_tensorrt_plugin()
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
warnings.warn('If input model has custom op from mmcv, \
|
||||||
|
you may have to build mmcv with TensorRT from source.')
|
||||||
|
model = TRTWraper(
|
||||||
|
trt_file, input_names=['input'], output_names=['probs'])
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.device_id = device_id
|
||||||
|
self.CLASSES = class_names
|
||||||
|
|
||||||
|
def simple_test(self, img, img_metas, **kwargs):
|
||||||
|
raise NotImplementedError('This method is not implemented.')
|
||||||
|
|
||||||
|
def extract_feat(self, imgs):
|
||||||
|
raise NotImplementedError('This method is not implemented.')
|
||||||
|
|
||||||
|
def forward_train(self, imgs, **kwargs):
|
||||||
|
raise NotImplementedError('This method is not implemented.')
|
||||||
|
|
||||||
|
def forward_test(self, imgs, img_metas, **kwargs):
|
||||||
|
input_data = imgs
|
||||||
|
with torch.cuda.device(self.device_id), torch.no_grad():
|
||||||
|
results = self.model({'input': input_data})['probs']
|
||||||
|
results = results.detach().cpu().numpy()
|
||||||
|
|
||||||
|
return list(results)
|
||||||
|
@ -13,6 +13,8 @@ def get_GiB(x: int):
|
|||||||
def onnx2tensorrt(onnx_file,
|
def onnx2tensorrt(onnx_file,
|
||||||
trt_file,
|
trt_file,
|
||||||
input_shape,
|
input_shape,
|
||||||
|
max_batch_size,
|
||||||
|
fp16_mode=False,
|
||||||
verify=False,
|
verify=False,
|
||||||
workspace_size=1):
|
workspace_size=1):
|
||||||
"""Create tensorrt engine from onnx model.
|
"""Create tensorrt engine from onnx model.
|
||||||
@ -22,6 +24,7 @@ def onnx2tensorrt(onnx_file,
|
|||||||
trt_file (str): Filename of the output TensorRT engine file.
|
trt_file (str): Filename of the output TensorRT engine file.
|
||||||
input_shape (list[int]): Input shape of the model.
|
input_shape (list[int]): Input shape of the model.
|
||||||
eg [1, 3, 224, 224].
|
eg [1, 3, 224, 224].
|
||||||
|
max_batch_size (int): Max batch size of the model.
|
||||||
verify (bool, optional): Whether to verify the converted model.
|
verify (bool, optional): Whether to verify the converted model.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
workspace_size (int, optional): Maximium workspace of GPU.
|
workspace_size (int, optional): Maximium workspace of GPU.
|
||||||
@ -32,12 +35,14 @@ def onnx2tensorrt(onnx_file,
|
|||||||
|
|
||||||
onnx_model = onnx.load(onnx_file)
|
onnx_model = onnx.load(onnx_file)
|
||||||
# create trt engine and wraper
|
# create trt engine and wraper
|
||||||
opt_shape_dict = {'input': [input_shape, input_shape, input_shape]}
|
assert max_batch_size >= 1
|
||||||
|
max_shape = [max_batch_size] + list(input_shape[1:])
|
||||||
|
opt_shape_dict = {'input': [input_shape, input_shape, max_shape]}
|
||||||
max_workspace_size = get_GiB(workspace_size)
|
max_workspace_size = get_GiB(workspace_size)
|
||||||
trt_engine = onnx2trt(
|
trt_engine = onnx2trt(
|
||||||
onnx_model,
|
onnx_model,
|
||||||
opt_shape_dict,
|
opt_shape_dict,
|
||||||
fp16_mode=False,
|
fp16_mode=fp16_mode,
|
||||||
max_workspace_size=max_workspace_size)
|
max_workspace_size=max_workspace_size)
|
||||||
save_dir, _ = osp.split(trt_file)
|
save_dir, _ = osp.split(trt_file)
|
||||||
if save_dir:
|
if save_dir:
|
||||||
@ -99,6 +104,12 @@ def parse_args():
|
|||||||
nargs='+',
|
nargs='+',
|
||||||
default=[224, 224],
|
default=[224, 224],
|
||||||
help='Input size of the model')
|
help='Input size of the model')
|
||||||
|
parser.add_argument(
|
||||||
|
'--max-batch-size',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help='Maximum batch size of TensorRT model.')
|
||||||
|
parser.add_argument('--fp16', action='store_true', help='Enable fp16 mode')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--workspace-size',
|
'--workspace-size',
|
||||||
type=int,
|
type=int,
|
||||||
@ -124,5 +135,7 @@ if __name__ == '__main__':
|
|||||||
args.model,
|
args.model,
|
||||||
args.trt_file,
|
args.trt_file,
|
||||||
input_shape,
|
input_shape,
|
||||||
|
args.max_batch_size,
|
||||||
|
fp16_mode=args.fp16,
|
||||||
verify=args.verify,
|
verify=args.verify,
|
||||||
workspace_size=args.workspace_size)
|
workspace_size=args.workspace_size)
|
||||||
|
@ -7,7 +7,7 @@ from mmcv import DictAction
|
|||||||
from mmcv.parallel import MMDataParallel
|
from mmcv.parallel import MMDataParallel
|
||||||
|
|
||||||
from mmcls.apis import single_gpu_test
|
from mmcls.apis import single_gpu_test
|
||||||
from mmcls.core.export import ONNXRuntimeClassifier
|
from mmcls.core.export import ONNXRuntimeClassifier, TensorRTClassifier
|
||||||
from mmcls.datasets import build_dataloader, build_dataset
|
from mmcls.datasets import build_dataloader, build_dataset
|
||||||
|
|
||||||
|
|
||||||
@ -16,6 +16,10 @@ def parse_args():
|
|||||||
description='Test (and eval) an ONNX model using ONNXRuntime.')
|
description='Test (and eval) an ONNX model using ONNXRuntime.')
|
||||||
parser.add_argument('config', help='model config file')
|
parser.add_argument('config', help='model config file')
|
||||||
parser.add_argument('model', help='filename of the input ONNX model')
|
parser.add_argument('model', help='filename of the input ONNX model')
|
||||||
|
parser.add_argument(
|
||||||
|
'--backend',
|
||||||
|
help='Backend of the model.',
|
||||||
|
choices=['onnxruntime', 'tensorrt'])
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--out', type=str, help='output result file in pickle format')
|
'--out', type=str, help='output result file in pickle format')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -67,10 +71,18 @@ def main():
|
|||||||
round_up=False)
|
round_up=False)
|
||||||
|
|
||||||
# build onnxruntime model and run inference.
|
# build onnxruntime model and run inference.
|
||||||
model = ONNXRuntimeClassifier(
|
if args.backend == 'onnxruntime':
|
||||||
args.model, class_names=dataset.CLASSES, device_id=0)
|
model = ONNXRuntimeClassifier(
|
||||||
|
args.model, class_names=dataset.CLASSES, device_id=0)
|
||||||
|
elif args.backend == 'tensorrt':
|
||||||
|
model = TensorRTClassifier(
|
||||||
|
args.model, class_names=dataset.CLASSES, device_id=0)
|
||||||
|
else:
|
||||||
|
print('Unknown backend: {}.'.format(args.model))
|
||||||
|
exit()
|
||||||
|
|
||||||
model = MMDataParallel(model, device_ids=[0])
|
model = MMDataParallel(model, device_ids=[0])
|
||||||
|
model.CLASSES = dataset.CLASSES
|
||||||
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir)
|
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir)
|
||||||
|
|
||||||
if args.metrics:
|
if args.metrics:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user