mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] Update deploy test tools (#553)
* add trt test tool * create deploy_test, update document * fix with isort * move import inside __init__ * remove comment, fix doc * update document
This commit is contained in:
parent
1f2d354b70
commit
597736288c
@ -76,9 +76,9 @@ Description of arguments:
|
|||||||
|
|
||||||
**Note**: This tool is still experimental. Some customized operators are not supported for now.
|
**Note**: This tool is still experimental. Some customized operators are not supported for now.
|
||||||
|
|
||||||
### Evaluate ONNX model with ONNXRuntime
|
### Evaluate ONNX model
|
||||||
|
|
||||||
We provide `tools/ort_test.py` to evaluate ONNX model with ONNXRuntime backend.
|
We provide `tools/deploy_test.py` to evaluate ONNX model with different backend.
|
||||||
|
|
||||||
#### Prerequisite
|
#### Prerequisite
|
||||||
|
|
||||||
@ -88,12 +88,15 @@ We provide `tools/ort_test.py` to evaluate ONNX model with ONNXRuntime backend.
|
|||||||
pip install onnx onnxruntime-gpu
|
pip install onnx onnxruntime-gpu
|
||||||
```
|
```
|
||||||
|
|
||||||
|
- Install TensorRT following [how-to-build-tensorrt-plugins-in-mmcv](https://mmcv.readthedocs.io/en/latest/tensorrt_plugin.html#how-to-build-tensorrt-plugins-in-mmcv)(optional)
|
||||||
|
|
||||||
#### Usage
|
#### Usage
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python tools/ort_test.py \
|
python tools/deploy_test.py \
|
||||||
${CONFIG_FILE} \
|
${CONFIG_FILE} \
|
||||||
${ONNX_FILE} \
|
${MODEL_FILE} \
|
||||||
|
${BACKEND} \
|
||||||
--out ${OUTPUT_FILE} \
|
--out ${OUTPUT_FILE} \
|
||||||
--eval ${EVALUATION_METRICS} \
|
--eval ${EVALUATION_METRICS} \
|
||||||
--show \
|
--show \
|
||||||
@ -106,7 +109,8 @@ python tools/ort_test.py \
|
|||||||
Description of all arguments
|
Description of all arguments
|
||||||
|
|
||||||
- `config`: The path of a model config file.
|
- `config`: The path of a model config file.
|
||||||
- `model`: The path of a ONNX model file.
|
- `model`: The path of a converted model file.
|
||||||
|
- `backend`: Backend of the inference, options: `onnxruntime`, `tensorrt`.
|
||||||
- `--out`: The path of output result file in pickle format.
|
- `--out`: The path of output result file in pickle format.
|
||||||
- `--format-only` : Format the output results without perform evaluation. It is useful when you want to format the result to a specific format and submit it to the test server. If not specified, it will be set to `False`. Note that this argument is **mutually exclusive** with `--eval`.
|
- `--format-only` : Format the output results without perform evaluation. It is useful when you want to format the result to a specific format and submit it to the test server. If not specified, it will be set to `False`. Note that this argument is **mutually exclusive** with `--eval`.
|
||||||
- `--eval`: Evaluation metrics, which depends on the dataset, e.g., "mIoU" for generic datasets, and "cityscapes" for Cityscapes. Note that this argument is **mutually exclusive** with `--format-only`.
|
- `--eval`: Evaluation metrics, which depends on the dataset, e.g., "mIoU" for generic datasets, and "cityscapes" for Cityscapes. Note that this argument is **mutually exclusive** with `--format-only`.
|
||||||
@ -118,12 +122,17 @@ Description of all arguments
|
|||||||
|
|
||||||
#### Results and Models
|
#### Results and Models
|
||||||
|
|
||||||
| Model | Config | Dataset | Metric | PyTorch | ONNXRuntime |
|
| Model | Config | Dataset | Metric | PyTorch | ONNXRuntime | TensorRT-fp32 | TensorRT-fp16 |
|
||||||
| :--------: | :--------------------------------------------: | :--------: | :----: | :-----: | :---------: |
|
| :--------: | :---------------------------------------------: | :--------: | :----: | :-----: | :---------: | :-----------: | :-----------: |
|
||||||
| FCN | fcn_r50-d8_512x1024_40k_cityscapes.py | cityscapes | mIoU | 72.2 | 72.2 |
|
| FCN | fcn_r50-d8_512x1024_40k_cityscapes.py | cityscapes | mIoU | 72.2 | 72.2 | 72.2 | 72.2 |
|
||||||
| PSPNet | pspnet_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.2 | 78.1 |
|
| PSPNet | pspnet_r50-d8_512x1024_40k_cityscapes.py | cityscapes | mIoU | 77.8 | 77.8 | 77.8 | 77.8 |
|
||||||
| deeplabv3 | deeplabv3_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.5 | 78.3 |
|
| deeplabv3 | deeplabv3_r50-d8_512x1024_40k_cityscapes.py | cityscapes | mIoU | 79.0 | 79.0 | 79.0 | 79.0 |
|
||||||
| deeplabv3+ | deeplabv3plus_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.9 | 78.7 |
|
| deeplabv3+ | deeplabv3plus_r50-d8_512x1024_40k_cityscapes.py | cityscapes | mIoU | 79.6 | 79.5 | 79.5 | 79.5 |
|
||||||
|
| PSPNet | pspnet_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.2 | 78.1 | | |
|
||||||
|
| deeplabv3 | deeplabv3_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.5 | 78.3 | | |
|
||||||
|
| deeplabv3+ | deeplabv3plus_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.9 | 78.7 | | |
|
||||||
|
|
||||||
|
**Note**: TensorRT is only available on configs with `whole mode`.
|
||||||
|
|
||||||
### Convert to TorchScript (experimental)
|
### Convert to TorchScript (experimental)
|
||||||
|
|
||||||
|
@ -2,10 +2,10 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import warnings
|
import warnings
|
||||||
|
from typing import Any, Iterable
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime as ort
|
|
||||||
import torch
|
import torch
|
||||||
from mmcv.parallel import MMDataParallel
|
from mmcv.parallel import MMDataParallel
|
||||||
from mmcv.runner import get_dist_info
|
from mmcv.runner import get_dist_info
|
||||||
@ -18,8 +18,10 @@ from mmseg.models.segmentors.base import BaseSegmentor
|
|||||||
|
|
||||||
class ONNXRuntimeSegmentor(BaseSegmentor):
|
class ONNXRuntimeSegmentor(BaseSegmentor):
|
||||||
|
|
||||||
def __init__(self, onnx_file, cfg, device_id):
|
def __init__(self, onnx_file: str, cfg: Any, device_id: int):
|
||||||
super(ONNXRuntimeSegmentor, self).__init__()
|
super(ONNXRuntimeSegmentor, self).__init__()
|
||||||
|
import onnxruntime as ort
|
||||||
|
|
||||||
# get the custom op path
|
# get the custom op path
|
||||||
ort_custom_op_path = ''
|
ort_custom_op_path = ''
|
||||||
try:
|
try:
|
||||||
@ -60,7 +62,8 @@ class ONNXRuntimeSegmentor(BaseSegmentor):
|
|||||||
def forward_train(self, imgs, img_metas, **kwargs):
|
def forward_train(self, imgs, img_metas, **kwargs):
|
||||||
raise NotImplementedError('This method is not implemented.')
|
raise NotImplementedError('This method is not implemented.')
|
||||||
|
|
||||||
def simple_test(self, img, img_meta, **kwargs):
|
def simple_test(self, img: torch.Tensor, img_meta: Iterable,
|
||||||
|
**kwargs) -> list:
|
||||||
device_type = img.device.type
|
device_type = img.device.type
|
||||||
self.io_binding.bind_input(
|
self.io_binding.bind_input(
|
||||||
name='input',
|
name='input',
|
||||||
@ -87,11 +90,63 @@ class ONNXRuntimeSegmentor(BaseSegmentor):
|
|||||||
raise NotImplementedError('This method is not implemented.')
|
raise NotImplementedError('This method is not implemented.')
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
class TensorRTSegmentor(BaseSegmentor):
|
||||||
|
|
||||||
|
def __init__(self, trt_file: str, cfg: Any, device_id: int):
|
||||||
|
super(TensorRTSegmentor, self).__init__()
|
||||||
|
from mmcv.tensorrt import TRTWraper, load_tensorrt_plugin
|
||||||
|
try:
|
||||||
|
load_tensorrt_plugin()
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
warnings.warn('If input model has custom op from mmcv, \
|
||||||
|
you may have to build mmcv with TensorRT from source.')
|
||||||
|
model = TRTWraper(
|
||||||
|
trt_file, input_names=['input'], output_names=['output'])
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.device_id = device_id
|
||||||
|
self.cfg = cfg
|
||||||
|
self.test_mode = cfg.model.test_cfg.mode
|
||||||
|
|
||||||
|
def extract_feat(self, imgs):
|
||||||
|
raise NotImplementedError('This method is not implemented.')
|
||||||
|
|
||||||
|
def encode_decode(self, img, img_metas):
|
||||||
|
raise NotImplementedError('This method is not implemented.')
|
||||||
|
|
||||||
|
def forward_train(self, imgs, img_metas, **kwargs):
|
||||||
|
raise NotImplementedError('This method is not implemented.')
|
||||||
|
|
||||||
|
def simple_test(self, img: torch.Tensor, img_meta: Iterable,
|
||||||
|
**kwargs) -> list:
|
||||||
|
with torch.cuda.device(self.device_id), torch.no_grad():
|
||||||
|
seg_pred = self.model({'input': img})['output']
|
||||||
|
seg_pred = seg_pred.detach().cpu().numpy()
|
||||||
|
# whole might support dynamic reshape
|
||||||
|
ori_shape = img_meta[0]['ori_shape']
|
||||||
|
if not (ori_shape[0] == seg_pred.shape[-2]
|
||||||
|
and ori_shape[1] == seg_pred.shape[-1]):
|
||||||
|
seg_pred = torch.from_numpy(seg_pred).float()
|
||||||
|
seg_pred = torch.nn.functional.interpolate(
|
||||||
|
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
|
||||||
|
seg_pred = seg_pred.long().detach().cpu().numpy()
|
||||||
|
seg_pred = seg_pred[0]
|
||||||
|
seg_pred = list(seg_pred)
|
||||||
|
return seg_pred
|
||||||
|
|
||||||
|
def aug_test(self, imgs, img_metas, **kwargs):
|
||||||
|
raise NotImplementedError('This method is not implemented.')
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description='mmseg onnxruntime backend test (and eval) a model')
|
description='mmseg backend test (and eval)')
|
||||||
parser.add_argument('config', help='test config file path')
|
parser.add_argument('config', help='test config file path')
|
||||||
parser.add_argument('model', help='Input model file')
|
parser.add_argument('model', help='Input model file')
|
||||||
|
parser.add_argument(
|
||||||
|
'--backend',
|
||||||
|
help='Backend of the model.',
|
||||||
|
choices=['onnxruntime', 'tensorrt'])
|
||||||
parser.add_argument('--out', help='output result file in pickle format')
|
parser.add_argument('--out', help='output result file in pickle format')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--format-only',
|
'--format-only',
|
||||||
@ -163,7 +218,12 @@ def main():
|
|||||||
|
|
||||||
# load onnx config and meta
|
# load onnx config and meta
|
||||||
cfg.model.train_cfg = None
|
cfg.model.train_cfg = None
|
||||||
|
|
||||||
|
if args.backend == 'onnxruntime':
|
||||||
model = ONNXRuntimeSegmentor(args.model, cfg=cfg, device_id=0)
|
model = ONNXRuntimeSegmentor(args.model, cfg=cfg, device_id=0)
|
||||||
|
elif args.backend == 'tensorrt':
|
||||||
|
model = TensorRTSegmentor(args.model, cfg=cfg, device_id=0)
|
||||||
|
|
||||||
model.CLASSES = dataset.CLASSES
|
model.CLASSES = dataset.CLASSES
|
||||||
model.PALETTE = dataset.PALETTE
|
model.PALETTE = dataset.PALETTE
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user