# Copyright (c) OpenMMLab. All rights reserved. import mmcv import numpy as np import pytest import torch import torch.nn as nn from mmcv import ConfigDict from mmseg.models import BACKBONES, HEADS from mmseg.models.decode_heads.decode_head import BaseDecodeHead from mmdeploy.codebase import import_codebase from mmdeploy.utils import Codebase from mmdeploy.utils.test import (WrapModel, get_model_outputs, get_rewrite_outputs) import_codebase(Codebase.MMSEG) @BACKBONES.register_module() class ExampleBackbone(nn.Module): def __init__(self): super(ExampleBackbone, self).__init__() self.conv = nn.Conv2d(3, 3, 3) def init_weights(self, pretrained=None): pass def forward(self, x): return [self.conv(x)] @HEADS.register_module() class ExampleDecodeHead(BaseDecodeHead): def __init__(self): super(ExampleDecodeHead, self).__init__(3, 3, num_classes=19) def forward(self, inputs): return self.cls_seg(inputs[0]) def get_model(type='EncoderDecoder', backbone='ExampleBackbone', decode_head='ExampleDecodeHead'): from mmseg.models import build_segmentor cfg = ConfigDict( type=type, backbone=dict(type=backbone), decode_head=dict(type=decode_head), train_cfg=None, test_cfg=dict(mode='whole')) segmentor = build_segmentor(cfg) return segmentor def _demo_mm_inputs(input_shape=(1, 3, 8, 16), num_classes=10): """Create a superset of inputs needed to run test or train batches. Args: input_shape (tuple): input batch dimensions num_classes (int): number of semantic classes """ (N, C, H, W) = input_shape rng = np.random.RandomState(0) imgs = rng.rand(*input_shape) segs = rng.randint( low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8) img_metas = [{ 'img_shape': (H, W, C), 'ori_shape': (H, W, C), 'pad_shape': (H, W, C), 'filename': '.png', 'scale_factor': 1.0, 'flip': False, 'flip_direction': 'horizontal' } for _ in range(N)] mm_inputs = { 'imgs': torch.FloatTensor(imgs), 'img_metas': img_metas, 'gt_semantic_seg': torch.LongTensor(segs) } return mm_inputs @pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn']) def test_encoderdecoder_simple_test(backend_type): segmentor = get_model() segmentor.cpu().eval() deploy_cfg = mmcv.Config( dict( backend_config=dict(type=backend_type), onnx_config=dict(output_names=['result'], input_shape=None), codebase_config=dict(type='mmseg', task='Segmentation'))) if isinstance(segmentor.decode_head, nn.ModuleList): num_classes = segmentor.decode_head[-1].num_classes else: num_classes = segmentor.decode_head.num_classes mm_inputs = _demo_mm_inputs(num_classes=num_classes) imgs = mm_inputs.pop('imgs') img_metas = mm_inputs.pop('img_metas') model_inputs = {'img': imgs, 'img_meta': img_metas} model_outputs = get_model_outputs(segmentor, 'simple_test', model_inputs) img_meta = { 'img_shape': (img_metas[0]['img_shape'][0], img_metas[0]['img_shape'][1]) } wrapped_model = WrapModel(segmentor, 'simple_test', img_meta=img_meta) rewrite_inputs = { 'img': imgs, } rewrite_outputs, is_backend_output = get_rewrite_outputs( wrapped_model=wrapped_model, model_inputs=rewrite_inputs, deploy_cfg=deploy_cfg) if is_backend_output: rewrite_outputs = rewrite_outputs[0] model_outputs = torch.tensor(model_outputs[0]) model_outputs = model_outputs.unsqueeze(0).unsqueeze(0) assert torch.allclose(rewrite_outputs, model_outputs) else: assert rewrite_outputs is not None @pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn']) def test_basesegmentor_forward(backend_type): segmentor = get_model() segmentor.cpu().eval() deploy_cfg = mmcv.Config( dict( backend_config=dict(type=backend_type), onnx_config=dict(output_names=['result'], input_shape=None), codebase_config=dict(type='mmseg', task='Segmentation'))) if isinstance(segmentor.decode_head, nn.ModuleList): num_classes = segmentor.decode_head[-1].num_classes else: num_classes = segmentor.decode_head.num_classes mm_inputs = _demo_mm_inputs(num_classes=num_classes) imgs = mm_inputs.pop('imgs') img_metas = mm_inputs.pop('img_metas') model_inputs = { 'img': [imgs], 'img_metas': [img_metas], 'return_loss': False } model_outputs = get_model_outputs(segmentor, 'forward', model_inputs) wrapped_model = WrapModel(segmentor, 'forward') rewrite_inputs = {'img': imgs} rewrite_outputs, is_backend_output = get_rewrite_outputs( wrapped_model=wrapped_model, model_inputs=rewrite_inputs, deploy_cfg=deploy_cfg) if is_backend_output: rewrite_outputs = torch.tensor(rewrite_outputs[0]) model_outputs = torch.tensor(model_outputs[0]) model_outputs = model_outputs.unsqueeze(0).unsqueeze(0) assert torch.allclose(rewrite_outputs, model_outputs) else: assert rewrite_outputs is not None @pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn']) def test_aspphead_forward(backend_type): from mmseg.models.decode_heads import ASPPHead head = ASPPHead( in_channels=32, channels=16, num_classes=19, dilations=(1, 12, 24)).eval() deploy_cfg = mmcv.Config( dict( backend_config=dict(type=backend_type), onnx_config=dict( output_names=['result'], input_shape=(1, 32, 45, 45)), codebase_config=dict(type='mmseg', task='Segmentation'))) inputs = [torch.randn(1, 32, 45, 45)] model_inputs = {'inputs': inputs} with torch.no_grad(): model_outputs = get_model_outputs(head, 'forward', model_inputs) wrapped_model = WrapModel(head, 'forward') rewrite_inputs = {'inputs': inputs} rewrite_outputs, is_backend_output = get_rewrite_outputs( wrapped_model=wrapped_model, model_inputs=rewrite_inputs, deploy_cfg=deploy_cfg) if is_backend_output: rewrite_outputs = torch.tensor(rewrite_outputs[0]) assert torch.allclose( rewrite_outputs, model_outputs, rtol=1e-03, atol=1e-05) else: assert rewrite_outputs is not None @pytest.mark.parametrize('backend_type', ['onnxruntime', 'ncnn']) def test_psphead_forward(backend_type): from mmseg.models.decode_heads import PSPHead head = PSPHead(in_channels=32, channels=16, num_classes=19).eval() deploy_cfg = mmcv.Config( dict( backend_config=dict(type=backend_type), onnx_config=dict(output_names=['result'], input_shape=None), codebase_config=dict(type='mmseg', task='Segmentation'))) inputs = [torch.randn(1, 32, 45, 45)] model_inputs = {'inputs': inputs} with torch.no_grad(): model_outputs = get_model_outputs(head, 'forward', model_inputs) wrapped_model = WrapModel(head, 'forward') rewrite_inputs = {'inputs': inputs} rewrite_outputs, is_backend_output = get_rewrite_outputs( wrapped_model=wrapped_model, model_inputs=rewrite_inputs, deploy_cfg=deploy_cfg) if is_backend_output: rewrite_outputs = torch.tensor(rewrite_outputs[0]) assert torch.allclose(rewrite_outputs, model_outputs, rtol=1, atol=1) else: assert rewrite_outputs is not None