mirror of https://github.com/open-mmlab/mmocr.git
Migrate part of old_tests
parent
f107991ac1
commit
eb2d5b525a
|
@ -23,7 +23,7 @@ from .polygon_utils import (boundary_iou, crop_polygon, is_poly_inside_rect,
|
|||
poly_union, polys2shapely, rescale_polygon,
|
||||
rescale_polygons, shapely2poly)
|
||||
from .setup_env import register_all_modules
|
||||
from .string_util import StringStrip
|
||||
from .string_utils import StringStripper
|
||||
from .typing import (ColorType, ConfigType, DetSampleList, InitConfigType,
|
||||
KIESampleList, MultiConfig, OptConfigType,
|
||||
OptDetSampleList, OptInitConfigType, OptKIESampleList,
|
||||
|
@ -34,7 +34,7 @@ __all__ = [
|
|||
'Registry', 'build_from_cfg', 'get_root_logger', 'collect_env',
|
||||
'is_3dlist', 'is_type_list', 'is_none_or_type', 'equal_len', 'is_2dlist',
|
||||
'valid_boundary', 'list_to_file', 'list_from_file', 'is_on_same_line',
|
||||
'stitch_boxes_into_lines', 'StringStrip', 'revert_sync_batchnorm',
|
||||
'stitch_boxes_into_lines', 'StringStripper', 'revert_sync_batchnorm',
|
||||
'bezier_to_polygon', 'sort_points', 'dump_ocr_data',
|
||||
'recog_anno_to_imginfo', 'rescale_polygons', 'rescale_polygon',
|
||||
'rescale_bboxes', 'bbox2poly', 'crop_polygon', 'is_poly_inside_rect',
|
||||
|
|
|
@ -4,7 +4,7 @@ import warnings
|
|||
from typing import Dict, Tuple
|
||||
|
||||
from mmocr.registry import TASK_UTILS
|
||||
from mmocr.utils.string_util import StringStrip
|
||||
from mmocr.utils.string_utils import StringStripper
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
|
@ -33,7 +33,7 @@ class LineStrParser:
|
|||
self.keys = keys
|
||||
self.keys_idx = keys_idx
|
||||
self.separator = separator
|
||||
self.strip_cls = StringStrip(**kwargs)
|
||||
self.strip_cls = StringStripper(**kwargs)
|
||||
|
||||
def __call__(self, in_str: str) -> Dict:
|
||||
line_str = self.strip_cls(in_str)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
class StringStrip:
|
||||
class StringStripper:
|
||||
"""Removing the leading and/or the trailing characters based on the string
|
||||
argument passed.
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
|
||||
import mmocr.utils as utils
|
||||
|
||||
|
||||
def test_is_3dlist():
|
||||
|
||||
assert utils.is_3dlist([])
|
||||
assert utils.is_3dlist([[]])
|
||||
assert utils.is_3dlist([[[]]])
|
||||
assert utils.is_3dlist([[[1]]])
|
||||
assert not utils.is_3dlist([[1, 2]])
|
||||
assert not utils.is_3dlist([[np.array([1, 2])]])
|
||||
|
||||
|
||||
def test_is_2dlist():
|
||||
|
||||
assert utils.is_2dlist([])
|
||||
assert utils.is_2dlist([[]])
|
||||
assert utils.is_2dlist([[1]])
|
||||
|
||||
|
||||
def test_is_type_list():
|
||||
assert utils.is_type_list([], int)
|
||||
assert utils.is_type_list([], float)
|
||||
assert utils.is_type_list([np.array([])], np.ndarray)
|
||||
assert utils.is_type_list([1], int)
|
||||
assert utils.is_type_list(['str'], str)
|
||||
|
||||
|
||||
def test_is_none_or_type():
|
||||
|
||||
assert utils.is_none_or_type(None, int)
|
||||
assert utils.is_none_or_type(1.0, float)
|
||||
assert utils.is_none_or_type(np.ndarray([]), np.ndarray)
|
||||
assert utils.is_none_or_type(1, int)
|
||||
assert utils.is_none_or_type('str', str)
|
||||
|
||||
|
||||
def test_valid_boundary():
|
||||
|
||||
x = [0, 0, 1, 0, 1, 1, 0, 1]
|
||||
assert not utils.valid_boundary(x, True)
|
||||
assert not utils.valid_boundary([0])
|
||||
assert utils.valid_boundary(x, False)
|
||||
x = [0, 0, 1, 0, 1, 1, 0, 1, 1]
|
||||
assert utils.valid_boundary(x, True)
|
|
@ -1,22 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
from mmcv.cnn.bricks import ConvModule
|
||||
|
||||
from mmocr.utils import revert_sync_batchnorm
|
||||
|
||||
|
||||
def test_revert_sync_batchnorm():
|
||||
conv_syncbn = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN')).to('cpu')
|
||||
conv_syncbn.train()
|
||||
x = torch.randn(1, 3, 10, 10)
|
||||
# Will raise an ValueError saying SyncBN does not run on CPU
|
||||
with pytest.raises(ValueError):
|
||||
y = conv_syncbn(x)
|
||||
conv_bn = revert_sync_batchnorm(conv_syncbn)
|
||||
y = conv_bn(x)
|
||||
assert y.shape == (1, 8, 9, 9)
|
||||
assert conv_bn.training == conv_syncbn.training
|
||||
conv_syncbn.eval()
|
||||
conv_bn = revert_sync_batchnorm(conv_syncbn)
|
||||
assert conv_bn.training == conv_syncbn.training
|
|
@ -1,68 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import platform
|
||||
|
||||
import cv2
|
||||
from mmcv import Config
|
||||
|
||||
from mmocr.utils import setup_multi_processes
|
||||
|
||||
|
||||
def test_setup_multi_processes():
|
||||
# temp save system setting
|
||||
sys_start_mehod = mp.get_start_method(allow_none=True)
|
||||
sys_cv_threads = cv2.getNumThreads()
|
||||
# pop and temp save system env vars
|
||||
sys_omp_threads = os.environ.pop('OMP_NUM_THREADS', default=None)
|
||||
sys_mkl_threads = os.environ.pop('MKL_NUM_THREADS', default=None)
|
||||
|
||||
# test config without setting env
|
||||
config = dict(data=dict(workers_per_gpu=2))
|
||||
cfg = Config(config)
|
||||
setup_multi_processes(cfg)
|
||||
assert os.getenv('OMP_NUM_THREADS') == '1'
|
||||
assert os.getenv('MKL_NUM_THREADS') == '1'
|
||||
# when set to 0, the num threads will be 1
|
||||
assert cv2.getNumThreads() == 1
|
||||
if platform.system() != 'Windows':
|
||||
assert mp.get_start_method() == 'fork'
|
||||
|
||||
# test num workers <= 1
|
||||
os.environ.pop('OMP_NUM_THREADS')
|
||||
os.environ.pop('MKL_NUM_THREADS')
|
||||
config = dict(data=dict(workers_per_gpu=0))
|
||||
cfg = Config(config)
|
||||
setup_multi_processes(cfg)
|
||||
assert 'OMP_NUM_THREADS' not in os.environ
|
||||
assert 'MKL_NUM_THREADS' not in os.environ
|
||||
|
||||
# test manually set env var
|
||||
os.environ['OMP_NUM_THREADS'] = '4'
|
||||
config = dict(data=dict(workers_per_gpu=2))
|
||||
cfg = Config(config)
|
||||
setup_multi_processes(cfg)
|
||||
assert os.getenv('OMP_NUM_THREADS') == '4'
|
||||
|
||||
# test manually set opencv threads and mp start method
|
||||
config = dict(
|
||||
data=dict(workers_per_gpu=2),
|
||||
opencv_num_threads=4,
|
||||
mp_start_method='spawn')
|
||||
cfg = Config(config)
|
||||
setup_multi_processes(cfg)
|
||||
assert cv2.getNumThreads() == 4
|
||||
assert mp.get_start_method() == 'spawn'
|
||||
|
||||
# revert setting to avoid affecting other programs
|
||||
if sys_start_mehod:
|
||||
mp.set_start_method(sys_start_mehod, force=True)
|
||||
cv2.setNumThreads(sys_cv_threads)
|
||||
if sys_omp_threads:
|
||||
os.environ['OMP_NUM_THREADS'] = sys_omp_threads
|
||||
else:
|
||||
os.environ.pop('OMP_NUM_THREADS')
|
||||
if sys_mkl_threads:
|
||||
os.environ['MKL_NUM_THREADS'] = sys_mkl_threads
|
||||
else:
|
||||
os.environ.pop('MKL_NUM_THREADS')
|
|
@ -1,113 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmocr.models.textdet.postprocessors import (DBPostprocessor,
|
||||
FCEPostprocessor,
|
||||
TextSnakePostprocessor)
|
||||
from mmocr.models.textdet.postprocessors.utils import (comps2boundaries,
|
||||
poly_nms)
|
||||
|
||||
|
||||
def test_db_boxes_from_bitmaps():
|
||||
"""Test the boxes_from_bitmaps function in db_decoder."""
|
||||
pred = np.array([[[0.8, 0.8, 0.8, 0.8, 0], [0.8, 0.8, 0.8, 0.8, 0],
|
||||
[0.8, 0.8, 0.8, 0.8, 0], [0.8, 0.8, 0.8, 0.8, 0],
|
||||
[0.8, 0.8, 0.8, 0.8, 0]]])
|
||||
preds = torch.FloatTensor(pred).requires_grad_(True)
|
||||
db_decode = DBPostprocessor(text_repr_type='quad', min_text_width=0)
|
||||
boxes = db_decode(preds)
|
||||
assert len(boxes) == 1
|
||||
|
||||
|
||||
def test_fcenet_decode():
|
||||
|
||||
k = 1
|
||||
preds = []
|
||||
preds.append(torch.ones(1, 4, 10, 10))
|
||||
preds.append(torch.ones(1, 4 * k + 2, 10, 10))
|
||||
fcenet_decode = FCEPostprocessor(
|
||||
fourier_degree=k, num_reconstr_points=50, nms_thr=0.01)
|
||||
boundaries = fcenet_decode(preds=preds, scale=1)
|
||||
|
||||
assert isinstance(boundaries, list)
|
||||
|
||||
|
||||
def test_poly_nms():
|
||||
threshold = 0
|
||||
polygons = []
|
||||
polygons.append([10, 10, 10, 30, 30, 30, 30, 10, 0.95])
|
||||
polygons.append([15, 15, 15, 25, 25, 25, 25, 15, 0.9])
|
||||
polygons.append([40, 40, 40, 50, 50, 50, 50, 40, 0.85])
|
||||
polygons.append([5, 5, 5, 15, 15, 15, 15, 5, 0.7])
|
||||
|
||||
keep_poly = poly_nms(polygons, threshold)
|
||||
assert isinstance(keep_poly, list)
|
||||
|
||||
|
||||
def test_comps2boundaries():
|
||||
|
||||
# test comps2boundaries
|
||||
x1 = np.arange(2, 18, 2)
|
||||
x2 = x1 + 2
|
||||
y1 = np.ones(8) * 2
|
||||
y2 = y1 + 2
|
||||
comp_scores = np.ones(8, dtype=np.float32) * 0.9
|
||||
text_comps = np.stack([x1, y1, x2, y1, x2, y2, x1, y2,
|
||||
comp_scores]).transpose()
|
||||
comp_labels = np.array([1, 1, 1, 1, 1, 3, 5, 5])
|
||||
shuffle = [3, 2, 5, 7, 6, 0, 4, 1]
|
||||
boundaries = comps2boundaries(text_comps[shuffle], comp_labels[shuffle])
|
||||
assert len(boundaries) == 3
|
||||
|
||||
# test comps2boundaries with blank inputs
|
||||
boundaries = comps2boundaries(text_comps[[]], comp_labels[[]])
|
||||
assert len(boundaries) == 0
|
||||
|
||||
|
||||
def test_textsnake_decode():
|
||||
|
||||
maps = torch.zeros((1, 6, 224, 224), dtype=torch.float)
|
||||
maps[:, 0:2, :, :] = -10.
|
||||
maps[:, 0, 60:100, 50:170] = 10.
|
||||
maps[:, 1, 75:85, 60:160] = 10.
|
||||
maps[:, 2, 75:85, 60:160] = 0.
|
||||
maps[:, 3, 75:85, 60:160] = 1.
|
||||
maps[:, 4, 75:85, 60:160] = 10.
|
||||
# test decoding with text center region of small area
|
||||
maps[:, 0:2, 150:152, 5:7] = 10.
|
||||
textsnake_decode = TextSnakePostprocessor()
|
||||
results = textsnake_decode(torch.squeeze(maps))
|
||||
assert len(results) == 1
|
||||
|
||||
# test decoding with small radius
|
||||
maps.fill_(0.)
|
||||
maps[:, 0:2, :, :] = -10.
|
||||
maps[:, 0, 120:140, 20:40] = 10.
|
||||
maps[:, 1, 120:140, 20:40] = 10.
|
||||
maps[:, 2, 120:140, 20:40] = 0.
|
||||
maps[:, 3, 120:140, 20:40] = 1.
|
||||
maps[:, 4, 120:140, 20:40] = 0.5
|
||||
|
||||
results = textsnake_decode(torch.squeeze(maps))
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
def test_db_decode():
|
||||
pred = torch.zeros((1, 8, 8))
|
||||
pred[0, 2:7, 2:7] = 0.8
|
||||
expect_result_quad = [[
|
||||
1.0, 8.0, 1.0, 1.0, 8.0, 1.0, 8.0, 8.0, 0.800000011920929
|
||||
]]
|
||||
expect_result_poly = [[
|
||||
8, 2, 8, 6, 6, 8, 2, 8, 1, 6, 1, 2, 2, 1, 6, 1, 0.800000011920929
|
||||
]]
|
||||
with pytest.raises(AssertionError):
|
||||
DBPostprocessor(text_repr_type='dummpy')
|
||||
db_decode = DBPostprocessor(text_repr_type='quad', min_text_width=1)
|
||||
result_quad = db_decode(preds=pred)
|
||||
db_decode = DBPostprocessor(text_repr_type='poly', min_text_width=1)
|
||||
result_poly = db_decode(preds=pred)
|
||||
assert result_quad == expect_result_quad
|
||||
assert result_poly == expect_result_poly
|
|
@ -36,6 +36,17 @@ class TestDBPostProcessor(unittest.TestCase):
|
|||
self.assertTrue(
|
||||
isinstance(results.pred_instances['scores'], torch.FloatTensor))
|
||||
|
||||
preds = (torch.FloatTensor([[0.8, 0.8, 0.8, 0.8, 0],
|
||||
[0.8, 0.8, 0.8, 0.8, 0],
|
||||
[0.8, 0.8, 0.8, 0.8, 0],
|
||||
[0.8, 0.8, 0.8, 0.8, 0],
|
||||
[0.8, 0.8, 0.8, 0.8, 0]]),
|
||||
torch.rand([1, 10]), torch.rand([1, 10]))
|
||||
postprocessor = DBPostprocessor(
|
||||
text_repr_type=text_repr_type, min_text_width=0)
|
||||
results = postprocessor.get_text_instances(preds, data_sample)
|
||||
self.assertEqual(len(results.pred_instances['polygons']), 1)
|
||||
|
||||
postprocessor = DBPostprocessor(
|
||||
min_text_score=1, text_repr_type=text_repr_type)
|
||||
pred_result = (torch.rand(4, 5) * 0.8, torch.rand(4, 5) * 0.8,
|
||||
|
|
|
@ -27,3 +27,24 @@ class TestDRRGPostProcessor(unittest.TestCase):
|
|||
self.assertIn('scores', result.pred_instances)
|
||||
self.assertTrue(
|
||||
isinstance(result.pred_instances['scores'], torch.FloatTensor))
|
||||
|
||||
def test_comps2polys(self):
|
||||
postprocessor = DRRGPostprocessor()
|
||||
|
||||
x1 = np.arange(2, 18, 2)
|
||||
x2 = x1 + 2
|
||||
y1 = np.ones(8) * 2
|
||||
y2 = y1 + 2
|
||||
comp_scores = np.ones(8, dtype=np.float32) * 0.9
|
||||
text_comps = np.stack([x1, y1, x2, y1, x2, y2, x1, y2,
|
||||
comp_scores]).transpose()
|
||||
comp_labels = np.array([1, 1, 1, 1, 1, 3, 5, 5])
|
||||
shuffle = [3, 2, 5, 7, 6, 0, 4, 1]
|
||||
|
||||
boundaries = postprocessor._comps2polys(text_comps[shuffle],
|
||||
comp_labels[shuffle])
|
||||
self.assertEqual(len(boundaries[0]), 3)
|
||||
|
||||
boundaries = postprocessor._comps2polys(text_comps[[]],
|
||||
comp_labels[[]])
|
||||
self.assertEqual(len(boundaries[0]), 0)
|
||||
|
|
|
@ -1,39 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import datetime
|
||||
import sys
|
||||
from unittest import TestCase
|
||||
|
||||
from mmengine import DefaultScope
|
||||
|
||||
from mmocr.utils import register_all_modules
|
||||
|
||||
|
||||
class TestSetupEnv(TestCase):
|
||||
|
||||
def test_register_all_modules(self):
|
||||
from mmocr.registry import DATASETS
|
||||
|
||||
# not init default scope
|
||||
sys.modules.pop('mmocr.datasets', None)
|
||||
sys.modules.pop('mmocr.datasets.ocr_dataset', None)
|
||||
DATASETS._module_dict.pop('OCRDataset', None)
|
||||
self.assertFalse('OCRDataset' in DATASETS.module_dict)
|
||||
register_all_modules(init_default_scope=False)
|
||||
self.assertTrue('OCRDataset' in DATASETS.module_dict)
|
||||
|
||||
# init default scope
|
||||
sys.modules.pop('mmocr.datasets')
|
||||
sys.modules.pop('mmocr.datasets.ocr_dataset')
|
||||
DATASETS._module_dict.pop('OCRDataset', None)
|
||||
self.assertFalse('OCRDataset' in DATASETS.module_dict)
|
||||
register_all_modules(init_default_scope=True)
|
||||
self.assertTrue('OCRDataset' in DATASETS.module_dict)
|
||||
self.assertEqual(DefaultScope.get_current_instance().scope_name,
|
||||
'mmocr')
|
||||
|
||||
# init default scope when another scope is init
|
||||
name = f'test-{datetime.datetime.now()}'
|
||||
DefaultScope.get_instance(name, scope_name='test')
|
||||
with self.assertWarnsRegex(
|
||||
Warning, 'The current default scope "test" is not "mmocr"'):
|
||||
register_all_modules(init_default_scope=True)
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
|
||||
from mmocr.utils import StringStrip
|
||||
from mmocr.utils import StringStripper
|
||||
|
||||
|
||||
def test_string_strip():
|
||||
|
@ -23,13 +23,13 @@ def test_string_strip():
|
|||
for idx3, strip_str in enumerate(strip_str_list):
|
||||
tmp_args = dict(
|
||||
strip=strip, strip_pos=strip_pos, strip_str=strip_str)
|
||||
strip_class = StringStrip(**tmp_args)
|
||||
strip_class = StringStripper(**tmp_args)
|
||||
i = idx1 * len(strip_pos_list) * len(
|
||||
strip_str_list) + idx2 * len(strip_str_list) + idx3
|
||||
|
||||
assert strip_class(in_str_list[i]) == out_str_list[i]
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
StringStrip(strip='strip')
|
||||
StringStrip(strip_pos='head')
|
||||
StringStrip(strip_str=['\n', '\t'])
|
||||
StringStripper(strip='strip')
|
||||
StringStripper(strip_pos='head')
|
||||
StringStripper(strip_str=['\n', '\t'])
|
|
@ -4,7 +4,7 @@ import json
|
|||
from typing import List, Tuple
|
||||
|
||||
from mmocr.datasets import RecogLMDBDataset
|
||||
from mmocr.utils import StringStrip, dump_ocr_data, recog_anno_to_imginfo
|
||||
from mmocr.utils import StringStripper, dump_ocr_data, recog_anno_to_imginfo
|
||||
|
||||
|
||||
def parse_legacy_data(in_path: str,
|
||||
|
@ -21,7 +21,7 @@ def parse_legacy_data(in_path: str,
|
|||
"""
|
||||
file_paths = []
|
||||
labels = []
|
||||
strip_cls = StringStrip()
|
||||
strip_cls = StringStripper()
|
||||
if format == 'lmdb':
|
||||
dataset = RecogLMDBDataset(
|
||||
in_path,
|
||||
|
|
Loading…
Reference in New Issue