mmocr/tests/test_engine/test_hooks/test_visualization_hook.py
liukuikun ef683206ed
[Api] vis hook and data flow api (#1185)
* vis hook and data flow api

* fix comment

* add TODO for merging and rewriting after MultiDatasetWrapper
2022-08-08 11:37:46 +08:00

80 lines
2.6 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import shutil
import time
from unittest import TestCase
from unittest.mock import Mock
import torch
from mmengine.data import InstanceData
from mmocr.engine.hooks import VisualizationHook
from mmocr.structures import TextDetDataSample
from mmocr.visualization import TextDetLocalVisualizer
def _rand_bboxes(num_boxes, h, w):
cx, cy, bw, bh = torch.rand(num_boxes, 4).T
tl_x = ((cx * w) - (w * bw / 2)).clamp(0, w).unsqueeze(0)
tl_y = ((cy * h) - (h * bh / 2)).clamp(0, h).unsqueeze(0)
br_x = ((cx * w) + (w * bw / 2)).clamp(0, w).unsqueeze(0)
br_y = ((cy * h) + (h * bh / 2)).clamp(0, h).unsqueeze(0)
bboxes = torch.cat([tl_x, tl_y, br_x, br_y], dim=0).T
return bboxes
class TestVisualizationHook(TestCase):
def setUp(self) -> None:
data_sample = TextDetDataSample()
data_sample.set_metainfo({
'img_path':
osp.join(
osp.dirname(__file__),
'../../data/det_toy_dataset/imgs/test/img_1.jpg')
})
pred_instances = InstanceData()
pred_instances.bboxes = _rand_bboxes(5, 10, 12)
pred_instances.labels = torch.randint(0, 2, (5, ))
pred_instances.scores = torch.rand((5, ))
data_sample.pred_instances = pred_instances
self.outputs = [data_sample] * 2
self.data_batch = None
def test_after_val_iter(self):
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
TextDetLocalVisualizer.get_instance(
'visualizer_val',
vis_backends=[dict(type='LocalVisBackend', img_save_dir='')],
save_dir=timestamp)
runner = Mock()
runner.iter = 1
hook = VisualizationHook(enable=True, interval=1)
self.assertFalse(osp.exists(timestamp))
hook.after_val_iter(runner, 1, self.data_batch, self.outputs)
self.assertTrue(osp.exists(timestamp))
shutil.rmtree(timestamp)
def test_after_test_iter(self):
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
TextDetLocalVisualizer.get_instance(
'visualizer_test',
vis_backends=[dict(type='LocalVisBackend', img_save_dir='')],
save_dir=timestamp)
runner = Mock()
runner.iter = 1
hook = VisualizationHook(enable=False)
hook.after_test_iter(runner, 1, self.data_batch, self.outputs)
self.assertFalse(osp.exists(timestamp))
hook = VisualizationHook(enable=True)
hook.after_test_iter(runner, 1, self.data_batch, self.outputs)
self.assertTrue(osp.exists(timestamp))
shutil.rmtree(timestamp)