# Copyright (c) OpenMMLab. All rights reserved. from unittest import TestCase from unittest.mock import Mock import torch from mmengine.data import PixelData from mmseg.data import SegDataSample from mmseg.engine.hooks import SegVisualizationHook from mmseg.visualization import SegLocalVisualizer class TestVisualizationHook(TestCase): def setUp(self) -> None: h = 288 w = 512 num_class = 2 SegLocalVisualizer.get_instance('visualizer') SegLocalVisualizer.dataset_meta = dict( classes=('background', 'foreground'), palette=[[120, 120, 120], [6, 230, 230]]) data_sample = SegDataSample() data_sample.set_metainfo({'img_path': 'tests/data/color.jpg'}) self.data_batch = [{'data_sample': data_sample}] * 2 pred_sem_seg_data = dict(data=torch.randint(0, num_class, (1, h, w))) pred_sem_seg = PixelData(**pred_sem_seg_data) pred_seg_data_sample = SegDataSample() pred_seg_data_sample.pred_sem_seg = pred_sem_seg self.outputs = [pred_seg_data_sample] * 2 def test_after_iter(self): runner = Mock() runner.iter = 1 hook = SegVisualizationHook(draw=True, interval=1) hook._after_iter( runner, 1, self.data_batch, self.outputs, mode='train') hook._after_iter(runner, 1, self.data_batch, self.outputs, mode='val') hook._after_iter(runner, 1, self.data_batch, self.outputs, mode='test') def test_after_val_iter(self): runner = Mock() runner.iter = 2 hook = SegVisualizationHook(interval=1) hook.after_val_iter(runner, 1, self.data_batch, self.outputs) hook = SegVisualizationHook(draw=True, interval=1) hook.after_val_iter(runner, 1, self.data_batch, self.outputs) hook = SegVisualizationHook( draw=True, interval=1, show=True, wait_time=1) hook.after_val_iter(runner, 1, self.data_batch, self.outputs) def test_after_test_iter(self): runner = Mock() runner.iter = 3 hook = SegVisualizationHook(draw=True, interval=1) hook.after_iter(runner, 1, self.data_batch, self.outputs)