mmsegmentation/tests/test_engine/test_visualization_hook.py
谢昕辰 5d9650838e
[Fix] Fix demo scripts (#1815)
* [Feature] Add SegVisualizer

* change name to visualizer_example

* fix inference api

* fix video demo and refine inference api

* fix

* mmseg compose

* set default device to cuda:0

* fix import

* update dir

* rm engine/visualizer ut

* refine inference api and docs

* rename

Co-authored-by: MengzhangLI <mcmong@pku.edu.cn>
2022-07-29 18:37:20 +08:00

63 lines
2.1 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import Mock
import torch
from mmengine.data import PixelData
from mmseg.data import SegDataSample
from mmseg.engine.hooks import SegVisualizationHook
from mmseg.visualization import SegLocalVisualizer
class TestVisualizationHook(TestCase):
def setUp(self) -> None:
h = 288
w = 512
num_class = 2
SegLocalVisualizer.get_instance('visualizer')
SegLocalVisualizer.dataset_meta = dict(
classes=('background', 'foreground'),
palette=[[120, 120, 120], [6, 230, 230]])
data_sample = SegDataSample()
data_sample.set_metainfo({'img_path': 'tests/data/color.jpg'})
self.data_batch = [{'data_sample': data_sample}] * 2
pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w)))
pred_sem_seg = PixelData(**pred_sem_seg_data)
pred_seg_data_sample = SegDataSample()
pred_seg_data_sample.pred_sem_seg = pred_sem_seg
self.outputs = [pred_seg_data_sample] * 2
def test_after_iter(self):
runner = Mock()
runner.iter = 1
hook = SegVisualizationHook(draw=True, interval=1)
hook._after_iter(
runner, 1, self.data_batch, self.outputs, mode='train')
hook._after_iter(runner, 1, self.data_batch, self.outputs, mode='val')
hook._after_iter(runner, 1, self.data_batch, self.outputs, mode='test')
def test_after_val_iter(self):
runner = Mock()
runner.iter = 2
hook = SegVisualizationHook(interval=1)
hook.after_val_iter(runner, 1, self.data_batch, self.outputs)
hook = SegVisualizationHook(draw=True, interval=1)
hook.after_val_iter(runner, 1, self.data_batch, self.outputs)
hook = SegVisualizationHook(
draw=True, interval=1, show=True, wait_time=1)
hook.after_val_iter(runner, 1, self.data_batch, self.outputs)
def test_after_test_iter(self):
runner = Mock()
runner.iter = 3
hook = SegVisualizationHook(draw=True, interval=1)
hook.after_iter(runner, 1, self.data_batch, self.outputs)