diff --git a/mmocr/models/textrecog/recognizer/crnn.py b/mmocr/models/textrecog/recognizer/crnn.py index 1ff68f7f..cfc98aae 100644 --- a/mmocr/models/textrecog/recognizer/crnn.py +++ b/mmocr/models/textrecog/recognizer/crnn.py @@ -1,6 +1,3 @@ -import torch -import torch.nn.functional as F - from mmdet.models.builder import DETECTORS from .encode_decode_recognizer import EncodeDecodeRecognizer @@ -8,11 +5,3 @@ from .encode_decode_recognizer import EncodeDecodeRecognizer @DETECTORS.register_module() class CRNNNet(EncodeDecodeRecognizer): """CTC-loss based recognizer.""" - - def forward_conversion(self, params, img): - x = self.extract_feat(img) - x = self.encoder(x) - outs = self.decoder(x) - outs = F.softmax(outs, dim=2) - params = torch.pow(params, 1) - return outs, params diff --git a/tests/test_dataset/test_kie_dataset.py b/tests/test_dataset/test_kie_dataset.py new file mode 100644 index 00000000..da74c6ed --- /dev/null +++ b/tests/test_dataset/test_kie_dataset.py @@ -0,0 +1,114 @@ +import json +import math +import os.path as osp +import tempfile + +import pytest +import torch + +from mmocr.datasets.kie_dataset import KIEDataset + + +def _create_dummy_ann_file(ann_file): + ann_info1 = { + 'file_name': + 'sample1.png', + 'height': + 200, + 'width': + 200, + 'annotations': [{ + 'text': 'store', + 'box': [11.0, 0.0, 22.0, 0.0, 12.0, 12.0, 0.0, 12.0], + 'label': 1 + }, { + 'text': 'address', + 'box': [23.0, 2.0, 31.0, 1.0, 24.0, 11.0, 16.0, 11.0], + 'label': 1 + }, { + 'text': 'price', + 'box': [33.0, 2.0, 43.0, 2.0, 36.0, 12.0, 25.0, 12.0], + 'label': 1 + }, { + 'text': '1.0', + 'box': [46.0, 2.0, 61.0, 2.0, 53.0, 12.0, 39.0, 12.0], + 'label': 1 + }, { + 'text': 'google', + 'box': [61.0, 2.0, 69.0, 2.0, 63.0, 12.0, 55.0, 12.0], + 'label': 1 + }] + } + with open(ann_file, 'w') as fw: + for ann_info in [ann_info1]: + fw.write(json.dumps(ann_info) + '\n') + + return ann_info1 + + +def _create_dummy_dict_file(dict_file): + dict_str = '0123' + with open(dict_file, 'w') as fw: + for char in list(dict_str): + fw.write(char + '\n') + + return dict_str + + +def _create_dummy_loader(): + loader = dict( + type='HardDiskLoader', + repeat=1, + parser=dict( + type='LineJsonParser', + keys=['file_name', 'height', 'width', 'annotations'])) + return loader + + +def test_kie_dataset(): + tmp_dir = tempfile.TemporaryDirectory() + # create dummy data + ann_file = osp.join(tmp_dir.name, 'fake_data.txt') + ann_info1 = _create_dummy_ann_file(ann_file) + + dict_file = osp.join(tmp_dir.name, 'fake_dict.txt') + _create_dummy_dict_file(dict_file) + + # test initialization + loader = _create_dummy_loader() + dataset = KIEDataset(ann_file, loader, dict_file, pipeline=[]) + + tmp_dir.cleanup() + + # test pre_pipeline + img_info = dataset.data_infos[0] + results = dict(img_info=img_info) + dataset.pre_pipeline(results) + assert results['img_prefix'] == dataset.img_prefix + + # test _parse_anno_info + annos = ann_info1['annotations'] + with pytest.raises(AssertionError): + dataset._parse_anno_info(annos[0]) + tmp_annos = [{ + 'text': 'store', + 'box': [11.0, 0.0, 22.0, 0.0, 12.0, 12.0, 0.0, 12.0] + }] + with pytest.raises(AssertionError): + dataset._parse_anno_info(tmp_annos) + + return_anno = dataset._parse_anno_info(annos) + assert 'bboxes' in return_anno + assert 'relations' in return_anno + assert 'texts' in return_anno + assert 'labels' in return_anno + + # test evaluation + result = {} + result['nodes'] = torch.full((5, 5), 1, dtype=torch.float) + result['nodes'][:, 1] = 100. + print('hello', result['nodes'].size()) + results = [result for _ in range(5)] + + eval_res = dataset.evaluate(results) + assert math.isclose(eval_res['macro_f1'], 0.2, abs_tol=1e-4) diff --git a/tests/test_dataset/test_ocr_transforms.py b/tests/test_dataset/test_ocr_transforms.py index a568b908..15522b25 100644 --- a/tests/test_dataset/test_ocr_transforms.py +++ b/tests/test_dataset/test_ocr_transforms.py @@ -4,6 +4,7 @@ import unittest.mock as mock import numpy as np import torch import torchvision.transforms.functional as TF +from PIL import Image import mmocr.datasets.pipelines.ocr_transforms as transforms @@ -92,3 +93,48 @@ def test_online_crop(mock_random): results = rci(results) assert np.allclose(results['img'].shape, [100, 100, 3]) + + +def test_fancy_pca(): + input_tensor = torch.rand(3, 32, 100) + + rci = transforms.FancyPCA() + + results = {'img': input_tensor} + results = rci(results) + + assert results['img'].shape == torch.Size([3, 32, 100]) + + +@mock.patch('%s.transforms.np.random.uniform' % __name__) +def test_random_padding(mock_random): + kwargs = dict(max_ratio=[0.0, 0.0, 0.0, 0.0], box_type=None) + + mock_random.side_effect = [1, 1, 1, 1] + + src_img = np.ones((32, 100, 3), dtype=np.uint8) + results = {'img': src_img, 'img_shape': (32, 100, 3)} + + rci = transforms.RandomPaddingOCR(**kwargs) + + results = rci(results) + print(results['img'].shape) + assert np.allclose(results['img_shape'], [96, 300, 3]) + + +def test_opencv2pil(): + src_img = np.ones((32, 100, 3), dtype=np.uint8) + results = {'img': src_img} + rci = transforms.OpencvToPil() + + results = rci(results) + assert np.allclose(results['img'].size, (100, 32)) + + +def test_pil2opencv(): + src_img = Image.new('RGB', (100, 32), color=(255, 255, 255)) + results = {'img': src_img} + rci = transforms.PilToOpencv() + + results = rci(results) + assert np.allclose(results['img'].shape, (32, 100, 3)) diff --git a/tests/test_models/test_ocr_head.py b/tests/test_models/test_ocr_head.py index 52761405..7df0f77b 100644 --- a/tests/test_models/test_ocr_head.py +++ b/tests/test_models/test_ocr_head.py @@ -4,13 +4,13 @@ import torch from mmocr.models.textrecog import SegHead -def test_cafcn_head(): +def test_seg_head(): with pytest.raises(AssertionError): SegHead(num_classes='100') with pytest.raises(AssertionError): SegHead(num_classes=-1) - cafcn_head = SegHead(num_classes=37) + seg_head = SegHead(num_classes=37) out_neck = (torch.rand(1, 128, 32, 32), ) - out_head = cafcn_head(out_neck) + out_head = seg_head(out_neck) assert out_head.shape == torch.Size([1, 37, 32, 32]) diff --git a/tests/test_models/test_ocr_loss.py b/tests/test_models/test_ocr_loss.py index 6dad9e2d..51f84fb2 100644 --- a/tests/test_models/test_ocr_loss.py +++ b/tests/test_models/test_ocr_loss.py @@ -6,6 +6,14 @@ from mmocr.models.textrecog.losses import CELoss, CTCLoss, SARLoss, TFLoss def test_ctc_loss(): + with pytest.raises(AssertionError): + CTCLoss(flatten='flatten') + with pytest.raises(AssertionError): + CTCLoss(blank=None) + with pytest.raises(AssertionError): + CTCLoss(reduction=1) + with pytest.raises(AssertionError): + CTCLoss(zero_infinity='zero') # test CTCLoss ctc_loss = CTCLoss() outputs = torch.zeros(2, 40, 37) diff --git a/tests/test_models/test_ocr_neck.py b/tests/test_models/test_ocr_neck.py new file mode 100644 index 00000000..28009311 --- /dev/null +++ b/tests/test_models/test_ocr_neck.py @@ -0,0 +1,17 @@ +import torch + +from mmocr.models.textrecog.necks import FPNOCR + + +def test_fpn_ocr(): + in_s1 = torch.rand(1, 128, 32, 256) + in_s2 = torch.rand(1, 256, 16, 128) + in_s3 = torch.rand(1, 512, 8, 64) + in_s4 = torch.rand(1, 512, 4, 32) + + fpn_ocr = FPNOCR(in_channels=[128, 256, 512, 512], out_channels=256) + fpn_ocr.init_weights() + fpn_ocr.train() + + out_neck = fpn_ocr((in_s1, in_s2, in_s3, in_s4)) + assert out_neck[0].shape == torch.Size([1, 256, 32, 256]) diff --git a/tests/test_models/test_recog_config.py b/tests/test_models/test_recog_config.py new file mode 100644 index 00000000..478743c1 --- /dev/null +++ b/tests/test_models/test_recog_config.py @@ -0,0 +1,147 @@ +import copy +from os.path import dirname, exists, join + +import numpy as np +import pytest +import torch + + +def _demo_mm_inputs(num_kernels=0, input_shape=(1, 3, 300, 300), + num_items=None): # yapf: disable + """Create a superset of inputs needed to run test or train batches. + + Args: + input_shape (tuple): Input batch dimensions. + + num_items (None | list[int]): Specifies the number of boxes + for each batch item. + """ + + (N, C, H, W) = input_shape + + rng = np.random.RandomState(0) + + imgs = rng.rand(*input_shape) + + img_metas = [{ + 'img_shape': (H, W, C), + 'ori_shape': (H, W, C), + 'pad_shape': (H, W, C), + 'filename': '.png', + 'text': 'hello', + 'valid_ratio': 1.0, + } for _ in range(N)] + + mm_inputs = { + 'imgs': torch.FloatTensor(imgs).requires_grad_(True), + 'img_metas': img_metas + } + return mm_inputs + + +def _demo_gt_kernel_inputs(num_kernels=3, input_shape=(1, 3, 300, 300), + num_items=None): # yapf: disable + """Create a superset of inputs needed to run test or train batches. + + Args: + input_shape (tuple): Input batch dimensions. + + num_items (None | list[int]): Specifies the number of boxes + for each batch item. + """ + from mmdet.core import BitmapMasks + + (N, C, H, W) = input_shape + gt_kernels = [] + + for batch_idx in range(N): + kernels = [] + for kernel_inx in range(num_kernels): + kernel = np.random.rand(H, W) + kernels.append(kernel) + gt_kernels.append(BitmapMasks(kernels, H, W)) + + return gt_kernels + + +def _get_config_directory(): + """Find the predefined detector config directory.""" + try: + # Assume we are running in the source mmocr repo + repo_dpath = dirname(dirname(dirname(__file__))) + except NameError: + # For IPython development when this __file__ is not defined + import mmocr + repo_dpath = dirname(dirname(mmocr.__file__)) + config_dpath = join(repo_dpath, 'configs') + if not exists(config_dpath): + raise Exception('Cannot find config path') + return config_dpath + + +def _get_config_module(fname): + """Load a configuration as a python module.""" + from mmcv import Config + config_dpath = _get_config_directory() + config_fpath = join(config_dpath, fname) + config_mod = Config.fromfile(config_fpath) + return config_mod + + +def _get_detector_cfg(fname): + """Grab configs necessary to create a detector. + + These are deep copied to allow for safe modification of parameters without + influencing other tests. + """ + config = _get_config_module(fname) + model = copy.deepcopy(config.model) + return model + + +@pytest.mark.parametrize('cfg_file', [ + 'textrecog/sar/sar_r31_parallel_decoder_academic.py', + 'textrecog/crnn/crnn_academic_dataset.py', + 'textrecog/nrtr/nrtr_r31_academic.py', + 'textrecog/robust_scanner/robustscanner_r31_academic.py', + 'textrecog/seg/seg_r31_1by16_fpnocr_academic.py' +]) +def test_encoder_decoder_pipeline(cfg_file): + model = _get_detector_cfg(cfg_file) + model['pretrained'] = None + + from mmocr.models import build_detector + detector = build_detector(model) + + input_shape = (1, 3, 32, 160) + if 'crnn' in cfg_file: + input_shape = (1, 1, 32, 160) + mm_inputs = _demo_mm_inputs(0, input_shape) + gt_kernels = None + if 'seg' in cfg_file: + gt_kernels = _demo_gt_kernel_inputs(3, input_shape) + + imgs = mm_inputs.pop('imgs') + img_metas = mm_inputs.pop('img_metas') + + # Test forward train + if 'seg' in cfg_file: + losses = detector.forward(imgs, img_metas, gt_kernels=gt_kernels) + else: + losses = detector.forward(imgs, img_metas) + assert isinstance(losses, dict) + + # Test forward test + with torch.no_grad(): + img_list = [g[None, :] for g in imgs] + batch_results = [] + for one_img, one_meta in zip(img_list, img_metas): + result = detector.forward([one_img], [[one_meta]], + return_loss=False) + batch_results.append(result) + + # Test show_result + + results = {'text': 'hello', 'score': 1.0} + img = np.random.rand(5, 5, 3) + detector.show_result(img, results)