[Enhancement]: Support opset_version 13 (#2071)
* upgrade to opset 13 * fix unsqueeze * fix mmseg yml * fix mmseg reg test * forcely change opset13 * fix mmdet3d * optimize squeeze * update base dockerfile * support squeeze/unsqueeze with axes as input in onnx2ncnn * update optimizer for squeeze/unsqueeze * revert * Revert "support squeeze/unsqueeze with axes as input in onnx2ncnn" This reverts commit 5ca9f1ae172cb4e1625f150ccb049138b5f37aa3. * fix docs * fix opsetpull/2095/head
parent
389a146212
commit
1c7749d17c
|
@ -108,3 +108,13 @@ RUN wget -c $TENSORRT_URL && \
|
|||
ENV TENSORRT_DIR=/root/workspace/TensorRT
|
||||
ENV LD_LIBRARY_PATH=$TENSORRT_DIR/lib:$LD_LIBRARY_PATH
|
||||
ENV PATH=$TENSORRT_DIR/bin:$PATH
|
||||
|
||||
# openvino
|
||||
RUN wget https://storage.openvinotoolkit.org/repositories/openvino/packages/2022.3/linux/l_openvino_toolkit_ubuntu20_2022.3.0.9052.9752fafe8eb_x86_64.tgz &&\
|
||||
tar -zxvf ./l_openvino_toolkit*.tgz &&\
|
||||
rm ./l_openvino_toolkit*.tgz &&\
|
||||
mv ./l_openvino* ./openvino_toolkit &&\
|
||||
bash ./openvino_toolkit/install_dependencies/install_openvino_dependencies.sh
|
||||
|
||||
ENV OPENVINO_DIR=/root/workspace/openvino_toolkit
|
||||
ENV InferenceEngine_DIR=$OPENVINO_DIR/runtime/cmake
|
||||
|
|
|
@ -1,15 +1,38 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
def update_squeeze_unsqueeze_opset13_pass(graph, params_dict, torch_out):
|
||||
"""Update Squeeze/Unsqueeze axes for opset13."""
|
||||
for node in graph.nodes():
|
||||
if node.kind() in ['onnx::Squeeze', 'onnx::Unsqueeze'] and \
|
||||
node.hasAttribute('axes'):
|
||||
axes = node['axes']
|
||||
axes_node = graph.create('onnx::Constant')
|
||||
axes_node.t_('value', torch.LongTensor(axes))
|
||||
node.removeAttribute('axes')
|
||||
node.addInput(axes_node.output())
|
||||
axes_node.insertBefore(node)
|
||||
return graph, params_dict, torch_out
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter('torch.onnx.utils._model_to_graph')
|
||||
def model_to_graph__custom_optimizer(*args, **kwargs):
|
||||
"""Rewriter of _model_to_graph, add custom passes."""
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
graph, params_dict, torch_out = ctx.origin_func(*args, **kwargs)
|
||||
|
||||
if hasattr(ctx, 'opset'):
|
||||
opset_version = ctx.opset
|
||||
else:
|
||||
from mmdeploy.utils import get_ir_config
|
||||
opset_version = get_ir_config(ctx.cfg).get('opset_version', 11)
|
||||
if opset_version >= 13:
|
||||
graph, params_dict, torch_out = update_squeeze_unsqueeze_opset13_pass(
|
||||
graph, params_dict, torch_out)
|
||||
custom_passes = getattr(ctx, 'onnx_custom_passes', None)
|
||||
|
||||
if custom_passes is not None:
|
||||
|
|
|
@ -18,6 +18,6 @@ def optimize_onnx(ctx, graph, params_dict, torch_out):
|
|||
logger.warning(
|
||||
'Can not optimize model, please build torchscipt extension.\n'
|
||||
'More details: '
|
||||
'https://github.com/open-mmlab/mmdeploy/tree/1.x/docs/en/experimental/onnx_optimizer.md' # noqa
|
||||
'https://github.com/open-mmlab/mmdeploy/tree/main/docs/en/experimental/onnx_optimizer.md' # noqa
|
||||
)
|
||||
return graph, params_dict, torch_out
|
||||
|
|
|
@ -48,8 +48,8 @@ class GridPriorsTRTOp(torch.autograd.Function):
|
|||
stride_w: int):
|
||||
"""Map ops to onnx symbolics."""
|
||||
# zero_h and zero_w is used to provide shape to GridPriorsTRT
|
||||
feat_h = g.op('Unsqueeze', feat_h, axes_i=[0])
|
||||
feat_w = g.op('Unsqueeze', feat_w, axes_i=[0])
|
||||
feat_h = symbolic_helper._unsqueeze_helper(g, feat_h, [0])
|
||||
feat_w = symbolic_helper._unsqueeze_helper(g, feat_w, [0])
|
||||
zero_h = g.op(
|
||||
'ConstantOfShape',
|
||||
feat_h,
|
||||
|
|
|
@ -92,14 +92,14 @@ class VoxelDetection(BaseTask):
|
|||
|
||||
def create_input(
|
||||
self,
|
||||
pcd: str,
|
||||
pcd: Union[str, Sequence[str]],
|
||||
input_shape: Sequence[int] = None,
|
||||
data_preprocessor: Optional[BaseDataPreprocessor] = None
|
||||
) -> Tuple[Dict, torch.Tensor]:
|
||||
"""Create input for detector.
|
||||
|
||||
Args:
|
||||
pcd (str): Input pcd file path.
|
||||
pcd (str, Sequence[str]): Input pcd file path.
|
||||
input_shape (Sequence[int], optional): model input shape.
|
||||
Defaults to None.
|
||||
data_preprocessor (Optional[BaseDataPreprocessor], optional):
|
||||
|
@ -115,7 +115,9 @@ class VoxelDetection(BaseTask):
|
|||
test_pipeline = Compose(test_pipeline)
|
||||
box_type_3d, box_mode_3d = \
|
||||
get_box_type(cfg.test_dataloader.dataset.box_type_3d)
|
||||
|
||||
# do not support batch inference
|
||||
if isinstance(pcd, (list, tuple)):
|
||||
pcd = pcd[0]
|
||||
data = []
|
||||
data_ = dict(
|
||||
lidar_points=dict(lidar_path=pcd),
|
||||
|
|
|
@ -122,7 +122,7 @@ torchscript:
|
|||
|
||||
models:
|
||||
- name: FCN
|
||||
metafile: configs/fcn/fcn.yml
|
||||
metafile: configs/fcn/metafile.yaml
|
||||
model_configs:
|
||||
- configs/fcn/fcn_r50-d8_4xb2-40k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
|
@ -134,7 +134,7 @@ models:
|
|||
- *pipeline_openvino_dynamic_fp32
|
||||
|
||||
- name: PSPNet
|
||||
metafile: configs/pspnet/pspnet.yml
|
||||
metafile: configs/pspnet/metafile.yaml
|
||||
model_configs:
|
||||
- configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
|
@ -146,7 +146,7 @@ models:
|
|||
- *pipeline_openvino_static_fp32
|
||||
|
||||
- name: deeplabv3
|
||||
metafile: configs/deeplabv3/deeplabv3.yml
|
||||
metafile: configs/deeplabv3/metafile.yaml
|
||||
model_configs:
|
||||
- configs/deeplabv3/deeplabv3_r50-d8_4xb2-40k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
|
@ -158,7 +158,7 @@ models:
|
|||
- *pipeline_openvino_dynamic_fp32
|
||||
|
||||
- name: deeplabv3+
|
||||
metafile: configs/deeplabv3plus/deeplabv3plus.yml
|
||||
metafile: configs/deeplabv3plus/metafile.yaml
|
||||
model_configs:
|
||||
- configs/deeplabv3plus/deeplabv3plus_r50-d8_4xb2-40k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
|
@ -170,7 +170,7 @@ models:
|
|||
- *pipeline_openvino_dynamic_fp32
|
||||
|
||||
- name: Fast-SCNN
|
||||
metafile: configs/fastscnn/fastscnn.yml
|
||||
metafile: configs/fastscnn/metafile.yaml
|
||||
model_configs:
|
||||
- configs/fastscnn/fast_scnn_8xb4-160k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
|
@ -181,7 +181,7 @@ models:
|
|||
- *pipeline_openvino_static_fp32
|
||||
|
||||
- name: UNet
|
||||
metafile: configs/unet/unet.yml
|
||||
metafile: configs/unet/metafile.yaml
|
||||
model_configs:
|
||||
- configs/unet/unet-s5-d16_fcn_4xb4-160k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
|
@ -192,7 +192,7 @@ models:
|
|||
- *pipeline_pplnn_dynamic_fp32
|
||||
|
||||
- name: ANN
|
||||
metafile: configs/ann/ann.yml
|
||||
metafile: configs/ann/metafile.yaml
|
||||
model_configs:
|
||||
- configs/ann/ann_r50-d8_4xb2-40k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
|
@ -201,7 +201,7 @@ models:
|
|||
- *pipeline_ts_fp32
|
||||
|
||||
- name: APCNet
|
||||
metafile: configs/apcnet/apcnet.yml
|
||||
metafile: configs/apcnet/metafile.yaml
|
||||
model_configs:
|
||||
- configs/apcnet/apcnet_r50-d8_4xb2-40k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
|
@ -211,7 +211,7 @@ models:
|
|||
- *pipeline_ts_fp32
|
||||
|
||||
- name: BiSeNetV1
|
||||
metafile: configs/bisenetv1/bisenetv1.yml
|
||||
metafile: configs/bisenetv1/metafile.yaml
|
||||
model_configs:
|
||||
- configs/bisenetv1/bisenetv1_r18-d32_4xb4-160k_cityscapes-1024x1024.py
|
||||
pipelines:
|
||||
|
@ -222,7 +222,7 @@ models:
|
|||
- *pipeline_ts_fp32
|
||||
|
||||
- name: BiSeNetV2
|
||||
metafile: configs/bisenetv2/bisenetv2.yml
|
||||
metafile: configs/bisenetv2/metafile.yaml
|
||||
model_configs:
|
||||
- configs/bisenetv2/bisenetv2_fcn_4xb4-160k_cityscapes-1024x1024.py
|
||||
pipelines:
|
||||
|
@ -233,7 +233,7 @@ models:
|
|||
- *pipeline_ts_fp32
|
||||
|
||||
- name: CGNet
|
||||
metafile: configs/cgnet/cgnet.yml
|
||||
metafile: configs/cgnet/metafile.yaml
|
||||
model_configs:
|
||||
- configs/cgnet/cgnet_fcn_4xb8-60k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
|
@ -244,7 +244,7 @@ models:
|
|||
- *pipeline_ts_fp32
|
||||
|
||||
- name: EMANet
|
||||
metafile: configs/emanet/emanet.yml
|
||||
metafile: configs/emanet/metafile.yaml
|
||||
model_configs:
|
||||
- configs/emanet/emanet_r50-d8_4xb2-80k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
|
@ -254,7 +254,7 @@ models:
|
|||
- *pipeline_ts_fp32
|
||||
|
||||
- name: EncNet
|
||||
metafile: configs/encnet/encnet.yml
|
||||
metafile: configs/encnet/metafile.yaml
|
||||
model_configs:
|
||||
- configs/encnet/encnet_r50-d8_4xb2-40k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
|
@ -264,7 +264,7 @@ models:
|
|||
- *pipeline_ts_fp32
|
||||
|
||||
- name: ERFNet
|
||||
metafile: configs/erfnet/erfnet.yml
|
||||
metafile: configs/erfnet/metafile.yaml
|
||||
model_configs:
|
||||
- configs/erfnet/erfnet_fcn_4xb4-160k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
|
@ -275,7 +275,7 @@ models:
|
|||
- *pipeline_ts_fp32
|
||||
|
||||
- name: FastFCN
|
||||
metafile: configs/fastfcn/fastfcn.yml
|
||||
metafile: configs/fastfcn/metafile.yaml
|
||||
model_configs:
|
||||
- configs/fastfcn/fastfcn_r50-d32_jpu_aspp_4xb2-80k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
|
@ -286,7 +286,7 @@ models:
|
|||
- *pipeline_ts_fp32
|
||||
|
||||
- name: GCNet
|
||||
metafile: configs/gcnet/gcnet.yml
|
||||
metafile: configs/gcnet/metafile.yaml
|
||||
model_configs:
|
||||
- configs/gcnet/gcnet_r50-d8_4xb2-40k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
|
@ -295,7 +295,7 @@ models:
|
|||
- *pipeline_ts_fp32
|
||||
|
||||
- name: ICNet
|
||||
metafile: configs/icnet/icnet.yml
|
||||
metafile: configs/icnet/metafile.yaml
|
||||
model_configs:
|
||||
- configs/icnet/icnet_r18-d8_4xb2-80k_cityscapes-832x832.py
|
||||
pipelines:
|
||||
|
@ -305,7 +305,7 @@ models:
|
|||
- *pipeline_ts_fp32
|
||||
|
||||
- name: ISANet
|
||||
metafile: configs/isanet/isanet.yml
|
||||
metafile: configs/isanet/metafile.yaml
|
||||
model_configs:
|
||||
- configs/isanet/isanet_r50-d8_4xb2-40k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
|
@ -314,7 +314,7 @@ models:
|
|||
- *pipeline_openvino_static_fp32_512x512
|
||||
|
||||
- name: OCRNet
|
||||
metafile: configs/ocrnet/ocrnet.yml
|
||||
metafile: configs/ocrnet/metafile.yaml
|
||||
model_configs:
|
||||
- configs/ocrnet/ocrnet_hr18s_4xb2-40k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
|
@ -325,7 +325,7 @@ models:
|
|||
- *pipeline_ts_fp32
|
||||
|
||||
- name: PointRend
|
||||
metafile: configs/point_rend/point_rend.yml
|
||||
metafile: configs/point_rend/metafile.yaml
|
||||
model_configs:
|
||||
- configs/point_rend/pointrend_r50_4xb2-80k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
|
@ -334,7 +334,7 @@ models:
|
|||
- *pipeline_ts_fp32
|
||||
|
||||
- name: Semantic FPN
|
||||
metafile: configs/sem_fpn/sem_fpn.yml
|
||||
metafile: configs/sem_fpn/metafile.yaml
|
||||
model_configs:
|
||||
- configs/sem_fpn/fpn_r50_4xb2-80k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
|
@ -345,7 +345,7 @@ models:
|
|||
- *pipeline_ts_fp32
|
||||
|
||||
- name: STDC
|
||||
metafile: configs/stdc/stdc.yml
|
||||
metafile: configs/stdc/metafile.yaml
|
||||
model_configs:
|
||||
- configs/stdc/stdc1_in1k-pre_4xb12-80k_cityscapes-512x1024.py
|
||||
- configs/stdc/stdc2_in1k-pre_4xb12-80k_cityscapes-512x1024.py
|
||||
|
@ -357,14 +357,14 @@ models:
|
|||
- *pipeline_ts_fp32
|
||||
|
||||
- name: UPerNet
|
||||
metafile: configs/upernet/upernet.yml
|
||||
metafile: configs/upernet/metafile.yaml
|
||||
model_configs:
|
||||
- configs/upernet/upernet_r50_4xb2-40k_cityscapes-512x1024.py
|
||||
pipelines:
|
||||
- *pipeline_ort_static_fp32
|
||||
- *pipeline_trt_static_fp16
|
||||
- name: Segmenter
|
||||
metafile: configs/segmenter/segmenter.yml
|
||||
metafile: configs/segmenter/metafile.yaml
|
||||
model_configs:
|
||||
- configs/segmenter/segmenter_vit-s_fcn_8xb1-160k_ade20k-512x512.py
|
||||
pipelines:
|
||||
|
|
|
@ -302,6 +302,9 @@ def get_pytorch_result(model_name: str, meta_info: dict, checkpoint_path: Path,
|
|||
# get metric
|
||||
model_info = meta_info[model_config_name]
|
||||
metafile_metric_info = model_info['Results']
|
||||
# deal with mmseg case
|
||||
if not isinstance(metafile_metric_info, (list, tuple)):
|
||||
metafile_metric_info = [metafile_metric_info]
|
||||
pytorch_metric = dict()
|
||||
using_dataset = set()
|
||||
using_task = set()
|
||||
|
|
Loading…
Reference in New Issue