72 lines
2.9 KiB
Python
72 lines
2.9 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from unittest.mock import Mock
|
|
|
|
import torch
|
|
|
|
from mmengine.data import BaseDataElement
|
|
from mmengine.hooks import NaiveVisualizationHook
|
|
|
|
|
|
class TestNaiveVisualizationHook:
|
|
|
|
def test_after_train_iter(self):
|
|
naive_visualization_hook = NaiveVisualizationHook()
|
|
runner = Mock(iter=1)
|
|
runner.writer.add_image = Mock()
|
|
inputs = torch.randn(1, 3, 15, 15)
|
|
batch_idx = 10
|
|
# test with normalize, resize, pad
|
|
gt_datasamples = BaseDataElement(
|
|
metainfo=dict(
|
|
img_norm_cfg=dict(
|
|
mean=(0, 0, 0), std=(0.5, 0.5, 0.5), to_bgr=True),
|
|
scale=(10, 10),
|
|
pad_shape=(15, 15, 3),
|
|
ori_height=5,
|
|
ori_width=5,
|
|
img_path='tmp.jpg'))
|
|
pred_datasamples = [BaseDataElement()]
|
|
data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)]
|
|
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
|
|
pred_datasamples)
|
|
# test with resize, pad
|
|
gt_datasamples = BaseDataElement(
|
|
metainfo=dict(
|
|
scale=(10, 10),
|
|
pad_shape=(15, 15, 3),
|
|
ori_height=5,
|
|
ori_width=5,
|
|
img_path='tmp.jpg'))
|
|
pred_datasamples = [BaseDataElement()]
|
|
data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)]
|
|
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
|
|
pred_datasamples)
|
|
# test with only resize
|
|
gt_datasamples = BaseDataElement(
|
|
metainfo=dict(
|
|
scale=(15, 15), ori_height=5, ori_width=5, img_path='tmp.jpg'))
|
|
pred_datasamples = [BaseDataElement()]
|
|
data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)]
|
|
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
|
|
pred_datasamples)
|
|
|
|
# test with only pad
|
|
gt_datasamples = BaseDataElement(
|
|
metainfo=dict(
|
|
pad_shape=(15, 15, 3),
|
|
ori_height=5,
|
|
ori_width=5,
|
|
img_path='tmp.jpg'))
|
|
pred_datasamples = [BaseDataElement()]
|
|
data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)]
|
|
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
|
|
pred_datasamples)
|
|
|
|
# test no transform
|
|
gt_datasamples = BaseDataElement(
|
|
metainfo=dict(ori_height=15, ori_width=15, img_path='tmp.jpg'))
|
|
pred_datasamples = [BaseDataElement()]
|
|
data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)]
|
|
naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch,
|
|
pred_datasamples)
|