mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
* [WIP] Refactor v2.0 (#163) * Refactor backend wrapper * Refactor mmdet.inference * Fix * merge * refactor utils * Use deployer and deploy_model to manage pipeline * Resolve comments * Add a real inference api function * rename wrappers * Set execute to private method * Rename deployer deploy_model * Refactor task * remove type hint * lint * Resolve comments * resolve comments * lint * docstring * [Fix]: Fix bugs in details in refactor branch (#192) * [WIP] Refactor v2.0 (#163) * Refactor backend wrapper * Refactor mmdet.inference * Fix * merge * refactor utils * Use deployer and deploy_model to manage pipeline * Resolve comments * Add a real inference api function * rename wrappers * Set execute to private method * Rename deployer deploy_model * Refactor task * remove type hint * lint * Resolve comments * resolve comments * lint * docstring * Fix errors * lint * resolve comments * fix bugs * conflict * lint and typo * Resolve comment * refactor mmseg (#201) * support mmseg * fix docstring * fix docstring * [Refactor]: Get the count of backend files (#202) * Fix backend files * resolve comments * lint * Fix ncnn * [Refactor]: Refactor folders of mmdet (#200) * Move folders * lint * test object detection model * lint * reset changes * fix openvino * resolve comments * __init__.py * Fix path * [Refactor]: move mmseg (#206) * [Refactor]: Refactor mmedit (#205) * feature mmedit * edit2.0 * edit * refactor mmedit * fix __init__.py * fix __init__ * fix formai * fix comment * fix comment * Fix wrong func_name of ConvFCBBoxHead (#209) * [Refactor]: Refactor mmdet unit test (#207) * Move folders * lint * test object detection model * lint * WIP * remove print * finish unit test * Fix tests * resolve comments * Add mask test * lint * resolve comments * Refine cfg file * Move files * add files * Fix path * [Unittest]: Refine the unit tests in mmdet #214 * [Refactor] refactor mmocr to mmdeploy/codebase (#213) * refactor mmocr to mmdeploy/codebase * fix docstring of show_result * fix docstring of visualize * refine docstring * replace print with logging * refince codes * resolve comments * resolve comments * [Refactor]: mmseg tests (#210) * refactor mmseg tests * rename test_codebase * update * add model.py * fix * [Refactor] Refactor mmcls and the package (#217) * refactor mmcls * fix yapf * fix isort * refactor-mmcls-package * fix print to logging * fix docstrings according to others comments * fix comments * fix comments * fix allentdans comment in pr215 * remove mmocr init * [Refactor] Refactor mmedit tests (#212) * feature mmedit * edit2.0 * edit * refactor mmedit * fix __init__.py * fix __init__ * fix formai * fix comment * fix comment * buff * edit test and code refactor * refactor dir * refactor tests/mmedit * fix docstring * add test coverage * fix lint * fix comment * fix comment * Update typehint (#216) * update type hint * update docstring * update * remove file * fix ppl * Refine get_predefined_partition_cfg * fix tensorrt version > 8 * move parse_cuda_device_id to device.py * Fix cascade * onnx2ncnn docstring Co-authored-by: Yifan Zhou <singlezombie@163.com> Co-authored-by: RunningLeon <maningsheng@sensetime.com> Co-authored-by: VVsssssk <88368822+VVsssssk@users.noreply.github.com> Co-authored-by: AllentDan <41138331+AllentDan@users.noreply.github.com> Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com>
230 lines
7.4 KiB
Python
230 lines
7.4 KiB
Python
import mmcv
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmcv import ConfigDict
|
|
from mmseg.models import BACKBONES, HEADS
|
|
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
|
|
|
from mmdeploy.utils.test import (WrapModel, get_model_outputs,
|
|
get_rewrite_outputs)
|
|
|
|
|
|
@BACKBONES.register_module()
|
|
class ExampleBackbone(nn.Module):
|
|
|
|
def __init__(self):
|
|
super(ExampleBackbone, self).__init__()
|
|
self.conv = nn.Conv2d(3, 3, 3)
|
|
|
|
def init_weights(self, pretrained=None):
|
|
pass
|
|
|
|
def forward(self, x):
|
|
return [self.conv(x)]
|
|
|
|
|
|
@HEADS.register_module()
|
|
class ExampleDecodeHead(BaseDecodeHead):
|
|
|
|
def __init__(self):
|
|
super(ExampleDecodeHead, self).__init__(3, 3, num_classes=19)
|
|
|
|
def forward(self, inputs):
|
|
return self.cls_seg(inputs[0])
|
|
|
|
|
|
def get_model(type='EncoderDecoder',
|
|
backbone='ExampleBackbone',
|
|
decode_head='ExampleDecodeHead'):
|
|
|
|
from mmseg.models import build_segmentor
|
|
|
|
cfg = ConfigDict(
|
|
type=type,
|
|
backbone=dict(type=backbone),
|
|
decode_head=dict(type=decode_head),
|
|
train_cfg=None,
|
|
test_cfg=dict(mode='whole'))
|
|
segmentor = build_segmentor(cfg)
|
|
|
|
return segmentor
|
|
|
|
|
|
def _demo_mm_inputs(input_shape=(1, 3, 8, 16), num_classes=10):
|
|
"""Create a superset of inputs needed to run test or train batches.
|
|
|
|
Args:
|
|
input_shape (tuple):
|
|
input batch dimensions
|
|
|
|
num_classes (int):
|
|
number of semantic classes
|
|
"""
|
|
(N, C, H, W) = input_shape
|
|
|
|
rng = np.random.RandomState(0)
|
|
|
|
imgs = rng.rand(*input_shape)
|
|
segs = rng.randint(
|
|
low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8)
|
|
|
|
img_metas = [{
|
|
'img_shape': (H, W, C),
|
|
'ori_shape': (H, W, C),
|
|
'pad_shape': (H, W, C),
|
|
'filename': '<demo>.png',
|
|
'scale_factor': 1.0,
|
|
'flip': False,
|
|
'flip_direction': 'horizontal'
|
|
} for _ in range(N)]
|
|
|
|
mm_inputs = {
|
|
'imgs': torch.FloatTensor(imgs),
|
|
'img_metas': img_metas,
|
|
'gt_semantic_seg': torch.LongTensor(segs)
|
|
}
|
|
return mm_inputs
|
|
|
|
|
|
@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn'])
|
|
def test_encoderdecoder_simple_test(backend_type):
|
|
segmentor = get_model()
|
|
segmentor.cpu().eval()
|
|
|
|
deploy_cfg = mmcv.Config(
|
|
dict(
|
|
backend_config=dict(type=backend_type),
|
|
onnx_config=dict(output_names=['result'], input_shape=None),
|
|
codebase_config=dict(type='mmseg', task='Segmentation')))
|
|
|
|
if isinstance(segmentor.decode_head, nn.ModuleList):
|
|
num_classes = segmentor.decode_head[-1].num_classes
|
|
else:
|
|
num_classes = segmentor.decode_head.num_classes
|
|
mm_inputs = _demo_mm_inputs(num_classes=num_classes)
|
|
imgs = mm_inputs.pop('imgs')
|
|
img_metas = mm_inputs.pop('img_metas')
|
|
model_inputs = {'img': imgs, 'img_meta': img_metas}
|
|
model_outputs = get_model_outputs(segmentor, 'simple_test', model_inputs)
|
|
img_meta = {
|
|
'img_shape':
|
|
(img_metas[0]['img_shape'][0], img_metas[0]['img_shape'][1])
|
|
}
|
|
wrapped_model = WrapModel(segmentor, 'simple_test', img_meta=img_meta)
|
|
|
|
rewrite_inputs = {
|
|
'img': imgs,
|
|
}
|
|
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
|
wrapped_model=wrapped_model,
|
|
model_inputs=rewrite_inputs,
|
|
deploy_cfg=deploy_cfg)
|
|
|
|
if is_backend_output:
|
|
rewrite_outputs = rewrite_outputs[0]
|
|
model_outputs = torch.tensor(model_outputs[0])
|
|
model_outputs = model_outputs.unsqueeze(0).unsqueeze(0)
|
|
assert torch.allclose(rewrite_outputs, model_outputs)
|
|
else:
|
|
assert rewrite_outputs is not None
|
|
|
|
|
|
@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn'])
|
|
def test_basesegmentor_forward(backend_type):
|
|
segmentor = get_model()
|
|
segmentor.cpu().eval()
|
|
|
|
deploy_cfg = mmcv.Config(
|
|
dict(
|
|
backend_config=dict(type=backend_type),
|
|
onnx_config=dict(output_names=['result'], input_shape=None),
|
|
codebase_config=dict(type='mmseg', task='Segmentation')))
|
|
|
|
if isinstance(segmentor.decode_head, nn.ModuleList):
|
|
num_classes = segmentor.decode_head[-1].num_classes
|
|
else:
|
|
num_classes = segmentor.decode_head.num_classes
|
|
mm_inputs = _demo_mm_inputs(num_classes=num_classes)
|
|
imgs = mm_inputs.pop('imgs')
|
|
img_metas = mm_inputs.pop('img_metas')
|
|
model_inputs = {
|
|
'img': [imgs],
|
|
'img_metas': [img_metas],
|
|
'return_loss': False
|
|
}
|
|
model_outputs = get_model_outputs(segmentor, 'forward', model_inputs)
|
|
|
|
wrapped_model = WrapModel(segmentor, 'forward')
|
|
rewrite_inputs = {'img': imgs}
|
|
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
|
wrapped_model=wrapped_model,
|
|
model_inputs=rewrite_inputs,
|
|
deploy_cfg=deploy_cfg)
|
|
if is_backend_output:
|
|
rewrite_outputs = torch.tensor(rewrite_outputs[0])
|
|
model_outputs = torch.tensor(model_outputs[0])
|
|
model_outputs = model_outputs.unsqueeze(0).unsqueeze(0)
|
|
assert torch.allclose(rewrite_outputs, model_outputs)
|
|
else:
|
|
assert rewrite_outputs is not None
|
|
|
|
|
|
@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn'])
|
|
def test_aspphead_forward(backend_type):
|
|
from mmseg.models.decode_heads import ASPPHead
|
|
head = ASPPHead(
|
|
in_channels=32, channels=16, num_classes=19,
|
|
dilations=(1, 12, 24)).eval()
|
|
|
|
deploy_cfg = mmcv.Config(
|
|
dict(
|
|
backend_config=dict(type=backend_type),
|
|
onnx_config=dict(
|
|
output_names=['result'], input_shape=(1, 32, 45, 45)),
|
|
codebase_config=dict(type='mmseg', task='Segmentation')))
|
|
inputs = [torch.randn(1, 32, 45, 45)]
|
|
model_inputs = {'inputs': inputs}
|
|
with torch.no_grad():
|
|
model_outputs = get_model_outputs(head, 'forward', model_inputs)
|
|
wrapped_model = WrapModel(head, 'forward')
|
|
rewrite_inputs = {'inputs': inputs}
|
|
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
|
wrapped_model=wrapped_model,
|
|
model_inputs=rewrite_inputs,
|
|
deploy_cfg=deploy_cfg)
|
|
if is_backend_output:
|
|
rewrite_outputs = torch.tensor(rewrite_outputs[0])
|
|
assert torch.allclose(
|
|
rewrite_outputs, model_outputs, rtol=1e-03, atol=1e-05)
|
|
else:
|
|
assert rewrite_outputs is not None
|
|
|
|
|
|
@pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn'])
|
|
def test_psphead_forward(backend_type):
|
|
from mmseg.models.decode_heads import PSPHead
|
|
head = PSPHead(in_channels=32, channels=16, num_classes=19).eval()
|
|
|
|
deploy_cfg = mmcv.Config(
|
|
dict(
|
|
backend_config=dict(type=backend_type),
|
|
onnx_config=dict(output_names=['result'], input_shape=None),
|
|
codebase_config=dict(type='mmseg', task='Segmentation')))
|
|
inputs = [torch.randn(1, 32, 45, 45)]
|
|
model_inputs = {'inputs': inputs}
|
|
with torch.no_grad():
|
|
model_outputs = get_model_outputs(head, 'forward', model_inputs)
|
|
wrapped_model = WrapModel(head, 'forward')
|
|
rewrite_inputs = {'inputs': inputs}
|
|
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
|
wrapped_model=wrapped_model,
|
|
model_inputs=rewrite_inputs,
|
|
deploy_cfg=deploy_cfg)
|
|
if is_backend_output:
|
|
rewrite_outputs = torch.tensor(rewrite_outputs[0])
|
|
assert torch.allclose(rewrite_outputs, model_outputs, rtol=1, atol=1)
|
|
else:
|
|
assert rewrite_outputs is not None
|