mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
parent
c8793ac141
commit
fbb5c8cda1
@ -1,6 +1,3 @@
|
|||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from mmdet.models.builder import DETECTORS
|
from mmdet.models.builder import DETECTORS
|
||||||
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
||||||
|
|
||||||
@ -8,11 +5,3 @@ from .encode_decode_recognizer import EncodeDecodeRecognizer
|
|||||||
@DETECTORS.register_module()
|
@DETECTORS.register_module()
|
||||||
class CRNNNet(EncodeDecodeRecognizer):
|
class CRNNNet(EncodeDecodeRecognizer):
|
||||||
"""CTC-loss based recognizer."""
|
"""CTC-loss based recognizer."""
|
||||||
|
|
||||||
def forward_conversion(self, params, img):
|
|
||||||
x = self.extract_feat(img)
|
|
||||||
x = self.encoder(x)
|
|
||||||
outs = self.decoder(x)
|
|
||||||
outs = F.softmax(outs, dim=2)
|
|
||||||
params = torch.pow(params, 1)
|
|
||||||
return outs, params
|
|
||||||
|
114
tests/test_dataset/test_kie_dataset.py
Normal file
114
tests/test_dataset/test_kie_dataset.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os.path as osp
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmocr.datasets.kie_dataset import KIEDataset
|
||||||
|
|
||||||
|
|
||||||
|
def _create_dummy_ann_file(ann_file):
|
||||||
|
ann_info1 = {
|
||||||
|
'file_name':
|
||||||
|
'sample1.png',
|
||||||
|
'height':
|
||||||
|
200,
|
||||||
|
'width':
|
||||||
|
200,
|
||||||
|
'annotations': [{
|
||||||
|
'text': 'store',
|
||||||
|
'box': [11.0, 0.0, 22.0, 0.0, 12.0, 12.0, 0.0, 12.0],
|
||||||
|
'label': 1
|
||||||
|
}, {
|
||||||
|
'text': 'address',
|
||||||
|
'box': [23.0, 2.0, 31.0, 1.0, 24.0, 11.0, 16.0, 11.0],
|
||||||
|
'label': 1
|
||||||
|
}, {
|
||||||
|
'text': 'price',
|
||||||
|
'box': [33.0, 2.0, 43.0, 2.0, 36.0, 12.0, 25.0, 12.0],
|
||||||
|
'label': 1
|
||||||
|
}, {
|
||||||
|
'text': '1.0',
|
||||||
|
'box': [46.0, 2.0, 61.0, 2.0, 53.0, 12.0, 39.0, 12.0],
|
||||||
|
'label': 1
|
||||||
|
}, {
|
||||||
|
'text': 'google',
|
||||||
|
'box': [61.0, 2.0, 69.0, 2.0, 63.0, 12.0, 55.0, 12.0],
|
||||||
|
'label': 1
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
with open(ann_file, 'w') as fw:
|
||||||
|
for ann_info in [ann_info1]:
|
||||||
|
fw.write(json.dumps(ann_info) + '\n')
|
||||||
|
|
||||||
|
return ann_info1
|
||||||
|
|
||||||
|
|
||||||
|
def _create_dummy_dict_file(dict_file):
|
||||||
|
dict_str = '0123'
|
||||||
|
with open(dict_file, 'w') as fw:
|
||||||
|
for char in list(dict_str):
|
||||||
|
fw.write(char + '\n')
|
||||||
|
|
||||||
|
return dict_str
|
||||||
|
|
||||||
|
|
||||||
|
def _create_dummy_loader():
|
||||||
|
loader = dict(
|
||||||
|
type='HardDiskLoader',
|
||||||
|
repeat=1,
|
||||||
|
parser=dict(
|
||||||
|
type='LineJsonParser',
|
||||||
|
keys=['file_name', 'height', 'width', 'annotations']))
|
||||||
|
return loader
|
||||||
|
|
||||||
|
|
||||||
|
def test_kie_dataset():
|
||||||
|
tmp_dir = tempfile.TemporaryDirectory()
|
||||||
|
# create dummy data
|
||||||
|
ann_file = osp.join(tmp_dir.name, 'fake_data.txt')
|
||||||
|
ann_info1 = _create_dummy_ann_file(ann_file)
|
||||||
|
|
||||||
|
dict_file = osp.join(tmp_dir.name, 'fake_dict.txt')
|
||||||
|
_create_dummy_dict_file(dict_file)
|
||||||
|
|
||||||
|
# test initialization
|
||||||
|
loader = _create_dummy_loader()
|
||||||
|
dataset = KIEDataset(ann_file, loader, dict_file, pipeline=[])
|
||||||
|
|
||||||
|
tmp_dir.cleanup()
|
||||||
|
|
||||||
|
# test pre_pipeline
|
||||||
|
img_info = dataset.data_infos[0]
|
||||||
|
results = dict(img_info=img_info)
|
||||||
|
dataset.pre_pipeline(results)
|
||||||
|
assert results['img_prefix'] == dataset.img_prefix
|
||||||
|
|
||||||
|
# test _parse_anno_info
|
||||||
|
annos = ann_info1['annotations']
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
dataset._parse_anno_info(annos[0])
|
||||||
|
tmp_annos = [{
|
||||||
|
'text': 'store',
|
||||||
|
'box': [11.0, 0.0, 22.0, 0.0, 12.0, 12.0, 0.0, 12.0]
|
||||||
|
}]
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
dataset._parse_anno_info(tmp_annos)
|
||||||
|
|
||||||
|
return_anno = dataset._parse_anno_info(annos)
|
||||||
|
assert 'bboxes' in return_anno
|
||||||
|
assert 'relations' in return_anno
|
||||||
|
assert 'texts' in return_anno
|
||||||
|
assert 'labels' in return_anno
|
||||||
|
|
||||||
|
# test evaluation
|
||||||
|
result = {}
|
||||||
|
result['nodes'] = torch.full((5, 5), 1, dtype=torch.float)
|
||||||
|
result['nodes'][:, 1] = 100.
|
||||||
|
print('hello', result['nodes'].size())
|
||||||
|
results = [result for _ in range(5)]
|
||||||
|
|
||||||
|
eval_res = dataset.evaluate(results)
|
||||||
|
assert math.isclose(eval_res['macro_f1'], 0.2, abs_tol=1e-4)
|
@ -4,6 +4,7 @@ import unittest.mock as mock
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchvision.transforms.functional as TF
|
import torchvision.transforms.functional as TF
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
import mmocr.datasets.pipelines.ocr_transforms as transforms
|
import mmocr.datasets.pipelines.ocr_transforms as transforms
|
||||||
|
|
||||||
@ -92,3 +93,48 @@ def test_online_crop(mock_random):
|
|||||||
|
|
||||||
results = rci(results)
|
results = rci(results)
|
||||||
assert np.allclose(results['img'].shape, [100, 100, 3])
|
assert np.allclose(results['img'].shape, [100, 100, 3])
|
||||||
|
|
||||||
|
|
||||||
|
def test_fancy_pca():
|
||||||
|
input_tensor = torch.rand(3, 32, 100)
|
||||||
|
|
||||||
|
rci = transforms.FancyPCA()
|
||||||
|
|
||||||
|
results = {'img': input_tensor}
|
||||||
|
results = rci(results)
|
||||||
|
|
||||||
|
assert results['img'].shape == torch.Size([3, 32, 100])
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch('%s.transforms.np.random.uniform' % __name__)
|
||||||
|
def test_random_padding(mock_random):
|
||||||
|
kwargs = dict(max_ratio=[0.0, 0.0, 0.0, 0.0], box_type=None)
|
||||||
|
|
||||||
|
mock_random.side_effect = [1, 1, 1, 1]
|
||||||
|
|
||||||
|
src_img = np.ones((32, 100, 3), dtype=np.uint8)
|
||||||
|
results = {'img': src_img, 'img_shape': (32, 100, 3)}
|
||||||
|
|
||||||
|
rci = transforms.RandomPaddingOCR(**kwargs)
|
||||||
|
|
||||||
|
results = rci(results)
|
||||||
|
print(results['img'].shape)
|
||||||
|
assert np.allclose(results['img_shape'], [96, 300, 3])
|
||||||
|
|
||||||
|
|
||||||
|
def test_opencv2pil():
|
||||||
|
src_img = np.ones((32, 100, 3), dtype=np.uint8)
|
||||||
|
results = {'img': src_img}
|
||||||
|
rci = transforms.OpencvToPil()
|
||||||
|
|
||||||
|
results = rci(results)
|
||||||
|
assert np.allclose(results['img'].size, (100, 32))
|
||||||
|
|
||||||
|
|
||||||
|
def test_pil2opencv():
|
||||||
|
src_img = Image.new('RGB', (100, 32), color=(255, 255, 255))
|
||||||
|
results = {'img': src_img}
|
||||||
|
rci = transforms.PilToOpencv()
|
||||||
|
|
||||||
|
results = rci(results)
|
||||||
|
assert np.allclose(results['img'].shape, (32, 100, 3))
|
||||||
|
@ -4,13 +4,13 @@ import torch
|
|||||||
from mmocr.models.textrecog import SegHead
|
from mmocr.models.textrecog import SegHead
|
||||||
|
|
||||||
|
|
||||||
def test_cafcn_head():
|
def test_seg_head():
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
SegHead(num_classes='100')
|
SegHead(num_classes='100')
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
SegHead(num_classes=-1)
|
SegHead(num_classes=-1)
|
||||||
|
|
||||||
cafcn_head = SegHead(num_classes=37)
|
seg_head = SegHead(num_classes=37)
|
||||||
out_neck = (torch.rand(1, 128, 32, 32), )
|
out_neck = (torch.rand(1, 128, 32, 32), )
|
||||||
out_head = cafcn_head(out_neck)
|
out_head = seg_head(out_neck)
|
||||||
assert out_head.shape == torch.Size([1, 37, 32, 32])
|
assert out_head.shape == torch.Size([1, 37, 32, 32])
|
||||||
|
@ -6,6 +6,14 @@ from mmocr.models.textrecog.losses import CELoss, CTCLoss, SARLoss, TFLoss
|
|||||||
|
|
||||||
|
|
||||||
def test_ctc_loss():
|
def test_ctc_loss():
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
CTCLoss(flatten='flatten')
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
CTCLoss(blank=None)
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
CTCLoss(reduction=1)
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
CTCLoss(zero_infinity='zero')
|
||||||
# test CTCLoss
|
# test CTCLoss
|
||||||
ctc_loss = CTCLoss()
|
ctc_loss = CTCLoss()
|
||||||
outputs = torch.zeros(2, 40, 37)
|
outputs = torch.zeros(2, 40, 37)
|
||||||
|
17
tests/test_models/test_ocr_neck.py
Normal file
17
tests/test_models/test_ocr_neck.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from mmocr.models.textrecog.necks import FPNOCR
|
||||||
|
|
||||||
|
|
||||||
|
def test_fpn_ocr():
|
||||||
|
in_s1 = torch.rand(1, 128, 32, 256)
|
||||||
|
in_s2 = torch.rand(1, 256, 16, 128)
|
||||||
|
in_s3 = torch.rand(1, 512, 8, 64)
|
||||||
|
in_s4 = torch.rand(1, 512, 4, 32)
|
||||||
|
|
||||||
|
fpn_ocr = FPNOCR(in_channels=[128, 256, 512, 512], out_channels=256)
|
||||||
|
fpn_ocr.init_weights()
|
||||||
|
fpn_ocr.train()
|
||||||
|
|
||||||
|
out_neck = fpn_ocr((in_s1, in_s2, in_s3, in_s4))
|
||||||
|
assert out_neck[0].shape == torch.Size([1, 256, 32, 256])
|
147
tests/test_models/test_recog_config.py
Normal file
147
tests/test_models/test_recog_config.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
import copy
|
||||||
|
from os.path import dirname, exists, join
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def _demo_mm_inputs(num_kernels=0, input_shape=(1, 3, 300, 300),
|
||||||
|
num_items=None): # yapf: disable
|
||||||
|
"""Create a superset of inputs needed to run test or train batches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_shape (tuple): Input batch dimensions.
|
||||||
|
|
||||||
|
num_items (None | list[int]): Specifies the number of boxes
|
||||||
|
for each batch item.
|
||||||
|
"""
|
||||||
|
|
||||||
|
(N, C, H, W) = input_shape
|
||||||
|
|
||||||
|
rng = np.random.RandomState(0)
|
||||||
|
|
||||||
|
imgs = rng.rand(*input_shape)
|
||||||
|
|
||||||
|
img_metas = [{
|
||||||
|
'img_shape': (H, W, C),
|
||||||
|
'ori_shape': (H, W, C),
|
||||||
|
'pad_shape': (H, W, C),
|
||||||
|
'filename': '<demo>.png',
|
||||||
|
'text': 'hello',
|
||||||
|
'valid_ratio': 1.0,
|
||||||
|
} for _ in range(N)]
|
||||||
|
|
||||||
|
mm_inputs = {
|
||||||
|
'imgs': torch.FloatTensor(imgs).requires_grad_(True),
|
||||||
|
'img_metas': img_metas
|
||||||
|
}
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
|
def _demo_gt_kernel_inputs(num_kernels=3, input_shape=(1, 3, 300, 300),
|
||||||
|
num_items=None): # yapf: disable
|
||||||
|
"""Create a superset of inputs needed to run test or train batches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_shape (tuple): Input batch dimensions.
|
||||||
|
|
||||||
|
num_items (None | list[int]): Specifies the number of boxes
|
||||||
|
for each batch item.
|
||||||
|
"""
|
||||||
|
from mmdet.core import BitmapMasks
|
||||||
|
|
||||||
|
(N, C, H, W) = input_shape
|
||||||
|
gt_kernels = []
|
||||||
|
|
||||||
|
for batch_idx in range(N):
|
||||||
|
kernels = []
|
||||||
|
for kernel_inx in range(num_kernels):
|
||||||
|
kernel = np.random.rand(H, W)
|
||||||
|
kernels.append(kernel)
|
||||||
|
gt_kernels.append(BitmapMasks(kernels, H, W))
|
||||||
|
|
||||||
|
return gt_kernels
|
||||||
|
|
||||||
|
|
||||||
|
def _get_config_directory():
|
||||||
|
"""Find the predefined detector config directory."""
|
||||||
|
try:
|
||||||
|
# Assume we are running in the source mmocr repo
|
||||||
|
repo_dpath = dirname(dirname(dirname(__file__)))
|
||||||
|
except NameError:
|
||||||
|
# For IPython development when this __file__ is not defined
|
||||||
|
import mmocr
|
||||||
|
repo_dpath = dirname(dirname(mmocr.__file__))
|
||||||
|
config_dpath = join(repo_dpath, 'configs')
|
||||||
|
if not exists(config_dpath):
|
||||||
|
raise Exception('Cannot find config path')
|
||||||
|
return config_dpath
|
||||||
|
|
||||||
|
|
||||||
|
def _get_config_module(fname):
|
||||||
|
"""Load a configuration as a python module."""
|
||||||
|
from mmcv import Config
|
||||||
|
config_dpath = _get_config_directory()
|
||||||
|
config_fpath = join(config_dpath, fname)
|
||||||
|
config_mod = Config.fromfile(config_fpath)
|
||||||
|
return config_mod
|
||||||
|
|
||||||
|
|
||||||
|
def _get_detector_cfg(fname):
|
||||||
|
"""Grab configs necessary to create a detector.
|
||||||
|
|
||||||
|
These are deep copied to allow for safe modification of parameters without
|
||||||
|
influencing other tests.
|
||||||
|
"""
|
||||||
|
config = _get_config_module(fname)
|
||||||
|
model = copy.deepcopy(config.model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('cfg_file', [
|
||||||
|
'textrecog/sar/sar_r31_parallel_decoder_academic.py',
|
||||||
|
'textrecog/crnn/crnn_academic_dataset.py',
|
||||||
|
'textrecog/nrtr/nrtr_r31_academic.py',
|
||||||
|
'textrecog/robust_scanner/robustscanner_r31_academic.py',
|
||||||
|
'textrecog/seg/seg_r31_1by16_fpnocr_academic.py'
|
||||||
|
])
|
||||||
|
def test_encoder_decoder_pipeline(cfg_file):
|
||||||
|
model = _get_detector_cfg(cfg_file)
|
||||||
|
model['pretrained'] = None
|
||||||
|
|
||||||
|
from mmocr.models import build_detector
|
||||||
|
detector = build_detector(model)
|
||||||
|
|
||||||
|
input_shape = (1, 3, 32, 160)
|
||||||
|
if 'crnn' in cfg_file:
|
||||||
|
input_shape = (1, 1, 32, 160)
|
||||||
|
mm_inputs = _demo_mm_inputs(0, input_shape)
|
||||||
|
gt_kernels = None
|
||||||
|
if 'seg' in cfg_file:
|
||||||
|
gt_kernels = _demo_gt_kernel_inputs(3, input_shape)
|
||||||
|
|
||||||
|
imgs = mm_inputs.pop('imgs')
|
||||||
|
img_metas = mm_inputs.pop('img_metas')
|
||||||
|
|
||||||
|
# Test forward train
|
||||||
|
if 'seg' in cfg_file:
|
||||||
|
losses = detector.forward(imgs, img_metas, gt_kernels=gt_kernels)
|
||||||
|
else:
|
||||||
|
losses = detector.forward(imgs, img_metas)
|
||||||
|
assert isinstance(losses, dict)
|
||||||
|
|
||||||
|
# Test forward test
|
||||||
|
with torch.no_grad():
|
||||||
|
img_list = [g[None, :] for g in imgs]
|
||||||
|
batch_results = []
|
||||||
|
for one_img, one_meta in zip(img_list, img_metas):
|
||||||
|
result = detector.forward([one_img], [[one_meta]],
|
||||||
|
return_loss=False)
|
||||||
|
batch_results.append(result)
|
||||||
|
|
||||||
|
# Test show_result
|
||||||
|
|
||||||
|
results = {'text': 'hello', 'score': 1.0}
|
||||||
|
img = np.random.rand(5, 5, 3)
|
||||||
|
detector.show_result(img, results)
|
Loading…
x
Reference in New Issue
Block a user