remove imports (#1207)

* remove imports

* update doc

* detailed docstring

* rephrase
This commit is contained in:
AllentDan 2022-10-24 10:45:52 +08:00 committed by GitHub
parent 4c872a41c3
commit 114b0b8238
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 33 additions and 5 deletions

View File

@ -21,6 +21,8 @@ Please refer to [install.md](https://mmocr.readthedocs.io/en/latest/install.html
Note that ncnn, pplnn, and OpenVINO only support the configs of DBNet18 for DBNet. Note that ncnn, pplnn, and OpenVINO only support the configs of DBNet18 for DBNet.
For CRNN models with TensorRT-int8 backend, we recommend TensorRT 7.2.3.4 and CUDA 10.2.
For the PANet with the [checkpoint](https://download.openmmlab.com/mmocr/textdet/panet/panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth) pretrained on ICDAR dataset, if you want to convert the model to TensorRT with 16 bits float point, please try the following script. For the PANet with the [checkpoint](https://download.openmmlab.com/mmocr/textdet/panet/panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth) pretrained on ICDAR dataset, if you want to convert the model to TensorRT with 16 bits float point, please try the following script.
```python ```python

View File

@ -21,6 +21,8 @@ mmocr 是一个基于 PyTorch 和 mmdetection 的开源工具箱,用于文本
请注意ncnn、pplnn 和 OpenVINO 仅支持 DBNet 的 DBNet18 配置。 请注意ncnn、pplnn 和 OpenVINO 仅支持 DBNet 的 DBNet18 配置。
CRNN 模型的 TensorRT int8量化只在 TensorRT 7.2.3.4 和 CUDA10.2下测试可用。
对于在 ICDAR 数据集上预训 [checkpoint](https://download.openmmlab.com/mmocr/textdet/panet/panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth) 的 PANet如果要将模型转为具有 fp16 TensorRT请尝试以下脚本。 对于在 ICDAR 数据集上预训 [checkpoint](https://download.openmmlab.com/mmocr/textdet/panet/panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth) 的 PANet如果要将模型转为具有 fp16 TensorRT请尝试以下脚本。
```python ```python

View File

@ -8,7 +8,8 @@ import onnx.utils
from mmdeploy.apis.core import PIPELINE_MANAGER from mmdeploy.apis.core import PIPELINE_MANAGER
from mmdeploy.core.optimizers import (attribute_to_dict, create_extractor, from mmdeploy.core.optimizers import (attribute_to_dict, create_extractor,
get_new_name, parse_extractor_io_string, get_new_name, parse_extractor_io_string,
remove_identity, rename_value) remove_identity, remove_imports,
rename_value)
from mmdeploy.utils import get_root_logger from mmdeploy.utils import get_root_logger
@ -198,6 +199,9 @@ def extract_partition(model: Union[str, onnx.ModelProto],
dim.dim_value = 0 dim.dim_value = 0
dim.dim_param = f'dim_{idx}' dim.dim_param = f'dim_{idx}'
# remove mmdeploy domain if useless
remove_imports(extracted_model)
# save extract_model if save_file is given # save extract_model if save_file is given
if save_file is not None: if save_file is not None:
onnx.save(extracted_model, save_file) onnx.save(extracted_model, save_file)

View File

@ -42,8 +42,7 @@ def process_model_config(model_cfg: mmcv.Config,
transforms = cfg.data.test.pipeline[1]['transforms'] transforms = cfg.data.test.pipeline[1]['transforms']
for trans in transforms: for trans in transforms:
trans_type = trans['type'] trans_type = trans['type']
if trans_type == 'Resize' and len( if trans_type == 'Resize' and len(input_shape) != 1:
input_shape) != 1 and input_shape[0] != input_shape[1]:
trans['keep_ratio'] = False trans['keep_ratio'] = False
elif trans_type == 'Pad': elif trans_type == 'Pad':
if 'size_divisor' in trans: if 'size_divisor' in trans:

View File

@ -2,10 +2,10 @@
from .extractor import create_extractor, parse_extractor_io_string from .extractor import create_extractor, parse_extractor_io_string
from .function_marker import mark, reset_mark_function_count from .function_marker import mark, reset_mark_function_count
from .optimize import (attribute_to_dict, get_new_name, remove_identity, from .optimize import (attribute_to_dict, get_new_name, remove_identity,
rename_value) remove_imports, rename_value)
__all__ = [ __all__ = [
'mark', 'reset_mark_function_count', 'create_extractor', 'mark', 'reset_mark_function_count', 'create_extractor',
'parse_extractor_io_string', 'remove_identity', 'attribute_to_dict', 'parse_extractor_io_string', 'remove_identity', 'attribute_to_dict',
'rename_value', 'get_new_name' 'rename_value', 'get_new_name', 'remove_imports'
] ]

View File

@ -206,3 +206,24 @@ def remove_identity(model: onnx.ModelProto):
pass pass
remove_nodes(model, is_identity) remove_nodes(model, is_identity)
def remove_imports(model: onnx.ModelProto):
"""Remove useless imports from an ONNX model.
The domain like `mmdeploy` might influence model conversion for
some backends.
Args:
model (onnx.ModelProto): Input onnx model.
"""
logger = get_root_logger()
dst_domain = ['']
for node in model.graph.node:
if hasattr(node, 'module') and (node.module not in dst_domain):
dst_domain.append(node.module)
src_domains = [oi.domain for oi in model.opset_import]
for i, src_domain in enumerate(src_domains):
if src_domain not in dst_domain:
logger.info(f'remove opset_import {src_domain}')
model.opset_import.pop(i)