mmsegmentation/tests/test_engine/test_visualization_hook.py

64 lines
2.2 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import Mock
import torch
from mmengine.structures import PixelData
from mmseg.engine.hooks import SegVisualizationHook
2022-08-03 15:43:23 +08:00
from mmseg.structures import SegDataSample
from mmseg.visualization import SegLocalVisualizer
class TestVisualizationHook(TestCase):
def setUp(self) -> None:
h = 288
w = 512
num_class = 2
SegLocalVisualizer.get_instance('visualizer')
SegLocalVisualizer.dataset_meta = dict(
classes=('background', 'foreground'),
palette=[[120, 120, 120], [6, 230, 230]])
data_sample = SegDataSample()
data_sample.set_metainfo({'img_path': 'tests/data/color.jpg'})
self.data_batch = [{'data_sample': data_sample}] * 2
pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
pred_sem_seg = PixelData(**pred_sem_seg_data)
pred_seg_data_sample = SegDataSample()
pred_seg_data_sample.set_metainfo({'img_path': 'tests/data/color.jpg'})
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
self.outputs = [pred_seg_data_sample] * 2
def test_after_iter(self):
runner = Mock()
runner.iter = 1
hook = SegVisualizationHook(draw=True, interval=1)
hook._after_iter(
runner, 1, self.data_batch, self.outputs, mode='train')
hook._after_iter(runner, 1, self.data_batch, self.outputs, mode='val')
hook._after_iter(runner, 1, self.data_batch, self.outputs, mode='test')
def test_after_val_iter(self):
runner = Mock()
runner.iter = 2
hook = SegVisualizationHook(interval=1)
hook.after_val_iter(runner, 1, self.data_batch, self.outputs)
hook = SegVisualizationHook(draw=True, interval=1)
hook.after_val_iter(runner, 1, self.data_batch, self.outputs)
hook = SegVisualizationHook(
draw=True, interval=1, show=True, wait_time=1)
hook.after_val_iter(runner, 1, self.data_batch, self.outputs)
def test_after_test_iter(self):
runner = Mock()
runner.iter = 3
hook = SegVisualizationHook(draw=True, interval=1)
hook.after_test_iter(runner, 1, self.data_batch, self.outputs)