mmsegmentation/tests/test_engine/test_visualization_hook.py
Miao Zheng 8de0050f25
[Refactor] data flow (#1956)
* [WIP] Refactor data flow

* model return

* [WIP] Refactor data flow

* support data_samples is optional

* fix benchmark

* fix base

* minors

* rebase

* fix api

* ut

* fix api inference

* comments

* docstring

* docstring

* docstring

* fix bug of slide inference

* add assert c > 1
2022-08-26 15:54:23 +08:00

64 lines
2.2 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import Mock
import torch
from mmengine.structures import PixelData
from mmseg.engine.hooks import SegVisualizationHook
from mmseg.structures import SegDataSample
from mmseg.visualization import SegLocalVisualizer
class TestVisualizationHook(TestCase):
def setUp(self) -> None:
h = 288
w = 512
num_class = 2
SegLocalVisualizer.get_instance('visualizer')
SegLocalVisualizer.dataset_meta = dict(
classes=('background', 'foreground'),
palette=[[120, 120, 120], [6, 230, 230]])
data_sample = SegDataSample()
data_sample.set_metainfo({'img_path': 'tests/data/color.jpg'})
self.data_batch = [{'data_sample': data_sample}] * 2
pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
pred_sem_seg = PixelData(**pred_sem_seg_data)
pred_seg_data_sample = SegDataSample()
pred_seg_data_sample.set_metainfo({'img_path': 'tests/data/color.jpg'})
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
self.outputs = [pred_seg_data_sample] * 2
def test_after_iter(self):
runner = Mock()
runner.iter = 1
hook = SegVisualizationHook(draw=True, interval=1)
hook._after_iter(
runner, 1, self.data_batch, self.outputs, mode='train')
hook._after_iter(runner, 1, self.data_batch, self.outputs, mode='val')
hook._after_iter(runner, 1, self.data_batch, self.outputs, mode='test')
def test_after_val_iter(self):
runner = Mock()
runner.iter = 2
hook = SegVisualizationHook(interval=1)
hook.after_val_iter(runner, 1, self.data_batch, self.outputs)
hook = SegVisualizationHook(draw=True, interval=1)
hook.after_val_iter(runner, 1, self.data_batch, self.outputs)
hook = SegVisualizationHook(
draw=True, interval=1, show=True, wait_time=1)
hook.after_val_iter(runner, 1, self.data_batch, self.outputs)
def test_after_test_iter(self):
runner = Mock()
runner.iter = 3
hook = SegVisualizationHook(draw=True, interval=1)
hook.after_test_iter(runner, 1, self.data_batch, self.outputs)