mmsegmentation/tests/test_apis/test_inferencer.py
2023-06-19 13:08:04 +08:00

95 lines
2.5 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import tempfile
import numpy as np
import torch
import torch.nn as nn
from mmengine import ConfigDict
from mmseg.apis import MMSegInferencer
from mmseg.models import EncoderDecoder
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.registry import MODELS
from mmseg.utils import register_all_modules
@MODELS.register_module(name='InferExampleHead')
class ExampleDecodeHead(BaseDecodeHead):
def __init__(self, num_classes=19, out_channels=None):
super().__init__(
3, 3, num_classes=num_classes, out_channels=out_channels)
def forward(self, inputs):
return self.cls_seg(inputs[0])
@MODELS.register_module(name='InferExampleBackbone')
class ExampleBackbone(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 3)
def init_weights(self, pretrained=None):
pass
def forward(self, x):
return [self.conv(x)]
@MODELS.register_module(name='InferExampleModel')
class ExampleModel(EncoderDecoder):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def test_inferencer():
register_all_modules()
visualizer = dict(
type='SegLocalVisualizer',
vis_backends=[dict(type='LocalVisBackend')],
name='visualizer')
cfg_dict = dict(
model=dict(
type='InferExampleModel',
data_preprocessor=dict(type='SegDataPreProcessor'),
backbone=dict(type='InferExampleBackbone'),
decode_head=dict(type='InferExampleHead'),
test_cfg=dict(mode='whole')),
visualizer=visualizer,
test_dataloader=dict(
dataset=dict(
type='ExampleDataset',
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]), ))
cfg = ConfigDict(cfg_dict)
model = MODELS.build(cfg.model)
ckpt = model.state_dict()
ckpt_filename = tempfile.mktemp()
torch.save(ckpt, ckpt_filename)
# test initialization
infer = MMSegInferencer(cfg, ckpt_filename)
# test forward
img = np.random.randint(0, 256, (4, 4, 3))
infer(img)
imgs = [img, img]
infer(imgs)
results = infer(imgs, out_dir=tempfile.gettempdir())
# test results
assert 'predictions' in results
assert 'visualization' in results
assert len(results['predictions']) == 2
assert results['predictions'][0].shape == (4, 4)