[Fix] fix visualization for partition (#1424)
* init * lint * pass output_names outside * docstring & type hintpull/1459/head
parent
b521e7da03
commit
047ab67c78
|
@ -16,7 +16,7 @@ from mmdeploy.codebase.base import BaseBackendModel
|
|||
from mmdeploy.codebase.mmdet.core.post_processing import multiclass_nms
|
||||
from mmdeploy.codebase.mmdet.deploy import get_post_processing_params
|
||||
from mmdeploy.utils import (Backend, get_backend, get_codebase_config,
|
||||
get_partition_config, load_config)
|
||||
get_ir_config, get_partition_config, load_config)
|
||||
|
||||
|
||||
def __build_backend_model(partition_name: str, backend: Backend,
|
||||
|
@ -678,6 +678,30 @@ class RKNNModel(End2EndModel):
|
|||
model_cfg = load_config(model_cfg)[0]
|
||||
self.model_cfg = model_cfg
|
||||
|
||||
def _init_wrapper(self, backend: Backend, backend_files: Sequence[str],
|
||||
device: str):
|
||||
"""Initialize backend wrapper.
|
||||
|
||||
Args:
|
||||
backend (Backend): The backend enum, specifying backend type.
|
||||
backend_files (Sequence[str]): Paths rknn model files.
|
||||
device (str): A string specifying device type.
|
||||
"""
|
||||
output_names = None
|
||||
if self.deploy_cfg is not None:
|
||||
ir_config = get_ir_config(self.deploy_cfg)
|
||||
output_names = ir_config.get('output_names', None)
|
||||
if get_partition_config(self.deploy_cfg) is not None:
|
||||
output_names = get_partition_config(
|
||||
self.deploy_cfg)['partition_cfg'][0]['output_names']
|
||||
|
||||
self.wrapper = BaseBackendModel._build_wrapper(
|
||||
backend,
|
||||
backend_files,
|
||||
device,
|
||||
output_names=output_names,
|
||||
deploy_cfg=self.deploy_cfg)
|
||||
|
||||
def _get_bboxes(self, outputs, img_metas):
|
||||
from mmdet.models import build_head
|
||||
head_cfg = self.model_cfg._cfg_dict.model.bbox_head
|
||||
|
|
|
@ -457,7 +457,9 @@ def test_build_object_detection_model(partition_type):
|
|||
type='mmdet', post_processing=post_processing)))
|
||||
if partition_type:
|
||||
deploy_cfg.partition_config = dict(
|
||||
apply_marks=True, type=partition_type)
|
||||
apply_marks=True,
|
||||
type=partition_type,
|
||||
partition_cfg=[dict(output_names=[])])
|
||||
|
||||
from mmdeploy.backend.onnxruntime import ORTWrapper
|
||||
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
|
||||
|
|
Loading…
Reference in New Issue