mmengine/tests/test_hook/test_naive_visualization_hook.py
liukuikun 5f8f36e6a5
refactor visualization (#147)
* [WIP] add inline

* refactor vis module

* [Refactor] according review

* [Fix] fix comment

* fix some error

* Get sub visualizer be Visualizer.get_instance

* fix conflict

* fix lint

* fix unit test

* fix mypy

* fix comment

* fix lint

* update docstr

* update

* update instancedata

* remove replace __mro__ with issubclass

Co-authored-by: PJLAB\huanghaian <1286304229@qq.com>
Co-authored-by: HAOCHENYE <21724054@zju.edu.cn>
2022-04-15 15:56:06 +08:00

72 lines
2.9 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import Mock
import torch
from mmengine.data import BaseDataElement
from mmengine.hooks import NaiveVisualizationHook
class TestNaiveVisualizationHook:
def test_after_train_iter(self):
naive_visualization_hook = NaiveVisualizationHook()
runner = Mock(iter=1)
runner.visualizer.add_image = Mock()
inputs = torch.randn(1, 3, 15, 15)
batch_idx = 10
# test with normalize, resize, pad
gt_datasamples = BaseDataElement(
metainfo=dict(
img_norm_cfg=dict(
mean=(0, 0, 0), std=(0.5, 0.5, 0.5), to_bgr=True),
scale=(10, 10),
pad_shape=(15, 15, 3),
ori_height=5,
ori_width=5,
img_path='tmp.jpg'))
pred_datasamples = [BaseDataElement()]
data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)]
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
pred_datasamples)
# test with resize, pad
gt_datasamples = BaseDataElement(
metainfo=dict(
scale=(10, 10),
pad_shape=(15, 15, 3),
ori_height=5,
ori_width=5,
img_path='tmp.jpg'))
pred_datasamples = [BaseDataElement()]
data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)]
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
pred_datasamples)
# test with only resize
gt_datasamples = BaseDataElement(
metainfo=dict(
scale=(15, 15), ori_height=5, ori_width=5, img_path='tmp.jpg'))
pred_datasamples = [BaseDataElement()]
data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)]
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
pred_datasamples)
# test with only pad
gt_datasamples = BaseDataElement(
metainfo=dict(
pad_shape=(15, 15, 3),
ori_height=5,
ori_width=5,
img_path='tmp.jpg'))
pred_datasamples = [BaseDataElement()]
data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)]
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
pred_datasamples)
# test no transform
gt_datasamples = BaseDataElement(
metainfo=dict(ori_height=15, ori_width=15, img_path='tmp.jpg'))
pred_datasamples = [BaseDataElement()]
data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)]
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
pred_datasamples)