remove imports (#1207)

* remove imports

* update doc

* detailed docstring

* rephrase
pull/1239/head
AllentDan 2022-10-24 10:45:52 +08:00 committed by GitHub
parent 4c872a41c3
commit 114b0b8238
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 33 additions and 5 deletions

View File

@ -21,6 +21,8 @@ Please refer to [install.md](https://mmocr.readthedocs.io/en/latest/install.html
Note that ncnn, pplnn, and OpenVINO only support the configs of DBNet18 for DBNet.
For CRNN models with TensorRT-int8 backend, we recommend TensorRT 7.2.3.4 and CUDA 10.2.
For the PANet with the [checkpoint](https://download.openmmlab.com/mmocr/textdet/panet/panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth) pretrained on ICDAR dataset, if you want to convert the model to TensorRT with 16 bits float point, please try the following script.
```python

View File

@ -21,6 +21,8 @@ mmocr 是一个基于 PyTorch 和 mmdetection 的开源工具箱,用于文本
请注意ncnn、pplnn 和 OpenVINO 仅支持 DBNet 的 DBNet18 配置。
CRNN 模型的 TensorRT int8量化只在 TensorRT 7.2.3.4 和 CUDA10.2下测试可用。
对于在 ICDAR 数据集上预训 [checkpoint](https://download.openmmlab.com/mmocr/textdet/panet/panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth) 的 PANet如果要将模型转为具有 fp16 TensorRT请尝试以下脚本。
```python

View File

@ -8,7 +8,8 @@ import onnx.utils
from mmdeploy.apis.core import PIPELINE_MANAGER
from mmdeploy.core.optimizers import (attribute_to_dict, create_extractor,
get_new_name, parse_extractor_io_string,
remove_identity, rename_value)
remove_identity, remove_imports,
rename_value)
from mmdeploy.utils import get_root_logger
@ -198,6 +199,9 @@ def extract_partition(model: Union[str, onnx.ModelProto],
dim.dim_value = 0
dim.dim_param = f'dim_{idx}'
# remove mmdeploy domain if useless
remove_imports(extracted_model)
# save extract_model if save_file is given
if save_file is not None:
onnx.save(extracted_model, save_file)

View File

@ -42,8 +42,7 @@ def process_model_config(model_cfg: mmcv.Config,
transforms = cfg.data.test.pipeline[1]['transforms']
for trans in transforms:
trans_type = trans['type']
if trans_type == 'Resize' and len(
input_shape) != 1 and input_shape[0] != input_shape[1]:
if trans_type == 'Resize' and len(input_shape) != 1:
trans['keep_ratio'] = False
elif trans_type == 'Pad':
if 'size_divisor' in trans:

View File

@ -2,10 +2,10 @@
from .extractor import create_extractor, parse_extractor_io_string
from .function_marker import mark, reset_mark_function_count
from .optimize import (attribute_to_dict, get_new_name, remove_identity,
rename_value)
remove_imports, rename_value)
__all__ = [
'mark', 'reset_mark_function_count', 'create_extractor',
'parse_extractor_io_string', 'remove_identity', 'attribute_to_dict',
'rename_value', 'get_new_name'
'rename_value', 'get_new_name', 'remove_imports'
]

View File

@ -206,3 +206,24 @@ def remove_identity(model: onnx.ModelProto):
pass
remove_nodes(model, is_identity)
def remove_imports(model: onnx.ModelProto):
"""Remove useless imports from an ONNX model.
The domain like `mmdeploy` might influence model conversion for
some backends.
Args:
model (onnx.ModelProto): Input onnx model.
"""
logger = get_root_logger()
dst_domain = ['']
for node in model.graph.node:
if hasattr(node, 'module') and (node.module not in dst_domain):
dst_domain.append(node.module)
src_domains = [oi.domain for oi in model.opset_import]
for i, src_domain in enumerate(src_domains):
if src_domain not in dst_domain:
logger.info(f'remove opset_import {src_domain}')
model.opset_import.pop(i)