mirror of https://github.com/open-mmlab/mmocr.git
139 lines
5.6 KiB
Python
139 lines
5.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os.path as osp
|
|
import random
|
|
import tempfile
|
|
from unittest import TestCase, mock
|
|
|
|
import mmcv
|
|
import mmengine
|
|
import numpy as np
|
|
import torch
|
|
from mmengine.structures import InstanceData
|
|
|
|
from mmocr.apis.inferencers import TextDetInferencer
|
|
from mmocr.utils.check_argument import is_type_list
|
|
from mmocr.utils.typing_utils import TextDetDataSample
|
|
|
|
|
|
class TestTextDetinferencer(TestCase):
|
|
|
|
@mock.patch('mmengine.infer.infer._load_checkpoint')
|
|
def setUp(self, mock_load):
|
|
mock_load.side_effect = lambda *x, **y: None
|
|
self.inferencer = TextDetInferencer('DB_r18')
|
|
seed = 1
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
|
|
@mock.patch('mmengine.infer.infer._load_checkpoint')
|
|
def test_init(self, mock_load):
|
|
mock_load.side_effect = lambda *x, **y: None
|
|
# init from metafile
|
|
TextDetInferencer('dbnet_resnet18_fpnc_1200e_icdar2015')
|
|
# init from cfg
|
|
TextDetInferencer(
|
|
'configs/textdet/dbnet/'
|
|
'dbnet_resnet18_fpnc_1200e_icdar2015.py',
|
|
'https://download.openmmlab.com/mmocr/textdet/dbnet/'
|
|
'dbnet_resnet18_fpnc_1200e_icdar2015/'
|
|
'dbnet_resnet18_fpnc_1200e_icdar2015_20220825_221614-7c0e94f2.pth')
|
|
|
|
def assert_predictions_equal(self, preds1, preds2):
|
|
for pred1, pred2 in zip(preds1, preds2):
|
|
self.assert_prediction_equal(pred1, pred2)
|
|
|
|
def assert_prediction_equal(self, pred1, pred2):
|
|
self.assertTrue(np.allclose(pred1['polygons'], pred2['polygons'], 0.1))
|
|
self.assertTrue(np.allclose(pred1['scores'], pred2['scores'], 0.1))
|
|
|
|
def test_call(self):
|
|
# single img
|
|
img_path = 'tests/data/det_toy_dataset/imgs/test/img_1.jpg'
|
|
res_path = self.inferencer(img_path, return_vis=True)
|
|
# ndarray
|
|
img = mmcv.imread(img_path)
|
|
res_ndarray = self.inferencer(img, return_vis=True)
|
|
self.assert_predictions_equal(res_path['predictions'],
|
|
res_ndarray['predictions'])
|
|
self.assertTrue(
|
|
np.allclose(res_path['visualization'],
|
|
res_ndarray['visualization']))
|
|
|
|
# multiple images
|
|
img_paths = [
|
|
'tests/data/det_toy_dataset/imgs/test/img_1.jpg',
|
|
'tests/data/det_toy_dataset/imgs/test/img_2.jpg'
|
|
]
|
|
res_path = self.inferencer(img_paths, return_vis=True)
|
|
# list of ndarray
|
|
imgs = [mmcv.imread(p) for p in img_paths]
|
|
res_ndarray = self.inferencer(imgs, return_vis=True)
|
|
self.assert_predictions_equal(res_path['predictions'],
|
|
res_ndarray['predictions'])
|
|
for i in range(len(img_paths)):
|
|
self.assertTrue(
|
|
np.allclose(res_path['visualization'][i],
|
|
res_ndarray['visualization'][i]))
|
|
|
|
# img dir, test different batch sizes
|
|
img_dir = 'tests/data/det_toy_dataset/imgs/test/'
|
|
res_bs1 = self.inferencer(img_dir, batch_size=1, return_vis=True)
|
|
res_bs3 = self.inferencer(img_dir, batch_size=3, return_vis=True)
|
|
self.assert_predictions_equal(res_bs1['predictions'],
|
|
res_bs3['predictions'])
|
|
self.assertTrue(
|
|
np.array_equal(res_bs1['visualization'], res_bs3['visualization']))
|
|
|
|
def test_visualize(self):
|
|
img_paths = [
|
|
'tests/data/det_toy_dataset/imgs/test/img_1.jpg',
|
|
'tests/data/det_toy_dataset/imgs/test/img_2.jpg'
|
|
]
|
|
|
|
# img_out_dir
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
self.inferencer(img_paths, out_dir=tmp_dir, save_vis=True)
|
|
for img_dir in ['img_1.jpg', 'img_2.jpg']:
|
|
self.assertTrue(osp.exists(osp.join(tmp_dir, 'vis', img_dir)))
|
|
|
|
def test_postprocess(self):
|
|
# return_datasample
|
|
img_path = 'tests/data/det_toy_dataset/imgs/test/img_1.jpg'
|
|
res = self.inferencer(img_path, return_datasamples=True)
|
|
self.assertTrue(is_type_list(res['predictions'], TextDetDataSample))
|
|
|
|
# dump predictions
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
res = self.inferencer(
|
|
img_path, print_result=True, out_dir=tmp_dir, save_pred=True)
|
|
dumped_res = mmengine.load(
|
|
osp.join(tmp_dir, 'preds', 'img_1.json'))
|
|
self.assert_prediction_equal(res['predictions'][0], dumped_res)
|
|
|
|
def test_pred2dict(self):
|
|
data_sample = TextDetDataSample()
|
|
data_sample.pred_instances = InstanceData()
|
|
|
|
data_sample.pred_instances.scores = np.array([0.9])
|
|
data_sample.pred_instances.polygons = [
|
|
np.array([0, 0, 0, 1, 1, 1, 1, 0])
|
|
]
|
|
res = self.inferencer.pred2dict(data_sample)
|
|
self.assertListAlmostEqual(res['polygons'], [[0, 0, 0, 1, 1, 1, 1, 0]])
|
|
self.assertListAlmostEqual(res['scores'], [0.9])
|
|
|
|
data_sample.pred_instances.bboxes = np.array([[0, 0, 1, 1]])
|
|
data_sample.pred_instances.scores = torch.FloatTensor([0.9])
|
|
res = self.inferencer.pred2dict(data_sample)
|
|
self.assertListAlmostEqual(res['polygons'], [[0, 0, 0, 1, 1, 1, 1, 0]])
|
|
self.assertListAlmostEqual(res['bboxes'], [[0, 0, 1, 1]])
|
|
self.assertListAlmostEqual(res['scores'], [0.9])
|
|
|
|
def assertListAlmostEqual(self, list1, list2, places=7):
|
|
for i in range(len(list1)):
|
|
if isinstance(list1[i], list):
|
|
self.assertListAlmostEqual(list1[i], list2[i], places=places)
|
|
else:
|
|
self.assertAlmostEqual(list1[i], list2[i], places=places)
|