diff --git a/mmocr/utils/__init__.py b/mmocr/utils/__init__.py index f9174e4c..52bd3d6b 100644 --- a/mmocr/utils/__init__.py +++ b/mmocr/utils/__init__.py @@ -23,7 +23,7 @@ from .polygon_utils import (boundary_iou, crop_polygon, is_poly_inside_rect, poly_union, polys2shapely, rescale_polygon, rescale_polygons, shapely2poly) from .setup_env import register_all_modules -from .string_util import StringStrip +from .string_utils import StringStripper from .typing import (ColorType, ConfigType, DetSampleList, InitConfigType, KIESampleList, MultiConfig, OptConfigType, OptDetSampleList, OptInitConfigType, OptKIESampleList, @@ -34,7 +34,7 @@ __all__ = [ 'Registry', 'build_from_cfg', 'get_root_logger', 'collect_env', 'is_3dlist', 'is_type_list', 'is_none_or_type', 'equal_len', 'is_2dlist', 'valid_boundary', 'list_to_file', 'list_from_file', 'is_on_same_line', - 'stitch_boxes_into_lines', 'StringStrip', 'revert_sync_batchnorm', + 'stitch_boxes_into_lines', 'StringStripper', 'revert_sync_batchnorm', 'bezier_to_polygon', 'sort_points', 'dump_ocr_data', 'recog_anno_to_imginfo', 'rescale_polygons', 'rescale_polygon', 'rescale_bboxes', 'bbox2poly', 'crop_polygon', 'is_poly_inside_rect', diff --git a/mmocr/utils/parsers.py b/mmocr/utils/parsers.py index 39d1732f..87cc063d 100644 --- a/mmocr/utils/parsers.py +++ b/mmocr/utils/parsers.py @@ -4,7 +4,7 @@ import warnings from typing import Dict, Tuple from mmocr.registry import TASK_UTILS -from mmocr.utils.string_util import StringStrip +from mmocr.utils.string_utils import StringStripper @TASK_UTILS.register_module() @@ -33,7 +33,7 @@ class LineStrParser: self.keys = keys self.keys_idx = keys_idx self.separator = separator - self.strip_cls = StringStrip(**kwargs) + self.strip_cls = StringStripper(**kwargs) def __call__(self, in_str: str) -> Dict: line_str = self.strip_cls(in_str) diff --git a/mmocr/utils/string_util.py b/mmocr/utils/string_utils.py similarity index 98% rename from mmocr/utils/string_util.py rename to mmocr/utils/string_utils.py index 5a8946ee..4c597408 100644 --- a/mmocr/utils/string_util.py +++ b/mmocr/utils/string_utils.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -class StringStrip: +class StringStripper: """Removing the leading and/or the trailing characters based on the string argument passed. diff --git a/old_tests/test_utils/test_check_argument.py b/old_tests/test_utils/test_check_argument.py deleted file mode 100644 index bd639e37..00000000 --- a/old_tests/test_utils/test_check_argument.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import numpy as np - -import mmocr.utils as utils - - -def test_is_3dlist(): - - assert utils.is_3dlist([]) - assert utils.is_3dlist([[]]) - assert utils.is_3dlist([[[]]]) - assert utils.is_3dlist([[[1]]]) - assert not utils.is_3dlist([[1, 2]]) - assert not utils.is_3dlist([[np.array([1, 2])]]) - - -def test_is_2dlist(): - - assert utils.is_2dlist([]) - assert utils.is_2dlist([[]]) - assert utils.is_2dlist([[1]]) - - -def test_is_type_list(): - assert utils.is_type_list([], int) - assert utils.is_type_list([], float) - assert utils.is_type_list([np.array([])], np.ndarray) - assert utils.is_type_list([1], int) - assert utils.is_type_list(['str'], str) - - -def test_is_none_or_type(): - - assert utils.is_none_or_type(None, int) - assert utils.is_none_or_type(1.0, float) - assert utils.is_none_or_type(np.ndarray([]), np.ndarray) - assert utils.is_none_or_type(1, int) - assert utils.is_none_or_type('str', str) - - -def test_valid_boundary(): - - x = [0, 0, 1, 0, 1, 1, 0, 1] - assert not utils.valid_boundary(x, True) - assert not utils.valid_boundary([0]) - assert utils.valid_boundary(x, False) - x = [0, 0, 1, 0, 1, 1, 0, 1, 1] - assert utils.valid_boundary(x, True) diff --git a/old_tests/test_utils/test_model.py b/old_tests/test_utils/test_model.py deleted file mode 100644 index d86d821a..00000000 --- a/old_tests/test_utils/test_model.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import pytest -import torch -from mmcv.cnn.bricks import ConvModule - -from mmocr.utils import revert_sync_batchnorm - - -def test_revert_sync_batchnorm(): - conv_syncbn = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN')).to('cpu') - conv_syncbn.train() - x = torch.randn(1, 3, 10, 10) - # Will raise an ValueError saying SyncBN does not run on CPU - with pytest.raises(ValueError): - y = conv_syncbn(x) - conv_bn = revert_sync_batchnorm(conv_syncbn) - y = conv_bn(x) - assert y.shape == (1, 8, 9, 9) - assert conv_bn.training == conv_syncbn.training - conv_syncbn.eval() - conv_bn = revert_sync_batchnorm(conv_syncbn) - assert conv_bn.training == conv_syncbn.training diff --git a/old_tests/test_utils/test_setup_env.py b/old_tests/test_utils/test_setup_env.py deleted file mode 100644 index b65b9647..00000000 --- a/old_tests/test_utils/test_setup_env.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import multiprocessing as mp -import os -import platform - -import cv2 -from mmcv import Config - -from mmocr.utils import setup_multi_processes - - -def test_setup_multi_processes(): - # temp save system setting - sys_start_mehod = mp.get_start_method(allow_none=True) - sys_cv_threads = cv2.getNumThreads() - # pop and temp save system env vars - sys_omp_threads = os.environ.pop('OMP_NUM_THREADS', default=None) - sys_mkl_threads = os.environ.pop('MKL_NUM_THREADS', default=None) - - # test config without setting env - config = dict(data=dict(workers_per_gpu=2)) - cfg = Config(config) - setup_multi_processes(cfg) - assert os.getenv('OMP_NUM_THREADS') == '1' - assert os.getenv('MKL_NUM_THREADS') == '1' - # when set to 0, the num threads will be 1 - assert cv2.getNumThreads() == 1 - if platform.system() != 'Windows': - assert mp.get_start_method() == 'fork' - - # test num workers <= 1 - os.environ.pop('OMP_NUM_THREADS') - os.environ.pop('MKL_NUM_THREADS') - config = dict(data=dict(workers_per_gpu=0)) - cfg = Config(config) - setup_multi_processes(cfg) - assert 'OMP_NUM_THREADS' not in os.environ - assert 'MKL_NUM_THREADS' not in os.environ - - # test manually set env var - os.environ['OMP_NUM_THREADS'] = '4' - config = dict(data=dict(workers_per_gpu=2)) - cfg = Config(config) - setup_multi_processes(cfg) - assert os.getenv('OMP_NUM_THREADS') == '4' - - # test manually set opencv threads and mp start method - config = dict( - data=dict(workers_per_gpu=2), - opencv_num_threads=4, - mp_start_method='spawn') - cfg = Config(config) - setup_multi_processes(cfg) - assert cv2.getNumThreads() == 4 - assert mp.get_start_method() == 'spawn' - - # revert setting to avoid affecting other programs - if sys_start_mehod: - mp.set_start_method(sys_start_mehod, force=True) - cv2.setNumThreads(sys_cv_threads) - if sys_omp_threads: - os.environ['OMP_NUM_THREADS'] = sys_omp_threads - else: - os.environ.pop('OMP_NUM_THREADS') - if sys_mkl_threads: - os.environ['MKL_NUM_THREADS'] = sys_mkl_threads - else: - os.environ.pop('MKL_NUM_THREADS') diff --git a/old_tests/test_utils/test_wrapper.py b/old_tests/test_utils/test_wrapper.py deleted file mode 100644 index deb32afe..00000000 --- a/old_tests/test_utils/test_wrapper.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import numpy as np -import pytest -import torch - -from mmocr.models.textdet.postprocessors import (DBPostprocessor, - FCEPostprocessor, - TextSnakePostprocessor) -from mmocr.models.textdet.postprocessors.utils import (comps2boundaries, - poly_nms) - - -def test_db_boxes_from_bitmaps(): - """Test the boxes_from_bitmaps function in db_decoder.""" - pred = np.array([[[0.8, 0.8, 0.8, 0.8, 0], [0.8, 0.8, 0.8, 0.8, 0], - [0.8, 0.8, 0.8, 0.8, 0], [0.8, 0.8, 0.8, 0.8, 0], - [0.8, 0.8, 0.8, 0.8, 0]]]) - preds = torch.FloatTensor(pred).requires_grad_(True) - db_decode = DBPostprocessor(text_repr_type='quad', min_text_width=0) - boxes = db_decode(preds) - assert len(boxes) == 1 - - -def test_fcenet_decode(): - - k = 1 - preds = [] - preds.append(torch.ones(1, 4, 10, 10)) - preds.append(torch.ones(1, 4 * k + 2, 10, 10)) - fcenet_decode = FCEPostprocessor( - fourier_degree=k, num_reconstr_points=50, nms_thr=0.01) - boundaries = fcenet_decode(preds=preds, scale=1) - - assert isinstance(boundaries, list) - - -def test_poly_nms(): - threshold = 0 - polygons = [] - polygons.append([10, 10, 10, 30, 30, 30, 30, 10, 0.95]) - polygons.append([15, 15, 15, 25, 25, 25, 25, 15, 0.9]) - polygons.append([40, 40, 40, 50, 50, 50, 50, 40, 0.85]) - polygons.append([5, 5, 5, 15, 15, 15, 15, 5, 0.7]) - - keep_poly = poly_nms(polygons, threshold) - assert isinstance(keep_poly, list) - - -def test_comps2boundaries(): - - # test comps2boundaries - x1 = np.arange(2, 18, 2) - x2 = x1 + 2 - y1 = np.ones(8) * 2 - y2 = y1 + 2 - comp_scores = np.ones(8, dtype=np.float32) * 0.9 - text_comps = np.stack([x1, y1, x2, y1, x2, y2, x1, y2, - comp_scores]).transpose() - comp_labels = np.array([1, 1, 1, 1, 1, 3, 5, 5]) - shuffle = [3, 2, 5, 7, 6, 0, 4, 1] - boundaries = comps2boundaries(text_comps[shuffle], comp_labels[shuffle]) - assert len(boundaries) == 3 - - # test comps2boundaries with blank inputs - boundaries = comps2boundaries(text_comps[[]], comp_labels[[]]) - assert len(boundaries) == 0 - - -def test_textsnake_decode(): - - maps = torch.zeros((1, 6, 224, 224), dtype=torch.float) - maps[:, 0:2, :, :] = -10. - maps[:, 0, 60:100, 50:170] = 10. - maps[:, 1, 75:85, 60:160] = 10. - maps[:, 2, 75:85, 60:160] = 0. - maps[:, 3, 75:85, 60:160] = 1. - maps[:, 4, 75:85, 60:160] = 10. - # test decoding with text center region of small area - maps[:, 0:2, 150:152, 5:7] = 10. - textsnake_decode = TextSnakePostprocessor() - results = textsnake_decode(torch.squeeze(maps)) - assert len(results) == 1 - - # test decoding with small radius - maps.fill_(0.) - maps[:, 0:2, :, :] = -10. - maps[:, 0, 120:140, 20:40] = 10. - maps[:, 1, 120:140, 20:40] = 10. - maps[:, 2, 120:140, 20:40] = 0. - maps[:, 3, 120:140, 20:40] = 1. - maps[:, 4, 120:140, 20:40] = 0.5 - - results = textsnake_decode(torch.squeeze(maps)) - assert len(results) == 0 - - -def test_db_decode(): - pred = torch.zeros((1, 8, 8)) - pred[0, 2:7, 2:7] = 0.8 - expect_result_quad = [[ - 1.0, 8.0, 1.0, 1.0, 8.0, 1.0, 8.0, 8.0, 0.800000011920929 - ]] - expect_result_poly = [[ - 8, 2, 8, 6, 6, 8, 2, 8, 1, 6, 1, 2, 2, 1, 6, 1, 0.800000011920929 - ]] - with pytest.raises(AssertionError): - DBPostprocessor(text_repr_type='dummpy') - db_decode = DBPostprocessor(text_repr_type='quad', min_text_width=1) - result_quad = db_decode(preds=pred) - db_decode = DBPostprocessor(text_repr_type='poly', min_text_width=1) - result_poly = db_decode(preds=pred) - assert result_quad == expect_result_quad - assert result_poly == expect_result_poly diff --git a/old_tests/test_utils/test_version_utils.py b/tests/test_init.py similarity index 100% rename from old_tests/test_utils/test_version_utils.py rename to tests/test_init.py diff --git a/tests/test_models/test_textdet/test_postprocessors/test_db_postprocessor.py b/tests/test_models/test_textdet/test_postprocessors/test_db_postprocessor.py index 49a1e6b9..0425d2bd 100644 --- a/tests/test_models/test_textdet/test_postprocessors/test_db_postprocessor.py +++ b/tests/test_models/test_textdet/test_postprocessors/test_db_postprocessor.py @@ -36,6 +36,17 @@ class TestDBPostProcessor(unittest.TestCase): self.assertTrue( isinstance(results.pred_instances['scores'], torch.FloatTensor)) + preds = (torch.FloatTensor([[0.8, 0.8, 0.8, 0.8, 0], + [0.8, 0.8, 0.8, 0.8, 0], + [0.8, 0.8, 0.8, 0.8, 0], + [0.8, 0.8, 0.8, 0.8, 0], + [0.8, 0.8, 0.8, 0.8, 0]]), + torch.rand([1, 10]), torch.rand([1, 10])) + postprocessor = DBPostprocessor( + text_repr_type=text_repr_type, min_text_width=0) + results = postprocessor.get_text_instances(preds, data_sample) + self.assertEqual(len(results.pred_instances['polygons']), 1) + postprocessor = DBPostprocessor( min_text_score=1, text_repr_type=text_repr_type) pred_result = (torch.rand(4, 5) * 0.8, torch.rand(4, 5) * 0.8, diff --git a/tests/test_models/test_textdet/test_postprocessors/test_drrg_postprocessor.py b/tests/test_models/test_textdet/test_postprocessors/test_drrg_postprocessor.py index 7ec32bbf..e968b5e4 100644 --- a/tests/test_models/test_textdet/test_postprocessors/test_drrg_postprocessor.py +++ b/tests/test_models/test_textdet/test_postprocessors/test_drrg_postprocessor.py @@ -27,3 +27,24 @@ class TestDRRGPostProcessor(unittest.TestCase): self.assertIn('scores', result.pred_instances) self.assertTrue( isinstance(result.pred_instances['scores'], torch.FloatTensor)) + + def test_comps2polys(self): + postprocessor = DRRGPostprocessor() + + x1 = np.arange(2, 18, 2) + x2 = x1 + 2 + y1 = np.ones(8) * 2 + y2 = y1 + 2 + comp_scores = np.ones(8, dtype=np.float32) * 0.9 + text_comps = np.stack([x1, y1, x2, y1, x2, y2, x1, y2, + comp_scores]).transpose() + comp_labels = np.array([1, 1, 1, 1, 1, 3, 5, 5]) + shuffle = [3, 2, 5, 7, 6, 0, 4, 1] + + boundaries = postprocessor._comps2polys(text_comps[shuffle], + comp_labels[shuffle]) + self.assertEqual(len(boundaries[0]), 3) + + boundaries = postprocessor._comps2polys(text_comps[[]], + comp_labels[[]]) + self.assertEqual(len(boundaries[0]), 0) diff --git a/old_tests/test_utils/test_textio.py b/tests/test_utils/test_fileio.py similarity index 100% rename from old_tests/test_utils/test_textio.py rename to tests/test_utils/test_fileio.py diff --git a/tests/test_utils/test_setup_env.py b/tests/test_utils/test_setup_env.py deleted file mode 100644 index c2dd9811..00000000 --- a/tests/test_utils/test_setup_env.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import datetime -import sys -from unittest import TestCase - -from mmengine import DefaultScope - -from mmocr.utils import register_all_modules - - -class TestSetupEnv(TestCase): - - def test_register_all_modules(self): - from mmocr.registry import DATASETS - - # not init default scope - sys.modules.pop('mmocr.datasets', None) - sys.modules.pop('mmocr.datasets.ocr_dataset', None) - DATASETS._module_dict.pop('OCRDataset', None) - self.assertFalse('OCRDataset' in DATASETS.module_dict) - register_all_modules(init_default_scope=False) - self.assertTrue('OCRDataset' in DATASETS.module_dict) - - # init default scope - sys.modules.pop('mmocr.datasets') - sys.modules.pop('mmocr.datasets.ocr_dataset') - DATASETS._module_dict.pop('OCRDataset', None) - self.assertFalse('OCRDataset' in DATASETS.module_dict) - register_all_modules(init_default_scope=True) - self.assertTrue('OCRDataset' in DATASETS.module_dict) - self.assertEqual(DefaultScope.get_current_instance().scope_name, - 'mmocr') - - # init default scope when another scope is init - name = f'test-{datetime.datetime.now()}' - DefaultScope.get_instance(name, scope_name='test') - with self.assertWarnsRegex( - Warning, 'The current default scope "test" is not "mmocr"'): - register_all_modules(init_default_scope=True) diff --git a/old_tests/test_utils/test_string_util.py b/tests/test_utils/test_string_utils.py similarity index 82% rename from old_tests/test_utils/test_string_util.py rename to tests/test_utils/test_string_utils.py index c0eb4678..e4c3d720 100644 --- a/old_tests/test_utils/test_string_util.py +++ b/tests/test_utils/test_string_utils.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import pytest -from mmocr.utils import StringStrip +from mmocr.utils import StringStripper def test_string_strip(): @@ -23,13 +23,13 @@ def test_string_strip(): for idx3, strip_str in enumerate(strip_str_list): tmp_args = dict( strip=strip, strip_pos=strip_pos, strip_str=strip_str) - strip_class = StringStrip(**tmp_args) + strip_class = StringStripper(**tmp_args) i = idx1 * len(strip_pos_list) * len( strip_str_list) + idx2 * len(strip_str_list) + idx3 assert strip_class(in_str_list[i]) == out_str_list[i] with pytest.raises(AssertionError): - StringStrip(strip='strip') - StringStrip(strip_pos='head') - StringStrip(strip_str=['\n', '\t']) + StringStripper(strip='strip') + StringStripper(strip_pos='head') + StringStripper(strip_str=['\n', '\t']) diff --git a/tools/data/textrecog/data_migrator.py b/tools/data/textrecog/data_migrator.py index 3f1fafd1..9fb0f205 100644 --- a/tools/data/textrecog/data_migrator.py +++ b/tools/data/textrecog/data_migrator.py @@ -4,7 +4,7 @@ import json from typing import List, Tuple from mmocr.datasets import RecogLMDBDataset -from mmocr.utils import StringStrip, dump_ocr_data, recog_anno_to_imginfo +from mmocr.utils import StringStripper, dump_ocr_data, recog_anno_to_imginfo def parse_legacy_data(in_path: str, @@ -21,7 +21,7 @@ def parse_legacy_data(in_path: str, """ file_paths = [] labels = [] - strip_cls = StringStrip() + strip_cls = StringStripper() if format == 'lmdb': dataset = RecogLMDBDataset( in_path,