mmsegmentation/tests/test_apis/test_inferencer.py

61 lines
1.7 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
import tempfile
import numpy as np
import torch
from mmengine import ConfigDict
from utils import * # noqa: F401, F403
from mmseg.apis import MMSegInferencer
from mmseg.registry import MODELS
from mmseg.utils import register_all_modules
def test_inferencer():
register_all_modules()
visualizer = dict(
type='SegLocalVisualizer',
vis_backends=[dict(type='LocalVisBackend')],
name='visualizer')
cfg_dict = dict(
model=dict(
type='InferExampleModel',
data_preprocessor=dict(type='SegDataPreProcessor'),
backbone=dict(type='InferExampleBackbone'),
decode_head=dict(type='InferExampleHead'),
test_cfg=dict(mode='whole')),
visualizer=visualizer,
2023-06-19 13:08:04 +08:00
test_dataloader=dict(
dataset=dict(
type='ExampleDataset',
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]), ))
cfg = ConfigDict(cfg_dict)
model = MODELS.build(cfg.model)
ckpt = model.state_dict()
ckpt_filename = tempfile.mktemp()
torch.save(ckpt, ckpt_filename)
# test initialization
infer = MMSegInferencer(cfg, ckpt_filename)
# test forward
img = np.random.randint(0, 256, (4, 4, 3))
infer(img)
imgs = [img, img]
infer(imgs)
results = infer(imgs, out_dir=tempfile.gettempdir())
# test results
assert 'predictions' in results
assert 'visualization' in results
assert len(results['predictions']) == 2
assert results['predictions'][0].shape == (4, 4)