[Fix] fix visualization for partition (#1424)

* init

* lint

* pass output_names outside

* docstring & type hint
pull/1459/head
AllentDan 2022-11-29 11:40:00 +08:00 committed by GitHub
parent b521e7da03
commit 047ab67c78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 2 deletions

View File

@ -16,7 +16,7 @@ from mmdeploy.codebase.base import BaseBackendModel
from mmdeploy.codebase.mmdet.core.post_processing import multiclass_nms
from mmdeploy.codebase.mmdet.deploy import get_post_processing_params
from mmdeploy.utils import (Backend, get_backend, get_codebase_config,
get_partition_config, load_config)
get_ir_config, get_partition_config, load_config)
def __build_backend_model(partition_name: str, backend: Backend,
@ -678,6 +678,30 @@ class RKNNModel(End2EndModel):
model_cfg = load_config(model_cfg)[0]
self.model_cfg = model_cfg
def _init_wrapper(self, backend: Backend, backend_files: Sequence[str],
device: str):
"""Initialize backend wrapper.
Args:
backend (Backend): The backend enum, specifying backend type.
backend_files (Sequence[str]): Paths rknn model files.
device (str): A string specifying device type.
"""
output_names = None
if self.deploy_cfg is not None:
ir_config = get_ir_config(self.deploy_cfg)
output_names = ir_config.get('output_names', None)
if get_partition_config(self.deploy_cfg) is not None:
output_names = get_partition_config(
self.deploy_cfg)['partition_cfg'][0]['output_names']
self.wrapper = BaseBackendModel._build_wrapper(
backend,
backend_files,
device,
output_names=output_names,
deploy_cfg=self.deploy_cfg)
def _get_bboxes(self, outputs, img_metas):
from mmdet.models import build_head
head_cfg = self.model_cfg._cfg_dict.model.bbox_head

View File

@ -457,7 +457,9 @@ def test_build_object_detection_model(partition_type):
type='mmdet', post_processing=post_processing)))
if partition_type:
deploy_cfg.partition_config = dict(
apply_marks=True, type=partition_type)
apply_marks=True,
type=partition_type,
partition_cfg=[dict(output_names=[])])
from mmdeploy.backend.onnxruntime import ORTWrapper
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})