mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
remove imports (#1207)
* remove imports * update doc * detailed docstring * rephrase
This commit is contained in:
parent
4c872a41c3
commit
114b0b8238
@ -21,6 +21,8 @@ Please refer to [install.md](https://mmocr.readthedocs.io/en/latest/install.html
|
|||||||
|
|
||||||
Note that ncnn, pplnn, and OpenVINO only support the configs of DBNet18 for DBNet.
|
Note that ncnn, pplnn, and OpenVINO only support the configs of DBNet18 for DBNet.
|
||||||
|
|
||||||
|
For CRNN models with TensorRT-int8 backend, we recommend TensorRT 7.2.3.4 and CUDA 10.2.
|
||||||
|
|
||||||
For the PANet with the [checkpoint](https://download.openmmlab.com/mmocr/textdet/panet/panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth) pretrained on ICDAR dataset, if you want to convert the model to TensorRT with 16 bits float point, please try the following script.
|
For the PANet with the [checkpoint](https://download.openmmlab.com/mmocr/textdet/panet/panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth) pretrained on ICDAR dataset, if you want to convert the model to TensorRT with 16 bits float point, please try the following script.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
@ -21,6 +21,8 @@ mmocr 是一个基于 PyTorch 和 mmdetection 的开源工具箱,用于文本
|
|||||||
|
|
||||||
请注意,ncnn、pplnn 和 OpenVINO 仅支持 DBNet 的 DBNet18 配置。
|
请注意,ncnn、pplnn 和 OpenVINO 仅支持 DBNet 的 DBNet18 配置。
|
||||||
|
|
||||||
|
CRNN 模型的 TensorRT int8量化只在 TensorRT 7.2.3.4 和 CUDA10.2下测试可用。
|
||||||
|
|
||||||
对于在 ICDAR 数据集上预训 [checkpoint](https://download.openmmlab.com/mmocr/textdet/panet/panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth) 的 PANet,如果要将模型转为具有 fp16 TensorRT,请尝试以下脚本。
|
对于在 ICDAR 数据集上预训 [checkpoint](https://download.openmmlab.com/mmocr/textdet/panet/panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth) 的 PANet,如果要将模型转为具有 fp16 TensorRT,请尝试以下脚本。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
@ -8,7 +8,8 @@ import onnx.utils
|
|||||||
from mmdeploy.apis.core import PIPELINE_MANAGER
|
from mmdeploy.apis.core import PIPELINE_MANAGER
|
||||||
from mmdeploy.core.optimizers import (attribute_to_dict, create_extractor,
|
from mmdeploy.core.optimizers import (attribute_to_dict, create_extractor,
|
||||||
get_new_name, parse_extractor_io_string,
|
get_new_name, parse_extractor_io_string,
|
||||||
remove_identity, rename_value)
|
remove_identity, remove_imports,
|
||||||
|
rename_value)
|
||||||
from mmdeploy.utils import get_root_logger
|
from mmdeploy.utils import get_root_logger
|
||||||
|
|
||||||
|
|
||||||
@ -198,6 +199,9 @@ def extract_partition(model: Union[str, onnx.ModelProto],
|
|||||||
dim.dim_value = 0
|
dim.dim_value = 0
|
||||||
dim.dim_param = f'dim_{idx}'
|
dim.dim_param = f'dim_{idx}'
|
||||||
|
|
||||||
|
# remove mmdeploy domain if useless
|
||||||
|
remove_imports(extracted_model)
|
||||||
|
|
||||||
# save extract_model if save_file is given
|
# save extract_model if save_file is given
|
||||||
if save_file is not None:
|
if save_file is not None:
|
||||||
onnx.save(extracted_model, save_file)
|
onnx.save(extracted_model, save_file)
|
||||||
|
@ -42,8 +42,7 @@ def process_model_config(model_cfg: mmcv.Config,
|
|||||||
transforms = cfg.data.test.pipeline[1]['transforms']
|
transforms = cfg.data.test.pipeline[1]['transforms']
|
||||||
for trans in transforms:
|
for trans in transforms:
|
||||||
trans_type = trans['type']
|
trans_type = trans['type']
|
||||||
if trans_type == 'Resize' and len(
|
if trans_type == 'Resize' and len(input_shape) != 1:
|
||||||
input_shape) != 1 and input_shape[0] != input_shape[1]:
|
|
||||||
trans['keep_ratio'] = False
|
trans['keep_ratio'] = False
|
||||||
elif trans_type == 'Pad':
|
elif trans_type == 'Pad':
|
||||||
if 'size_divisor' in trans:
|
if 'size_divisor' in trans:
|
||||||
|
@ -2,10 +2,10 @@
|
|||||||
from .extractor import create_extractor, parse_extractor_io_string
|
from .extractor import create_extractor, parse_extractor_io_string
|
||||||
from .function_marker import mark, reset_mark_function_count
|
from .function_marker import mark, reset_mark_function_count
|
||||||
from .optimize import (attribute_to_dict, get_new_name, remove_identity,
|
from .optimize import (attribute_to_dict, get_new_name, remove_identity,
|
||||||
rename_value)
|
remove_imports, rename_value)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'mark', 'reset_mark_function_count', 'create_extractor',
|
'mark', 'reset_mark_function_count', 'create_extractor',
|
||||||
'parse_extractor_io_string', 'remove_identity', 'attribute_to_dict',
|
'parse_extractor_io_string', 'remove_identity', 'attribute_to_dict',
|
||||||
'rename_value', 'get_new_name'
|
'rename_value', 'get_new_name', 'remove_imports'
|
||||||
]
|
]
|
||||||
|
@ -206,3 +206,24 @@ def remove_identity(model: onnx.ModelProto):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
remove_nodes(model, is_identity)
|
remove_nodes(model, is_identity)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_imports(model: onnx.ModelProto):
|
||||||
|
"""Remove useless imports from an ONNX model.
|
||||||
|
|
||||||
|
The domain like `mmdeploy` might influence model conversion for
|
||||||
|
some backends.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (onnx.ModelProto): Input onnx model.
|
||||||
|
"""
|
||||||
|
logger = get_root_logger()
|
||||||
|
dst_domain = ['']
|
||||||
|
for node in model.graph.node:
|
||||||
|
if hasattr(node, 'module') and (node.module not in dst_domain):
|
||||||
|
dst_domain.append(node.module)
|
||||||
|
src_domains = [oi.domain for oi in model.opset_import]
|
||||||
|
for i, src_domain in enumerate(src_domains):
|
||||||
|
if src_domain not in dst_domain:
|
||||||
|
logger.info(f'remove opset_import {src_domain}')
|
||||||
|
model.opset_import.pop(i)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user