Migrate part of old_tests

pull/1178/head
gaotongxiao 2022-07-13 14:16:54 +00:00
parent f107991ac1
commit eb2d5b525a
14 changed files with 44 additions and 302 deletions

View File

@ -23,7 +23,7 @@ from .polygon_utils import (boundary_iou, crop_polygon, is_poly_inside_rect,
poly_union, polys2shapely, rescale_polygon,
rescale_polygons, shapely2poly)
from .setup_env import register_all_modules
from .string_util import StringStrip
from .string_utils import StringStripper
from .typing import (ColorType, ConfigType, DetSampleList, InitConfigType,
KIESampleList, MultiConfig, OptConfigType,
OptDetSampleList, OptInitConfigType, OptKIESampleList,
@ -34,7 +34,7 @@ __all__ = [
'Registry', 'build_from_cfg', 'get_root_logger', 'collect_env',
'is_3dlist', 'is_type_list', 'is_none_or_type', 'equal_len', 'is_2dlist',
'valid_boundary', 'list_to_file', 'list_from_file', 'is_on_same_line',
'stitch_boxes_into_lines', 'StringStrip', 'revert_sync_batchnorm',
'stitch_boxes_into_lines', 'StringStripper', 'revert_sync_batchnorm',
'bezier_to_polygon', 'sort_points', 'dump_ocr_data',
'recog_anno_to_imginfo', 'rescale_polygons', 'rescale_polygon',
'rescale_bboxes', 'bbox2poly', 'crop_polygon', 'is_poly_inside_rect',

View File

@ -4,7 +4,7 @@ import warnings
from typing import Dict, Tuple
from mmocr.registry import TASK_UTILS
from mmocr.utils.string_util import StringStrip
from mmocr.utils.string_utils import StringStripper
@TASK_UTILS.register_module()
@ -33,7 +33,7 @@ class LineStrParser:
self.keys = keys
self.keys_idx = keys_idx
self.separator = separator
self.strip_cls = StringStrip(**kwargs)
self.strip_cls = StringStripper(**kwargs)
def __call__(self, in_str: str) -> Dict:
line_str = self.strip_cls(in_str)

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
class StringStrip:
class StringStripper:
"""Removing the leading and/or the trailing characters based on the string
argument passed.

View File

@ -1,48 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import mmocr.utils as utils
def test_is_3dlist():
assert utils.is_3dlist([])
assert utils.is_3dlist([[]])
assert utils.is_3dlist([[[]]])
assert utils.is_3dlist([[[1]]])
assert not utils.is_3dlist([[1, 2]])
assert not utils.is_3dlist([[np.array([1, 2])]])
def test_is_2dlist():
assert utils.is_2dlist([])
assert utils.is_2dlist([[]])
assert utils.is_2dlist([[1]])
def test_is_type_list():
assert utils.is_type_list([], int)
assert utils.is_type_list([], float)
assert utils.is_type_list([np.array([])], np.ndarray)
assert utils.is_type_list([1], int)
assert utils.is_type_list(['str'], str)
def test_is_none_or_type():
assert utils.is_none_or_type(None, int)
assert utils.is_none_or_type(1.0, float)
assert utils.is_none_or_type(np.ndarray([]), np.ndarray)
assert utils.is_none_or_type(1, int)
assert utils.is_none_or_type('str', str)
def test_valid_boundary():
x = [0, 0, 1, 0, 1, 1, 0, 1]
assert not utils.valid_boundary(x, True)
assert not utils.valid_boundary([0])
assert utils.valid_boundary(x, False)
x = [0, 0, 1, 0, 1, 1, 0, 1, 1]
assert utils.valid_boundary(x, True)

View File

@ -1,22 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmcv.cnn.bricks import ConvModule
from mmocr.utils import revert_sync_batchnorm
def test_revert_sync_batchnorm():
conv_syncbn = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN')).to('cpu')
conv_syncbn.train()
x = torch.randn(1, 3, 10, 10)
# Will raise an ValueError saying SyncBN does not run on CPU
with pytest.raises(ValueError):
y = conv_syncbn(x)
conv_bn = revert_sync_batchnorm(conv_syncbn)
y = conv_bn(x)
assert y.shape == (1, 8, 9, 9)
assert conv_bn.training == conv_syncbn.training
conv_syncbn.eval()
conv_bn = revert_sync_batchnorm(conv_syncbn)
assert conv_bn.training == conv_syncbn.training

View File

@ -1,68 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import multiprocessing as mp
import os
import platform
import cv2
from mmcv import Config
from mmocr.utils import setup_multi_processes
def test_setup_multi_processes():
# temp save system setting
sys_start_mehod = mp.get_start_method(allow_none=True)
sys_cv_threads = cv2.getNumThreads()
# pop and temp save system env vars
sys_omp_threads = os.environ.pop('OMP_NUM_THREADS', default=None)
sys_mkl_threads = os.environ.pop('MKL_NUM_THREADS', default=None)
# test config without setting env
config = dict(data=dict(workers_per_gpu=2))
cfg = Config(config)
setup_multi_processes(cfg)
assert os.getenv('OMP_NUM_THREADS') == '1'
assert os.getenv('MKL_NUM_THREADS') == '1'
# when set to 0, the num threads will be 1
assert cv2.getNumThreads() == 1
if platform.system() != 'Windows':
assert mp.get_start_method() == 'fork'
# test num workers <= 1
os.environ.pop('OMP_NUM_THREADS')
os.environ.pop('MKL_NUM_THREADS')
config = dict(data=dict(workers_per_gpu=0))
cfg = Config(config)
setup_multi_processes(cfg)
assert 'OMP_NUM_THREADS' not in os.environ
assert 'MKL_NUM_THREADS' not in os.environ
# test manually set env var
os.environ['OMP_NUM_THREADS'] = '4'
config = dict(data=dict(workers_per_gpu=2))
cfg = Config(config)
setup_multi_processes(cfg)
assert os.getenv('OMP_NUM_THREADS') == '4'
# test manually set opencv threads and mp start method
config = dict(
data=dict(workers_per_gpu=2),
opencv_num_threads=4,
mp_start_method='spawn')
cfg = Config(config)
setup_multi_processes(cfg)
assert cv2.getNumThreads() == 4
assert mp.get_start_method() == 'spawn'
# revert setting to avoid affecting other programs
if sys_start_mehod:
mp.set_start_method(sys_start_mehod, force=True)
cv2.setNumThreads(sys_cv_threads)
if sys_omp_threads:
os.environ['OMP_NUM_THREADS'] = sys_omp_threads
else:
os.environ.pop('OMP_NUM_THREADS')
if sys_mkl_threads:
os.environ['MKL_NUM_THREADS'] = sys_mkl_threads
else:
os.environ.pop('MKL_NUM_THREADS')

View File

@ -1,113 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch
from mmocr.models.textdet.postprocessors import (DBPostprocessor,
FCEPostprocessor,
TextSnakePostprocessor)
from mmocr.models.textdet.postprocessors.utils import (comps2boundaries,
poly_nms)
def test_db_boxes_from_bitmaps():
"""Test the boxes_from_bitmaps function in db_decoder."""
pred = np.array([[[0.8, 0.8, 0.8, 0.8, 0], [0.8, 0.8, 0.8, 0.8, 0],
[0.8, 0.8, 0.8, 0.8, 0], [0.8, 0.8, 0.8, 0.8, 0],
[0.8, 0.8, 0.8, 0.8, 0]]])
preds = torch.FloatTensor(pred).requires_grad_(True)
db_decode = DBPostprocessor(text_repr_type='quad', min_text_width=0)
boxes = db_decode(preds)
assert len(boxes) == 1
def test_fcenet_decode():
k = 1
preds = []
preds.append(torch.ones(1, 4, 10, 10))
preds.append(torch.ones(1, 4 * k + 2, 10, 10))
fcenet_decode = FCEPostprocessor(
fourier_degree=k, num_reconstr_points=50, nms_thr=0.01)
boundaries = fcenet_decode(preds=preds, scale=1)
assert isinstance(boundaries, list)
def test_poly_nms():
threshold = 0
polygons = []
polygons.append([10, 10, 10, 30, 30, 30, 30, 10, 0.95])
polygons.append([15, 15, 15, 25, 25, 25, 25, 15, 0.9])
polygons.append([40, 40, 40, 50, 50, 50, 50, 40, 0.85])
polygons.append([5, 5, 5, 15, 15, 15, 15, 5, 0.7])
keep_poly = poly_nms(polygons, threshold)
assert isinstance(keep_poly, list)
def test_comps2boundaries():
# test comps2boundaries
x1 = np.arange(2, 18, 2)
x2 = x1 + 2
y1 = np.ones(8) * 2
y2 = y1 + 2
comp_scores = np.ones(8, dtype=np.float32) * 0.9
text_comps = np.stack([x1, y1, x2, y1, x2, y2, x1, y2,
comp_scores]).transpose()
comp_labels = np.array([1, 1, 1, 1, 1, 3, 5, 5])
shuffle = [3, 2, 5, 7, 6, 0, 4, 1]
boundaries = comps2boundaries(text_comps[shuffle], comp_labels[shuffle])
assert len(boundaries) == 3
# test comps2boundaries with blank inputs
boundaries = comps2boundaries(text_comps[[]], comp_labels[[]])
assert len(boundaries) == 0
def test_textsnake_decode():
maps = torch.zeros((1, 6, 224, 224), dtype=torch.float)
maps[:, 0:2, :, :] = -10.
maps[:, 0, 60:100, 50:170] = 10.
maps[:, 1, 75:85, 60:160] = 10.
maps[:, 2, 75:85, 60:160] = 0.
maps[:, 3, 75:85, 60:160] = 1.
maps[:, 4, 75:85, 60:160] = 10.
# test decoding with text center region of small area
maps[:, 0:2, 150:152, 5:7] = 10.
textsnake_decode = TextSnakePostprocessor()
results = textsnake_decode(torch.squeeze(maps))
assert len(results) == 1
# test decoding with small radius
maps.fill_(0.)
maps[:, 0:2, :, :] = -10.
maps[:, 0, 120:140, 20:40] = 10.
maps[:, 1, 120:140, 20:40] = 10.
maps[:, 2, 120:140, 20:40] = 0.
maps[:, 3, 120:140, 20:40] = 1.
maps[:, 4, 120:140, 20:40] = 0.5
results = textsnake_decode(torch.squeeze(maps))
assert len(results) == 0
def test_db_decode():
pred = torch.zeros((1, 8, 8))
pred[0, 2:7, 2:7] = 0.8
expect_result_quad = [[
1.0, 8.0, 1.0, 1.0, 8.0, 1.0, 8.0, 8.0, 0.800000011920929
]]
expect_result_poly = [[
8, 2, 8, 6, 6, 8, 2, 8, 1, 6, 1, 2, 2, 1, 6, 1, 0.800000011920929
]]
with pytest.raises(AssertionError):
DBPostprocessor(text_repr_type='dummpy')
db_decode = DBPostprocessor(text_repr_type='quad', min_text_width=1)
result_quad = db_decode(preds=pred)
db_decode = DBPostprocessor(text_repr_type='poly', min_text_width=1)
result_poly = db_decode(preds=pred)
assert result_quad == expect_result_quad
assert result_poly == expect_result_poly

View File

@ -36,6 +36,17 @@ class TestDBPostProcessor(unittest.TestCase):
self.assertTrue(
isinstance(results.pred_instances['scores'], torch.FloatTensor))
preds = (torch.FloatTensor([[0.8, 0.8, 0.8, 0.8, 0],
[0.8, 0.8, 0.8, 0.8, 0],
[0.8, 0.8, 0.8, 0.8, 0],
[0.8, 0.8, 0.8, 0.8, 0],
[0.8, 0.8, 0.8, 0.8, 0]]),
torch.rand([1, 10]), torch.rand([1, 10]))
postprocessor = DBPostprocessor(
text_repr_type=text_repr_type, min_text_width=0)
results = postprocessor.get_text_instances(preds, data_sample)
self.assertEqual(len(results.pred_instances['polygons']), 1)
postprocessor = DBPostprocessor(
min_text_score=1, text_repr_type=text_repr_type)
pred_result = (torch.rand(4, 5) * 0.8, torch.rand(4, 5) * 0.8,

View File

@ -27,3 +27,24 @@ class TestDRRGPostProcessor(unittest.TestCase):
self.assertIn('scores', result.pred_instances)
self.assertTrue(
isinstance(result.pred_instances['scores'], torch.FloatTensor))
def test_comps2polys(self):
postprocessor = DRRGPostprocessor()
x1 = np.arange(2, 18, 2)
x2 = x1 + 2
y1 = np.ones(8) * 2
y2 = y1 + 2
comp_scores = np.ones(8, dtype=np.float32) * 0.9
text_comps = np.stack([x1, y1, x2, y1, x2, y2, x1, y2,
comp_scores]).transpose()
comp_labels = np.array([1, 1, 1, 1, 1, 3, 5, 5])
shuffle = [3, 2, 5, 7, 6, 0, 4, 1]
boundaries = postprocessor._comps2polys(text_comps[shuffle],
comp_labels[shuffle])
self.assertEqual(len(boundaries[0]), 3)
boundaries = postprocessor._comps2polys(text_comps[[]],
comp_labels[[]])
self.assertEqual(len(boundaries[0]), 0)

View File

@ -1,39 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import datetime
import sys
from unittest import TestCase
from mmengine import DefaultScope
from mmocr.utils import register_all_modules
class TestSetupEnv(TestCase):
def test_register_all_modules(self):
from mmocr.registry import DATASETS
# not init default scope
sys.modules.pop('mmocr.datasets', None)
sys.modules.pop('mmocr.datasets.ocr_dataset', None)
DATASETS._module_dict.pop('OCRDataset', None)
self.assertFalse('OCRDataset' in DATASETS.module_dict)
register_all_modules(init_default_scope=False)
self.assertTrue('OCRDataset' in DATASETS.module_dict)
# init default scope
sys.modules.pop('mmocr.datasets')
sys.modules.pop('mmocr.datasets.ocr_dataset')
DATASETS._module_dict.pop('OCRDataset', None)
self.assertFalse('OCRDataset' in DATASETS.module_dict)
register_all_modules(init_default_scope=True)
self.assertTrue('OCRDataset' in DATASETS.module_dict)
self.assertEqual(DefaultScope.get_current_instance().scope_name,
'mmocr')
# init default scope when another scope is init
name = f'test-{datetime.datetime.now()}'
DefaultScope.get_instance(name, scope_name='test')
with self.assertWarnsRegex(
Warning, 'The current default scope "test" is not "mmocr"'):
register_all_modules(init_default_scope=True)

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
from mmocr.utils import StringStrip
from mmocr.utils import StringStripper
def test_string_strip():
@ -23,13 +23,13 @@ def test_string_strip():
for idx3, strip_str in enumerate(strip_str_list):
tmp_args = dict(
strip=strip, strip_pos=strip_pos, strip_str=strip_str)
strip_class = StringStrip(**tmp_args)
strip_class = StringStripper(**tmp_args)
i = idx1 * len(strip_pos_list) * len(
strip_str_list) + idx2 * len(strip_str_list) + idx3
assert strip_class(in_str_list[i]) == out_str_list[i]
with pytest.raises(AssertionError):
StringStrip(strip='strip')
StringStrip(strip_pos='head')
StringStrip(strip_str=['\n', '\t'])
StringStripper(strip='strip')
StringStripper(strip_pos='head')
StringStripper(strip_str=['\n', '\t'])

View File

@ -4,7 +4,7 @@ import json
from typing import List, Tuple
from mmocr.datasets import RecogLMDBDataset
from mmocr.utils import StringStrip, dump_ocr_data, recog_anno_to_imginfo
from mmocr.utils import StringStripper, dump_ocr_data, recog_anno_to_imginfo
def parse_legacy_data(in_path: str,
@ -21,7 +21,7 @@ def parse_legacy_data(in_path: str,
"""
file_paths = []
labels = []
strip_cls = StringStrip()
strip_cls = StringStripper()
if format == 'lmdb':
dataset = RecogLMDBDataset(
in_path,