# Copyright (c) OpenMMLab. All rights reserved. import tempfile import numpy as np import torch import torch.nn as nn from mmengine import ConfigDict from mmseg.apis import MMSegInferencer from mmseg.models import EncoderDecoder from mmseg.models.decode_heads.decode_head import BaseDecodeHead from mmseg.registry import MODELS from mmseg.utils import register_all_modules @MODELS.register_module(name='InferExampleHead') class ExampleDecodeHead(BaseDecodeHead): def __init__(self, num_classes=19, out_channels=None): super().__init__( 3, 3, num_classes=num_classes, out_channels=out_channels) def forward(self, inputs): return self.cls_seg(inputs[0]) @MODELS.register_module(name='InferExampleBackbone') class ExampleBackbone(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 3, 3) def init_weights(self, pretrained=None): pass def forward(self, x): return [self.conv(x)] @MODELS.register_module(name='InferExampleModel') class ExampleModel(EncoderDecoder): def __init__(self, **kwargs): super().__init__(**kwargs) def test_inferencer(): register_all_modules() visualizer = dict( type='SegLocalVisualizer', vis_backends=[dict(type='LocalVisBackend')], name='visualizer') cfg_dict = dict( model=dict( type='InferExampleModel', data_preprocessor=dict(type='SegDataPreProcessor'), backbone=dict(type='InferExampleBackbone'), decode_head=dict(type='InferExampleHead'), test_cfg=dict(mode='whole')), visualizer=visualizer, test_dataloader=dict( dataset=dict( type='ExampleDataset', pipeline=[ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations'), dict(type='PackSegInputs') ]), )) cfg = ConfigDict(cfg_dict) model = MODELS.build(cfg.model) ckpt = model.state_dict() ckpt_filename = tempfile.mktemp() torch.save(ckpt, ckpt_filename) # test initialization infer = MMSegInferencer(cfg, ckpt_filename) # test forward img = np.random.randint(0, 256, (4, 4, 3)) infer(img) imgs = [img, img] infer(imgs) results = infer(imgs, out_dir=tempfile.gettempdir()) # test results assert 'predictions' in results assert 'visualization' in results assert len(results['predictions']) == 2 assert results['predictions'][0].shape == (4, 4)