mirror of https://github.com/open-mmlab/mmocr.git
80 lines
2.6 KiB
Python
80 lines
2.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os.path as osp
|
|
import shutil
|
|
import time
|
|
from unittest import TestCase
|
|
from unittest.mock import Mock
|
|
|
|
import torch
|
|
from mmengine.structures import InstanceData
|
|
|
|
from mmocr.engine.hooks import VisualizationHook
|
|
from mmocr.structures import TextDetDataSample
|
|
from mmocr.visualization import TextDetLocalVisualizer
|
|
|
|
|
|
def _rand_bboxes(num_boxes, h, w):
|
|
cx, cy, bw, bh = torch.rand(num_boxes, 4).T
|
|
|
|
tl_x = ((cx * w) - (w * bw / 2)).clamp(0, w).unsqueeze(0)
|
|
tl_y = ((cy * h) - (h * bh / 2)).clamp(0, h).unsqueeze(0)
|
|
br_x = ((cx * w) + (w * bw / 2)).clamp(0, w).unsqueeze(0)
|
|
br_y = ((cy * h) + (h * bh / 2)).clamp(0, h).unsqueeze(0)
|
|
|
|
bboxes = torch.cat([tl_x, tl_y, br_x, br_y], dim=0).T
|
|
return bboxes
|
|
|
|
|
|
class TestVisualizationHook(TestCase):
|
|
|
|
def setUp(self) -> None:
|
|
|
|
data_sample = TextDetDataSample()
|
|
data_sample.set_metainfo({
|
|
'img_path':
|
|
osp.join(
|
|
osp.dirname(__file__),
|
|
'../../data/det_toy_dataset/imgs/test/img_1.jpg')
|
|
})
|
|
|
|
pred_instances = InstanceData()
|
|
pred_instances.bboxes = _rand_bboxes(5, 10, 12)
|
|
pred_instances.labels = torch.randint(0, 2, (5, ))
|
|
pred_instances.scores = torch.rand((5, ))
|
|
|
|
data_sample.pred_instances = pred_instances
|
|
self.outputs = [data_sample] * 2
|
|
self.data_batch = None
|
|
|
|
def test_after_val_iter(self):
|
|
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
|
TextDetLocalVisualizer.get_instance(
|
|
'visualizer_val',
|
|
vis_backends=[dict(type='LocalVisBackend', img_save_dir='')],
|
|
save_dir=timestamp)
|
|
runner = Mock()
|
|
runner.iter = 1
|
|
hook = VisualizationHook(enable=True, interval=1)
|
|
self.assertFalse(osp.exists(timestamp))
|
|
hook.after_val_iter(runner, 1, self.data_batch, self.outputs)
|
|
self.assertTrue(osp.exists(timestamp))
|
|
shutil.rmtree(timestamp)
|
|
|
|
def test_after_test_iter(self):
|
|
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
|
TextDetLocalVisualizer.get_instance(
|
|
'visualizer_test',
|
|
vis_backends=[dict(type='LocalVisBackend', img_save_dir='')],
|
|
save_dir=timestamp)
|
|
runner = Mock()
|
|
runner.iter = 1
|
|
|
|
hook = VisualizationHook(enable=False)
|
|
hook.after_test_iter(runner, 1, self.data_batch, self.outputs)
|
|
self.assertFalse(osp.exists(timestamp))
|
|
|
|
hook = VisualizationHook(enable=True)
|
|
hook.after_test_iter(runner, 1, self.data_batch, self.outputs)
|
|
self.assertTrue(osp.exists(timestamp))
|
|
shutil.rmtree(timestamp)
|