47 lines
1.6 KiB
Python
47 lines
1.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import mmengine
|
|
import torch
|
|
from mmengine.structures import PixelData
|
|
|
|
from mmdeploy.apis import build_task_processor
|
|
from mmdeploy.utils import load_config
|
|
|
|
|
|
def generate_datasample(h, w):
|
|
from mmseg.structures import SegDataSample
|
|
metainfo = dict(img_shape=(h, w), ori_shape=(h, w), pad_shape=(h, w))
|
|
data_sample = SegDataSample()
|
|
data_sample.set_metainfo(metainfo)
|
|
seg_pred = torch.randint(0, 2, (1, h, w))
|
|
seg_gt = torch.randint(0, 2, (1, h, w))
|
|
data_sample.set_data(dict(pred_sem_seg=PixelData(**dict(data=seg_pred))))
|
|
data_sample.set_data(
|
|
dict(gt_sem_seg=PixelData(**dict(data=seg_gt, metainfo=metainfo))))
|
|
return data_sample
|
|
|
|
|
|
def generate_mmseg_deploy_config(backend='onnxruntime'):
|
|
deploy_cfg = mmengine.Config(
|
|
dict(
|
|
backend_config=dict(type=backend),
|
|
codebase_config=dict(type='mmseg', task='Segmentation'),
|
|
onnx_config=dict(
|
|
type='onnx',
|
|
export_params=True,
|
|
keep_initializers_as_inputs=False,
|
|
opset_version=11,
|
|
input_shape=None,
|
|
input_names=['input'],
|
|
output_names=['output'])))
|
|
return deploy_cfg
|
|
|
|
|
|
def generate_mmseg_task_processor(model_cfg=None, deploy_cfg=None):
|
|
if model_cfg is None:
|
|
model_cfg = 'tests/test_codebase/test_mmseg/data/model.py'
|
|
if deploy_cfg is None:
|
|
deploy_cfg = generate_mmseg_deploy_config()
|
|
model_cfg, deploy_cfg = load_config(model_cfg, deploy_cfg)
|
|
task_processor = build_task_processor(model_cfg, deploy_cfg, 'cpu')
|
|
return task_processor
|