[Enhancement]: Support opset_version 13 (#2071)

* upgrade to opset 13

* fix unsqueeze

* fix mmseg yml

* fix mmseg reg test

* forcely change opset13

* fix mmdet3d

* optimize squeeze

* update base dockerfile

* support squeeze/unsqueeze with axes as input in onnx2ncnn

* update optimizer for squeeze/unsqueeze

* revert

* Revert "support squeeze/unsqueeze with axes as input in onnx2ncnn"

This reverts commit 5ca9f1ae172cb4e1625f150ccb049138b5f37aa3.

* fix docs

* fix opset
This commit is contained in:
RunningLeon 2023-05-17 11:02:30 +08:00 committed by GitHub
parent 389a146212
commit 1c7749d17c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 69 additions and 31 deletions

View File

@ -108,3 +108,13 @@ RUN wget -c $TENSORRT_URL && \
ENV TENSORRT_DIR=/root/workspace/TensorRT ENV TENSORRT_DIR=/root/workspace/TensorRT
ENV LD_LIBRARY_PATH=$TENSORRT_DIR/lib:$LD_LIBRARY_PATH ENV LD_LIBRARY_PATH=$TENSORRT_DIR/lib:$LD_LIBRARY_PATH
ENV PATH=$TENSORRT_DIR/bin:$PATH ENV PATH=$TENSORRT_DIR/bin:$PATH
# openvino
RUN wget https://storage.openvinotoolkit.org/repositories/openvino/packages/2022.3/linux/l_openvino_toolkit_ubuntu20_2022.3.0.9052.9752fafe8eb_x86_64.tgz &&\
tar -zxvf ./l_openvino_toolkit*.tgz &&\
rm ./l_openvino_toolkit*.tgz &&\
mv ./l_openvino* ./openvino_toolkit &&\
bash ./openvino_toolkit/install_dependencies/install_openvino_dependencies.sh
ENV OPENVINO_DIR=/root/workspace/openvino_toolkit
ENV InferenceEngine_DIR=$OPENVINO_DIR/runtime/cmake

View File

@ -1,15 +1,38 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable from typing import Callable
import torch
from mmdeploy.core import FUNCTION_REWRITER from mmdeploy.core import FUNCTION_REWRITER
def update_squeeze_unsqueeze_opset13_pass(graph, params_dict, torch_out):
"""Update Squeeze/Unsqueeze axes for opset13."""
for node in graph.nodes():
if node.kind() in ['onnx::Squeeze', 'onnx::Unsqueeze'] and \
node.hasAttribute('axes'):
axes = node['axes']
axes_node = graph.create('onnx::Constant')
axes_node.t_('value', torch.LongTensor(axes))
node.removeAttribute('axes')
node.addInput(axes_node.output())
axes_node.insertBefore(node)
return graph, params_dict, torch_out
@FUNCTION_REWRITER.register_rewriter('torch.onnx.utils._model_to_graph') @FUNCTION_REWRITER.register_rewriter('torch.onnx.utils._model_to_graph')
def model_to_graph__custom_optimizer(*args, **kwargs): def model_to_graph__custom_optimizer(*args, **kwargs):
"""Rewriter of _model_to_graph, add custom passes.""" """Rewriter of _model_to_graph, add custom passes."""
ctx = FUNCTION_REWRITER.get_context() ctx = FUNCTION_REWRITER.get_context()
graph, params_dict, torch_out = ctx.origin_func(*args, **kwargs) graph, params_dict, torch_out = ctx.origin_func(*args, **kwargs)
if hasattr(ctx, 'opset'):
opset_version = ctx.opset
else:
from mmdeploy.utils import get_ir_config
opset_version = get_ir_config(ctx.cfg).get('opset_version', 11)
if opset_version >= 13:
graph, params_dict, torch_out = update_squeeze_unsqueeze_opset13_pass(
graph, params_dict, torch_out)
custom_passes = getattr(ctx, 'onnx_custom_passes', None) custom_passes = getattr(ctx, 'onnx_custom_passes', None)
if custom_passes is not None: if custom_passes is not None:

View File

@ -18,6 +18,6 @@ def optimize_onnx(ctx, graph, params_dict, torch_out):
logger.warning( logger.warning(
'Can not optimize model, please build torchscipt extension.\n' 'Can not optimize model, please build torchscipt extension.\n'
'More details: ' 'More details: '
'https://github.com/open-mmlab/mmdeploy/tree/1.x/docs/en/experimental/onnx_optimizer.md' # noqa 'https://github.com/open-mmlab/mmdeploy/tree/main/docs/en/experimental/onnx_optimizer.md' # noqa
) )
return graph, params_dict, torch_out return graph, params_dict, torch_out

View File

@ -48,8 +48,8 @@ class GridPriorsTRTOp(torch.autograd.Function):
stride_w: int): stride_w: int):
"""Map ops to onnx symbolics.""" """Map ops to onnx symbolics."""
# zero_h and zero_w is used to provide shape to GridPriorsTRT # zero_h and zero_w is used to provide shape to GridPriorsTRT
feat_h = g.op('Unsqueeze', feat_h, axes_i=[0]) feat_h = symbolic_helper._unsqueeze_helper(g, feat_h, [0])
feat_w = g.op('Unsqueeze', feat_w, axes_i=[0]) feat_w = symbolic_helper._unsqueeze_helper(g, feat_w, [0])
zero_h = g.op( zero_h = g.op(
'ConstantOfShape', 'ConstantOfShape',
feat_h, feat_h,

View File

@ -92,14 +92,14 @@ class VoxelDetection(BaseTask):
def create_input( def create_input(
self, self,
pcd: str, pcd: Union[str, Sequence[str]],
input_shape: Sequence[int] = None, input_shape: Sequence[int] = None,
data_preprocessor: Optional[BaseDataPreprocessor] = None data_preprocessor: Optional[BaseDataPreprocessor] = None
) -> Tuple[Dict, torch.Tensor]: ) -> Tuple[Dict, torch.Tensor]:
"""Create input for detector. """Create input for detector.
Args: Args:
pcd (str): Input pcd file path. pcd (str, Sequence[str]): Input pcd file path.
input_shape (Sequence[int], optional): model input shape. input_shape (Sequence[int], optional): model input shape.
Defaults to None. Defaults to None.
data_preprocessor (Optional[BaseDataPreprocessor], optional): data_preprocessor (Optional[BaseDataPreprocessor], optional):
@ -115,7 +115,9 @@ class VoxelDetection(BaseTask):
test_pipeline = Compose(test_pipeline) test_pipeline = Compose(test_pipeline)
box_type_3d, box_mode_3d = \ box_type_3d, box_mode_3d = \
get_box_type(cfg.test_dataloader.dataset.box_type_3d) get_box_type(cfg.test_dataloader.dataset.box_type_3d)
# do not support batch inference
if isinstance(pcd, (list, tuple)):
pcd = pcd[0]
data = [] data = []
data_ = dict( data_ = dict(
lidar_points=dict(lidar_path=pcd), lidar_points=dict(lidar_path=pcd),

View File

@ -122,7 +122,7 @@ torchscript:
models: models:
- name: FCN - name: FCN
metafile: configs/fcn/fcn.yml metafile: configs/fcn/metafile.yaml
model_configs: model_configs:
- configs/fcn/fcn_r50-d8_4xb2-40k_cityscapes-512x1024.py - configs/fcn/fcn_r50-d8_4xb2-40k_cityscapes-512x1024.py
pipelines: pipelines:
@ -134,7 +134,7 @@ models:
- *pipeline_openvino_dynamic_fp32 - *pipeline_openvino_dynamic_fp32
- name: PSPNet - name: PSPNet
metafile: configs/pspnet/pspnet.yml metafile: configs/pspnet/metafile.yaml
model_configs: model_configs:
- configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py - configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py
pipelines: pipelines:
@ -146,7 +146,7 @@ models:
- *pipeline_openvino_static_fp32 - *pipeline_openvino_static_fp32
- name: deeplabv3 - name: deeplabv3
metafile: configs/deeplabv3/deeplabv3.yml metafile: configs/deeplabv3/metafile.yaml
model_configs: model_configs:
- configs/deeplabv3/deeplabv3_r50-d8_4xb2-40k_cityscapes-512x1024.py - configs/deeplabv3/deeplabv3_r50-d8_4xb2-40k_cityscapes-512x1024.py
pipelines: pipelines:
@ -158,7 +158,7 @@ models:
- *pipeline_openvino_dynamic_fp32 - *pipeline_openvino_dynamic_fp32
- name: deeplabv3+ - name: deeplabv3+
metafile: configs/deeplabv3plus/deeplabv3plus.yml metafile: configs/deeplabv3plus/metafile.yaml
model_configs: model_configs:
- configs/deeplabv3plus/deeplabv3plus_r50-d8_4xb2-40k_cityscapes-512x1024.py - configs/deeplabv3plus/deeplabv3plus_r50-d8_4xb2-40k_cityscapes-512x1024.py
pipelines: pipelines:
@ -170,7 +170,7 @@ models:
- *pipeline_openvino_dynamic_fp32 - *pipeline_openvino_dynamic_fp32
- name: Fast-SCNN - name: Fast-SCNN
metafile: configs/fastscnn/fastscnn.yml metafile: configs/fastscnn/metafile.yaml
model_configs: model_configs:
- configs/fastscnn/fast_scnn_8xb4-160k_cityscapes-512x1024.py - configs/fastscnn/fast_scnn_8xb4-160k_cityscapes-512x1024.py
pipelines: pipelines:
@ -181,7 +181,7 @@ models:
- *pipeline_openvino_static_fp32 - *pipeline_openvino_static_fp32
- name: UNet - name: UNet
metafile: configs/unet/unet.yml metafile: configs/unet/metafile.yaml
model_configs: model_configs:
- configs/unet/unet-s5-d16_fcn_4xb4-160k_cityscapes-512x1024.py - configs/unet/unet-s5-d16_fcn_4xb4-160k_cityscapes-512x1024.py
pipelines: pipelines:
@ -192,7 +192,7 @@ models:
- *pipeline_pplnn_dynamic_fp32 - *pipeline_pplnn_dynamic_fp32
- name: ANN - name: ANN
metafile: configs/ann/ann.yml metafile: configs/ann/metafile.yaml
model_configs: model_configs:
- configs/ann/ann_r50-d8_4xb2-40k_cityscapes-512x1024.py - configs/ann/ann_r50-d8_4xb2-40k_cityscapes-512x1024.py
pipelines: pipelines:
@ -201,7 +201,7 @@ models:
- *pipeline_ts_fp32 - *pipeline_ts_fp32
- name: APCNet - name: APCNet
metafile: configs/apcnet/apcnet.yml metafile: configs/apcnet/metafile.yaml
model_configs: model_configs:
- configs/apcnet/apcnet_r50-d8_4xb2-40k_cityscapes-512x1024.py - configs/apcnet/apcnet_r50-d8_4xb2-40k_cityscapes-512x1024.py
pipelines: pipelines:
@ -211,7 +211,7 @@ models:
- *pipeline_ts_fp32 - *pipeline_ts_fp32
- name: BiSeNetV1 - name: BiSeNetV1
metafile: configs/bisenetv1/bisenetv1.yml metafile: configs/bisenetv1/metafile.yaml
model_configs: model_configs:
- configs/bisenetv1/bisenetv1_r18-d32_4xb4-160k_cityscapes-1024x1024.py - configs/bisenetv1/bisenetv1_r18-d32_4xb4-160k_cityscapes-1024x1024.py
pipelines: pipelines:
@ -222,7 +222,7 @@ models:
- *pipeline_ts_fp32 - *pipeline_ts_fp32
- name: BiSeNetV2 - name: BiSeNetV2
metafile: configs/bisenetv2/bisenetv2.yml metafile: configs/bisenetv2/metafile.yaml
model_configs: model_configs:
- configs/bisenetv2/bisenetv2_fcn_4xb4-160k_cityscapes-1024x1024.py - configs/bisenetv2/bisenetv2_fcn_4xb4-160k_cityscapes-1024x1024.py
pipelines: pipelines:
@ -233,7 +233,7 @@ models:
- *pipeline_ts_fp32 - *pipeline_ts_fp32
- name: CGNet - name: CGNet
metafile: configs/cgnet/cgnet.yml metafile: configs/cgnet/metafile.yaml
model_configs: model_configs:
- configs/cgnet/cgnet_fcn_4xb8-60k_cityscapes-512x1024.py - configs/cgnet/cgnet_fcn_4xb8-60k_cityscapes-512x1024.py
pipelines: pipelines:
@ -244,7 +244,7 @@ models:
- *pipeline_ts_fp32 - *pipeline_ts_fp32
- name: EMANet - name: EMANet
metafile: configs/emanet/emanet.yml metafile: configs/emanet/metafile.yaml
model_configs: model_configs:
- configs/emanet/emanet_r50-d8_4xb2-80k_cityscapes-512x1024.py - configs/emanet/emanet_r50-d8_4xb2-80k_cityscapes-512x1024.py
pipelines: pipelines:
@ -254,7 +254,7 @@ models:
- *pipeline_ts_fp32 - *pipeline_ts_fp32
- name: EncNet - name: EncNet
metafile: configs/encnet/encnet.yml metafile: configs/encnet/metafile.yaml
model_configs: model_configs:
- configs/encnet/encnet_r50-d8_4xb2-40k_cityscapes-512x1024.py - configs/encnet/encnet_r50-d8_4xb2-40k_cityscapes-512x1024.py
pipelines: pipelines:
@ -264,7 +264,7 @@ models:
- *pipeline_ts_fp32 - *pipeline_ts_fp32
- name: ERFNet - name: ERFNet
metafile: configs/erfnet/erfnet.yml metafile: configs/erfnet/metafile.yaml
model_configs: model_configs:
- configs/erfnet/erfnet_fcn_4xb4-160k_cityscapes-512x1024.py - configs/erfnet/erfnet_fcn_4xb4-160k_cityscapes-512x1024.py
pipelines: pipelines:
@ -275,7 +275,7 @@ models:
- *pipeline_ts_fp32 - *pipeline_ts_fp32
- name: FastFCN - name: FastFCN
metafile: configs/fastfcn/fastfcn.yml metafile: configs/fastfcn/metafile.yaml
model_configs: model_configs:
- configs/fastfcn/fastfcn_r50-d32_jpu_aspp_4xb2-80k_cityscapes-512x1024.py - configs/fastfcn/fastfcn_r50-d32_jpu_aspp_4xb2-80k_cityscapes-512x1024.py
pipelines: pipelines:
@ -286,7 +286,7 @@ models:
- *pipeline_ts_fp32 - *pipeline_ts_fp32
- name: GCNet - name: GCNet
metafile: configs/gcnet/gcnet.yml metafile: configs/gcnet/metafile.yaml
model_configs: model_configs:
- configs/gcnet/gcnet_r50-d8_4xb2-40k_cityscapes-512x1024.py - configs/gcnet/gcnet_r50-d8_4xb2-40k_cityscapes-512x1024.py
pipelines: pipelines:
@ -295,7 +295,7 @@ models:
- *pipeline_ts_fp32 - *pipeline_ts_fp32
- name: ICNet - name: ICNet
metafile: configs/icnet/icnet.yml metafile: configs/icnet/metafile.yaml
model_configs: model_configs:
- configs/icnet/icnet_r18-d8_4xb2-80k_cityscapes-832x832.py - configs/icnet/icnet_r18-d8_4xb2-80k_cityscapes-832x832.py
pipelines: pipelines:
@ -305,7 +305,7 @@ models:
- *pipeline_ts_fp32 - *pipeline_ts_fp32
- name: ISANet - name: ISANet
metafile: configs/isanet/isanet.yml metafile: configs/isanet/metafile.yaml
model_configs: model_configs:
- configs/isanet/isanet_r50-d8_4xb2-40k_cityscapes-512x1024.py - configs/isanet/isanet_r50-d8_4xb2-40k_cityscapes-512x1024.py
pipelines: pipelines:
@ -314,7 +314,7 @@ models:
- *pipeline_openvino_static_fp32_512x512 - *pipeline_openvino_static_fp32_512x512
- name: OCRNet - name: OCRNet
metafile: configs/ocrnet/ocrnet.yml metafile: configs/ocrnet/metafile.yaml
model_configs: model_configs:
- configs/ocrnet/ocrnet_hr18s_4xb2-40k_cityscapes-512x1024.py - configs/ocrnet/ocrnet_hr18s_4xb2-40k_cityscapes-512x1024.py
pipelines: pipelines:
@ -325,7 +325,7 @@ models:
- *pipeline_ts_fp32 - *pipeline_ts_fp32
- name: PointRend - name: PointRend
metafile: configs/point_rend/point_rend.yml metafile: configs/point_rend/metafile.yaml
model_configs: model_configs:
- configs/point_rend/pointrend_r50_4xb2-80k_cityscapes-512x1024.py - configs/point_rend/pointrend_r50_4xb2-80k_cityscapes-512x1024.py
pipelines: pipelines:
@ -334,7 +334,7 @@ models:
- *pipeline_ts_fp32 - *pipeline_ts_fp32
- name: Semantic FPN - name: Semantic FPN
metafile: configs/sem_fpn/sem_fpn.yml metafile: configs/sem_fpn/metafile.yaml
model_configs: model_configs:
- configs/sem_fpn/fpn_r50_4xb2-80k_cityscapes-512x1024.py - configs/sem_fpn/fpn_r50_4xb2-80k_cityscapes-512x1024.py
pipelines: pipelines:
@ -345,7 +345,7 @@ models:
- *pipeline_ts_fp32 - *pipeline_ts_fp32
- name: STDC - name: STDC
metafile: configs/stdc/stdc.yml metafile: configs/stdc/metafile.yaml
model_configs: model_configs:
- configs/stdc/stdc1_in1k-pre_4xb12-80k_cityscapes-512x1024.py - configs/stdc/stdc1_in1k-pre_4xb12-80k_cityscapes-512x1024.py
- configs/stdc/stdc2_in1k-pre_4xb12-80k_cityscapes-512x1024.py - configs/stdc/stdc2_in1k-pre_4xb12-80k_cityscapes-512x1024.py
@ -357,14 +357,14 @@ models:
- *pipeline_ts_fp32 - *pipeline_ts_fp32
- name: UPerNet - name: UPerNet
metafile: configs/upernet/upernet.yml metafile: configs/upernet/metafile.yaml
model_configs: model_configs:
- configs/upernet/upernet_r50_4xb2-40k_cityscapes-512x1024.py - configs/upernet/upernet_r50_4xb2-40k_cityscapes-512x1024.py
pipelines: pipelines:
- *pipeline_ort_static_fp32 - *pipeline_ort_static_fp32
- *pipeline_trt_static_fp16 - *pipeline_trt_static_fp16
- name: Segmenter - name: Segmenter
metafile: configs/segmenter/segmenter.yml metafile: configs/segmenter/metafile.yaml
model_configs: model_configs:
- configs/segmenter/segmenter_vit-s_fcn_8xb1-160k_ade20k-512x512.py - configs/segmenter/segmenter_vit-s_fcn_8xb1-160k_ade20k-512x512.py
pipelines: pipelines:

View File

@ -302,6 +302,9 @@ def get_pytorch_result(model_name: str, meta_info: dict, checkpoint_path: Path,
# get metric # get metric
model_info = meta_info[model_config_name] model_info = meta_info[model_config_name]
metafile_metric_info = model_info['Results'] metafile_metric_info = model_info['Results']
# deal with mmseg case
if not isinstance(metafile_metric_info, (list, tuple)):
metafile_metric_info = [metafile_metric_info]
pytorch_metric = dict() pytorch_metric = dict()
using_dataset = set() using_dataset = set()
using_task = set() using_task = set()