Fix some bugs on dev1.x (#1390)

* fix onnx export unused param

* add cfgoptions in reg test
This commit is contained in:
RunningLeon 2022-11-18 18:35:00 +08:00 committed by GitHub
parent a59b17259d
commit a10b9e964b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 24 additions and 7 deletions

View File

@ -25,7 +25,6 @@ def export(model: torch.nn.Module,
dynamic_axes: Optional[Dict] = None, dynamic_axes: Optional[Dict] = None,
verbose: bool = False, verbose: bool = False,
keep_initializers_as_inputs: Optional[bool] = None, keep_initializers_as_inputs: Optional[bool] = None,
patch_metas: Dict = {},
optimize: bool = False): optimize: bool = False):
"""Export a PyTorch model into ONNX format. This is a wrap of """Export a PyTorch model into ONNX format. This is a wrap of
`torch.onnx.export` with some enhancement. `torch.onnx.export` with some enhancement.
@ -67,7 +66,6 @@ def export(model: torch.nn.Module,
verbose (bool): Enable verbose model on `torch.onnx.export`. verbose (bool): Enable verbose model on `torch.onnx.export`.
keep_initializers_as_inputs (bool): Whether we should add inputs for keep_initializers_as_inputs (bool): Whether we should add inputs for
each initializer. each initializer.
patch_meta (Dict): The information used to patch the model.
optimize (bool): Perform optimize on model. optimize (bool): Perform optimize on model.
""" """
output_path = output_path_prefix + '.onnx' output_path = output_path_prefix + '.onnx'
@ -117,8 +115,9 @@ def export(model: torch.nn.Module,
assert isinstance( assert isinstance(
input_metas, dict input_metas, dict
), f'Expect input_metas type is dict, get {type(input_metas)}.' ), f'Expect input_metas type is dict, get {type(input_metas)}.'
model_forward = model.forward model_forward = patched_model.forward
model.forward = partial(model.forward, **input_metas) patched_model.forward = partial(patched_model.forward,
**input_metas)
torch.onnx.export( torch.onnx.export(
patched_model, patched_model,
@ -133,4 +132,4 @@ def export(model: torch.nn.Module,
verbose=verbose) verbose=verbose)
if input_metas is not None: if input_metas is not None:
model.forward = model_forward patched_model.forward = model_forward

View File

@ -67,7 +67,6 @@ def torch2onnx(img: Any,
if isinstance(model_inputs, list) and len(model_inputs) == 1: if isinstance(model_inputs, list) and len(model_inputs) == 1:
model_inputs = model_inputs[0] model_inputs = model_inputs[0]
data_samples = data['data_samples'] data_samples = data['data_samples']
patch_metas = {'data_samples': data_samples}
input_metas = {'data_samples': data_samples, 'mode': 'predict'} input_metas = {'data_samples': data_samples, 'mode': 'predict'}
# export to onnx # export to onnx
@ -107,5 +106,4 @@ def torch2onnx(img: Any,
dynamic_axes=dynamic_axes, dynamic_axes=dynamic_axes,
verbose=verbose, verbose=verbose,
keep_initializers_as_inputs=keep_initializers_as_inputs, keep_initializers_as_inputs=keep_initializers_as_inputs,
patch_metas=patch_metas,
optimize=optimize) optimize=optimize)

View File

@ -2,6 +2,7 @@ aenum
grpcio grpcio
h5py h5py
matplotlib matplotlib
mmengine
multiprocess multiprocess
numpy numpy
onnx>=1.8.0 onnx>=1.8.0

View File

@ -328,3 +328,14 @@ models:
pipelines: pipelines:
- *pipeline_seg_ort_dynamic_fp32 - *pipeline_seg_ort_dynamic_fp32
- *pipeline_seg_trt_dynamic_fp32 - *pipeline_seg_trt_dynamic_fp32
- name: RTMDet
metafile: configs/rtmdet/metafile.yml
model_configs:
- configs/rtmdet/rtmdet_s_8xb32-300e_coco.py
pipelines:
- *pipeline_ort_dynamic_fp32
- deploy_config: configs/mmdet/detection/detection_tensorrt_static-640x640.py
convert_image: *convert_image
backend_test: *default_backend_test
sdk_config: *sdk_dynamic

View File

@ -487,6 +487,14 @@ def get_backend_fps_metric(deploy_cfg_path: str, model_cfg_path: Path,
] ]
codebase_name = get_codebase(str(deploy_cfg_path)).value codebase_name = get_codebase(str(deploy_cfg_path)).value
# to stop Dataloader OOM in docker CI
if codebase_name not in ['medit', 'mmocr']:
cfg_options = 'test_dataloader.num_workers=0 ' \
'test_dataloader.persistent_workers=False ' \
'val_dataloader.num_workers=0 ' \
'val_dataloader.persistent_workers=False '
cmd_lines.append(f'--cfg-options {cfg_options}')
# Test backend # Test backend
return_code = run_cmd(cmd_lines, log_path) return_code = run_cmd(cmd_lines, log_path)
fps, backend_metric, test_pass = get_fps_metric(return_code, fps, backend_metric, test_pass = get_fps_metric(return_code,