From dc5d53b7f3b8fa754cfb66789e8835d15f7b0867 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Tue, 25 May 2021 11:37:46 +0800 Subject: [PATCH] [Feature] Update deploy test tools (#553) * add trt test tool * create deploy_test, update document * fix with isort * move import inside __init__ * remove comment, fix doc * update document --- docs/useful_tools.md | 31 ++++++++---- tools/{ort_test.py => deploy_test.py} | 72 ++++++++++++++++++++++++--- 2 files changed, 86 insertions(+), 17 deletions(-) rename tools/{ort_test.py => deploy_test.py} (71%) diff --git a/docs/useful_tools.md b/docs/useful_tools.md index 8ae19f5be..de5e127b1 100644 --- a/docs/useful_tools.md +++ b/docs/useful_tools.md @@ -76,9 +76,9 @@ Description of arguments: **Note**: This tool is still experimental. Some customized operators are not supported for now. -### Evaluate ONNX model with ONNXRuntime +### Evaluate ONNX model -We provide `tools/ort_test.py` to evaluate ONNX model with ONNXRuntime backend. +We provide `tools/deploy_test.py` to evaluate ONNX model with different backend. #### Prerequisite @@ -88,12 +88,15 @@ We provide `tools/ort_test.py` to evaluate ONNX model with ONNXRuntime backend. pip install onnx onnxruntime-gpu ``` +- Install TensorRT following [how-to-build-tensorrt-plugins-in-mmcv](https://mmcv.readthedocs.io/en/latest/tensorrt_plugin.html#how-to-build-tensorrt-plugins-in-mmcv)(optional) + #### Usage ```bash -python tools/ort_test.py \ +python tools/deploy_test.py \ ${CONFIG_FILE} \ - ${ONNX_FILE} \ + ${MODEL_FILE} \ + ${BACKEND} \ --out ${OUTPUT_FILE} \ --eval ${EVALUATION_METRICS} \ --show \ @@ -106,7 +109,8 @@ python tools/ort_test.py \ Description of all arguments - `config`: The path of a model config file. -- `model`: The path of a ONNX model file. +- `model`: The path of a converted model file. +- `backend`: Backend of the inference, options: `onnxruntime`, `tensorrt`. - `--out`: The path of output result file in pickle format. - `--format-only` : Format the output results without perform evaluation. It is useful when you want to format the result to a specific format and submit it to the test server. If not specified, it will be set to `False`. Note that this argument is **mutually exclusive** with `--eval`. - `--eval`: Evaluation metrics, which depends on the dataset, e.g., "mIoU" for generic datasets, and "cityscapes" for Cityscapes. Note that this argument is **mutually exclusive** with `--format-only`. @@ -118,12 +122,17 @@ Description of all arguments #### Results and Models -| Model | Config | Dataset | Metric | PyTorch | ONNXRuntime | -| :--------: | :--------------------------------------------: | :--------: | :----: | :-----: | :---------: | -| FCN | fcn_r50-d8_512x1024_40k_cityscapes.py | cityscapes | mIoU | 72.2 | 72.2 | -| PSPNet | pspnet_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.2 | 78.1 | -| deeplabv3 | deeplabv3_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.5 | 78.3 | -| deeplabv3+ | deeplabv3plus_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.9 | 78.7 | +| Model | Config | Dataset | Metric | PyTorch | ONNXRuntime | TensorRT-fp32 | TensorRT-fp16 | +| :--------: | :---------------------------------------------: | :--------: | :----: | :-----: | :---------: | :-----------: | :-----------: | +| FCN | fcn_r50-d8_512x1024_40k_cityscapes.py | cityscapes | mIoU | 72.2 | 72.2 | 72.2 | 72.2 | +| PSPNet | pspnet_r50-d8_512x1024_40k_cityscapes.py | cityscapes | mIoU | 77.8 | 77.8 | 77.8 | 77.8 | +| deeplabv3 | deeplabv3_r50-d8_512x1024_40k_cityscapes.py | cityscapes | mIoU | 79.0 | 79.0 | 79.0 | 79.0 | +| deeplabv3+ | deeplabv3plus_r50-d8_512x1024_40k_cityscapes.py | cityscapes | mIoU | 79.6 | 79.5 | 79.5 | 79.5 | +| PSPNet | pspnet_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.2 | 78.1 | | | +| deeplabv3 | deeplabv3_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.5 | 78.3 | | | +| deeplabv3+ | deeplabv3plus_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.9 | 78.7 | | | + +**Note**: TensorRT is only available on configs with `whole mode`. ### Convert to TorchScript (experimental) diff --git a/tools/ort_test.py b/tools/deploy_test.py similarity index 71% rename from tools/ort_test.py rename to tools/deploy_test.py index 807b21272..bef3512d7 100644 --- a/tools/ort_test.py +++ b/tools/deploy_test.py @@ -2,10 +2,10 @@ import argparse import os import os.path as osp import warnings +from typing import Any, Iterable import mmcv import numpy as np -import onnxruntime as ort import torch from mmcv.parallel import MMDataParallel from mmcv.runner import get_dist_info @@ -18,8 +18,10 @@ from mmseg.models.segmentors.base import BaseSegmentor class ONNXRuntimeSegmentor(BaseSegmentor): - def __init__(self, onnx_file, cfg, device_id): + def __init__(self, onnx_file: str, cfg: Any, device_id: int): super(ONNXRuntimeSegmentor, self).__init__() + import onnxruntime as ort + # get the custom op path ort_custom_op_path = '' try: @@ -60,7 +62,8 @@ class ONNXRuntimeSegmentor(BaseSegmentor): def forward_train(self, imgs, img_metas, **kwargs): raise NotImplementedError('This method is not implemented.') - def simple_test(self, img, img_meta, **kwargs): + def simple_test(self, img: torch.Tensor, img_meta: Iterable, + **kwargs) -> list: device_type = img.device.type self.io_binding.bind_input( name='input', @@ -87,11 +90,63 @@ class ONNXRuntimeSegmentor(BaseSegmentor): raise NotImplementedError('This method is not implemented.') -def parse_args(): +class TensorRTSegmentor(BaseSegmentor): + + def __init__(self, trt_file: str, cfg: Any, device_id: int): + super(TensorRTSegmentor, self).__init__() + from mmcv.tensorrt import TRTWraper, load_tensorrt_plugin + try: + load_tensorrt_plugin() + except (ImportError, ModuleNotFoundError): + warnings.warn('If input model has custom op from mmcv, \ + you may have to build mmcv with TensorRT from source.') + model = TRTWraper( + trt_file, input_names=['input'], output_names=['output']) + + self.model = model + self.device_id = device_id + self.cfg = cfg + self.test_mode = cfg.model.test_cfg.mode + + def extract_feat(self, imgs): + raise NotImplementedError('This method is not implemented.') + + def encode_decode(self, img, img_metas): + raise NotImplementedError('This method is not implemented.') + + def forward_train(self, imgs, img_metas, **kwargs): + raise NotImplementedError('This method is not implemented.') + + def simple_test(self, img: torch.Tensor, img_meta: Iterable, + **kwargs) -> list: + with torch.cuda.device(self.device_id), torch.no_grad(): + seg_pred = self.model({'input': img})['output'] + seg_pred = seg_pred.detach().cpu().numpy() + # whole might support dynamic reshape + ori_shape = img_meta[0]['ori_shape'] + if not (ori_shape[0] == seg_pred.shape[-2] + and ori_shape[1] == seg_pred.shape[-1]): + seg_pred = torch.from_numpy(seg_pred).float() + seg_pred = torch.nn.functional.interpolate( + seg_pred, size=tuple(ori_shape[:2]), mode='nearest') + seg_pred = seg_pred.long().detach().cpu().numpy() + seg_pred = seg_pred[0] + seg_pred = list(seg_pred) + return seg_pred + + def aug_test(self, imgs, img_metas, **kwargs): + raise NotImplementedError('This method is not implemented.') + + +def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( - description='mmseg onnxruntime backend test (and eval) a model') + description='mmseg backend test (and eval)') parser.add_argument('config', help='test config file path') parser.add_argument('model', help='Input model file') + parser.add_argument( + '--backend', + help='Backend of the model.', + choices=['onnxruntime', 'tensorrt']) parser.add_argument('--out', help='output result file in pickle format') parser.add_argument( '--format-only', @@ -163,7 +218,12 @@ def main(): # load onnx config and meta cfg.model.train_cfg = None - model = ONNXRuntimeSegmentor(args.model, cfg=cfg, device_id=0) + + if args.backend == 'onnxruntime': + model = ONNXRuntimeSegmentor(args.model, cfg=cfg, device_id=0) + elif args.backend == 'tensorrt': + model = TensorRTSegmentor(args.model, cfg=cfg, device_id=0) + model.CLASSES = dataset.CLASSES model.PALETTE = dataset.PALETTE