# Copyright (c) OpenMMLab. All rights reserved. import mmcv import numpy as np import pytest import torch import torch.nn as nn from mmcv import ConfigDict from mmseg.models import BACKBONES, HEADS from mmseg.models.decode_heads.decode_head import BaseDecodeHead from mmdeploy.codebase import import_codebase from mmdeploy.utils import Backend, Codebase, Task from mmdeploy.utils.test import (WrapModel, check_backend, get_model_outputs, get_rewrite_outputs) import_codebase(Codebase.MMSEG) @BACKBONES.register_module() class ExampleBackbone(nn.Module): def __init__(self): super(ExampleBackbone, self).__init__() self.conv = nn.Conv2d(3, 3, 3) def init_weights(self, pretrained=None): pass def forward(self, x): return [self.conv(x)] @HEADS.register_module() class ExampleDecodeHead(BaseDecodeHead): def __init__(self): super(ExampleDecodeHead, self).__init__(3, 3, num_classes=19) def forward(self, inputs): return self.cls_seg(inputs[0]) def get_model(type='EncoderDecoder', backbone='ExampleBackbone', decode_head='ExampleDecodeHead'): from mmseg.models import build_segmentor cfg = ConfigDict( type=type, backbone=dict(type=backbone), decode_head=dict(type=decode_head), train_cfg=None, test_cfg=dict(mode='whole')) segmentor = build_segmentor(cfg) return segmentor def _demo_mm_inputs(input_shape=(1, 3, 8, 16), num_classes=10): """Create a superset of inputs needed to run test or train batches. Args: input_shape (tuple): input batch dimensions num_classes (int): number of semantic classes """ (N, C, H, W) = input_shape rng = np.random.RandomState(0) imgs = rng.rand(*input_shape) segs = rng.randint( low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8) img_metas = [{ 'img_shape': (H, W, C), 'ori_shape': (H, W, C), 'pad_shape': (H, W, C), 'filename': '.png', 'scale_factor': 1.0, 'flip': False, 'flip_direction': 'horizontal' } for _ in range(N)] mm_inputs = { 'imgs': torch.FloatTensor(imgs), 'img_metas': img_metas, 'gt_semantic_seg': torch.LongTensor(segs) } return mm_inputs @pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME, Backend.OPENVINO, Backend.NCNN]) def test_encoderdecoder_simple_test(backend): check_backend(backend) segmentor = get_model() segmentor.cpu().eval() deploy_cfg = mmcv.Config( dict( backend_config=dict(type=backend.value), onnx_config=dict(output_names=['result'], input_shape=None), codebase_config=dict(type='mmseg', task='Segmentation'))) if isinstance(segmentor.decode_head, nn.ModuleList): num_classes = segmentor.decode_head[-1].num_classes else: num_classes = segmentor.decode_head.num_classes mm_inputs = _demo_mm_inputs( input_shape=(1, 3, 32, 32), num_classes=num_classes) imgs = mm_inputs.pop('imgs') img_metas = mm_inputs.pop('img_metas') model_inputs = {'img': imgs, 'img_meta': img_metas} model_outputs = get_model_outputs(segmentor, 'simple_test', model_inputs) img_meta = { 'img_shape': (img_metas[0]['img_shape'][0], img_metas[0]['img_shape'][1]) } wrapped_model = WrapModel(segmentor, 'simple_test', img_meta=img_meta) rewrite_inputs = { 'img': imgs, } rewrite_outputs, is_backend_output = get_rewrite_outputs( wrapped_model=wrapped_model, model_inputs=rewrite_inputs, deploy_cfg=deploy_cfg) model_outputs = torch.tensor(model_outputs[0]) rewrite_outputs = rewrite_outputs[0].to(model_outputs).reshape( model_outputs.shape) assert torch.allclose(rewrite_outputs, model_outputs) @pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME, Backend.OPENVINO]) def test_basesegmentor_forward(backend): check_backend(backend) segmentor = get_model() segmentor.cpu().eval() deploy_cfg = mmcv.Config( dict( backend_config=dict(type=backend.value), onnx_config=dict(output_names=['result'], input_shape=None), codebase_config=dict(type='mmseg', task='Segmentation'))) if isinstance(segmentor.decode_head, nn.ModuleList): num_classes = segmentor.decode_head[-1].num_classes else: num_classes = segmentor.decode_head.num_classes mm_inputs = _demo_mm_inputs(num_classes=num_classes) imgs = mm_inputs.pop('imgs') img_metas = mm_inputs.pop('img_metas') model_inputs = { 'img': [imgs], 'img_metas': [img_metas], 'return_loss': False } model_outputs = get_model_outputs(segmentor, 'forward', model_inputs) wrapped_model = WrapModel(segmentor, 'forward') rewrite_inputs = {'img': imgs} rewrite_outputs, is_backend_output = get_rewrite_outputs( wrapped_model=wrapped_model, model_inputs=rewrite_inputs, deploy_cfg=deploy_cfg) model_outputs = torch.tensor(model_outputs[0]) rewrite_outputs = rewrite_outputs[0].to(model_outputs).reshape( model_outputs.shape) assert torch.allclose(rewrite_outputs, model_outputs) @pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME, Backend.OPENVINO]) def test_aspphead_forward(backend): check_backend(backend) from mmseg.models.decode_heads import ASPPHead head = ASPPHead( in_channels=32, channels=16, num_classes=19, dilations=(1, 12, 24)).eval() deploy_cfg = mmcv.Config( dict( backend_config=dict(type=backend.value), onnx_config=dict( output_names=['result'], input_shape=(1, 32, 45, 45)), codebase_config=dict(type='mmseg', task='Segmentation'))) inputs = [torch.randn(1, 32, 45, 45)] model_inputs = {'inputs': inputs} with torch.no_grad(): model_outputs = get_model_outputs(head, 'forward', model_inputs) wrapped_model = WrapModel(head, 'forward') rewrite_inputs = {'inputs': inputs} rewrite_outputs, is_backend_output = get_rewrite_outputs( wrapped_model=wrapped_model, model_inputs=rewrite_inputs, deploy_cfg=deploy_cfg) if is_backend_output: rewrite_outputs = rewrite_outputs[0] rewrite_outputs = rewrite_outputs.to(model_outputs).reshape( model_outputs.shape) assert torch.allclose( rewrite_outputs, model_outputs, rtol=1e-03, atol=1e-05) @pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME, Backend.OPENVINO, Backend.NCNN]) def test_psphead_forward(backend): check_backend(backend) from mmseg.models.decode_heads import PSPHead head = PSPHead(in_channels=32, channels=16, num_classes=19).eval() deploy_cfg = mmcv.Config( dict( backend_config=dict(type=backend.value), onnx_config=dict(output_names=['result'], input_shape=None), codebase_config=dict(type='mmseg', task='Segmentation'))) inputs = [torch.randn(1, 32, 45, 45)] model_inputs = {'inputs': inputs} with torch.no_grad(): model_outputs = get_model_outputs(head, 'forward', model_inputs) wrapped_model = WrapModel(head, 'forward') rewrite_inputs = {'inputs': inputs} rewrite_outputs, is_backend_output = get_rewrite_outputs( wrapped_model=wrapped_model, model_inputs=rewrite_inputs, deploy_cfg=deploy_cfg) if is_backend_output: rewrite_outputs = rewrite_outputs[0] rewrite_outputs = rewrite_outputs.to(model_outputs).reshape( model_outputs.shape) assert torch.allclose(rewrite_outputs, model_outputs, rtol=1, atol=1) @pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME]) def test_emamodule_forward(backend): check_backend(backend) from mmseg.models.decode_heads.ema_head import EMAModule head = EMAModule(8, 2, 2, 1.0).eval() deploy_cfg = mmcv.Config( dict( backend_config=dict(type=backend.value), onnx_config=dict( output_names=['result'], input_shape=(1, 8, 16, 16)), codebase_config=dict(type='mmseg', task='Segmentation'))) feats = torch.randn(1, 8, 16, 16) model_inputs = {'feats': feats} with torch.no_grad(): model_outputs = get_model_outputs(head, 'forward', model_inputs) wrapped_model = WrapModel(head, 'forward') rewrite_outputs, is_backend_output = get_rewrite_outputs( wrapped_model=wrapped_model, model_inputs=model_inputs, deploy_cfg=deploy_cfg) if is_backend_output: rewrite_outputs = rewrite_outputs[0] rewrite_outputs = rewrite_outputs.to(model_outputs).reshape( model_outputs.shape) assert torch.allclose( rewrite_outputs, model_outputs, rtol=1e-03, atol=1e-05) @pytest.mark.parametrize('is_dynamic_shape', [True, False]) @pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME]) def test_upconvblock_forward(backend, is_dynamic_shape): check_backend(backend) from mmseg.models.backbones.unet import BasicConvBlock from mmseg.models.utils import UpConvBlock head = UpConvBlock(BasicConvBlock, 16, 8, 8).eval() dynamic_axes = { 'x': { 0: 'b', 2: 'h', 3: 'w' }, 'skip': { 0: 'b', 2: 'h', 3: 'w' }, 'output': { 0: 'b', 2: 'h', 3: 'w' }, } if is_dynamic_shape else None deploy_cfg = mmcv.Config( dict( backend_config=dict(type=backend.value), onnx_config=dict( input_names=['skip', 'x'], output_names=['output'], dynamic_axes=dynamic_axes), codebase_config=dict( type=Codebase.MMSEG.value, task=Task.SEGMENTATION.value))) x = torch.randn(1, 16, 16, 16) skip = torch.randn(1, 8, 32, 32) model_inputs = {'x': x, 'skip': skip} with torch.no_grad(): model_outputs = get_model_outputs(head, 'forward', model_inputs) wrapped_model = WrapModel(head, 'forward') rewrite_outputs, is_backend_output = get_rewrite_outputs( wrapped_model=wrapped_model, model_inputs=model_inputs, deploy_cfg=deploy_cfg) if is_backend_output: rewrite_outputs = rewrite_outputs[0] rewrite_outputs = rewrite_outputs.to(model_outputs).reshape( model_outputs.shape) assert torch.allclose( rewrite_outputs, model_outputs, rtol=1e-03, atol=1e-05)