mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
parent
c8793ac141
commit
fbb5c8cda1
@ -1,6 +1,3 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmdet.models.builder import DETECTORS
|
||||
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
||||
|
||||
@ -8,11 +5,3 @@ from .encode_decode_recognizer import EncodeDecodeRecognizer
|
||||
@DETECTORS.register_module()
|
||||
class CRNNNet(EncodeDecodeRecognizer):
|
||||
"""CTC-loss based recognizer."""
|
||||
|
||||
def forward_conversion(self, params, img):
|
||||
x = self.extract_feat(img)
|
||||
x = self.encoder(x)
|
||||
outs = self.decoder(x)
|
||||
outs = F.softmax(outs, dim=2)
|
||||
params = torch.pow(params, 1)
|
||||
return outs, params
|
||||
|
114
tests/test_dataset/test_kie_dataset.py
Normal file
114
tests/test_dataset/test_kie_dataset.py
Normal file
@ -0,0 +1,114 @@
|
||||
import json
|
||||
import math
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmocr.datasets.kie_dataset import KIEDataset
|
||||
|
||||
|
||||
def _create_dummy_ann_file(ann_file):
|
||||
ann_info1 = {
|
||||
'file_name':
|
||||
'sample1.png',
|
||||
'height':
|
||||
200,
|
||||
'width':
|
||||
200,
|
||||
'annotations': [{
|
||||
'text': 'store',
|
||||
'box': [11.0, 0.0, 22.0, 0.0, 12.0, 12.0, 0.0, 12.0],
|
||||
'label': 1
|
||||
}, {
|
||||
'text': 'address',
|
||||
'box': [23.0, 2.0, 31.0, 1.0, 24.0, 11.0, 16.0, 11.0],
|
||||
'label': 1
|
||||
}, {
|
||||
'text': 'price',
|
||||
'box': [33.0, 2.0, 43.0, 2.0, 36.0, 12.0, 25.0, 12.0],
|
||||
'label': 1
|
||||
}, {
|
||||
'text': '1.0',
|
||||
'box': [46.0, 2.0, 61.0, 2.0, 53.0, 12.0, 39.0, 12.0],
|
||||
'label': 1
|
||||
}, {
|
||||
'text': 'google',
|
||||
'box': [61.0, 2.0, 69.0, 2.0, 63.0, 12.0, 55.0, 12.0],
|
||||
'label': 1
|
||||
}]
|
||||
}
|
||||
with open(ann_file, 'w') as fw:
|
||||
for ann_info in [ann_info1]:
|
||||
fw.write(json.dumps(ann_info) + '\n')
|
||||
|
||||
return ann_info1
|
||||
|
||||
|
||||
def _create_dummy_dict_file(dict_file):
|
||||
dict_str = '0123'
|
||||
with open(dict_file, 'w') as fw:
|
||||
for char in list(dict_str):
|
||||
fw.write(char + '\n')
|
||||
|
||||
return dict_str
|
||||
|
||||
|
||||
def _create_dummy_loader():
|
||||
loader = dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=1,
|
||||
parser=dict(
|
||||
type='LineJsonParser',
|
||||
keys=['file_name', 'height', 'width', 'annotations']))
|
||||
return loader
|
||||
|
||||
|
||||
def test_kie_dataset():
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
# create dummy data
|
||||
ann_file = osp.join(tmp_dir.name, 'fake_data.txt')
|
||||
ann_info1 = _create_dummy_ann_file(ann_file)
|
||||
|
||||
dict_file = osp.join(tmp_dir.name, 'fake_dict.txt')
|
||||
_create_dummy_dict_file(dict_file)
|
||||
|
||||
# test initialization
|
||||
loader = _create_dummy_loader()
|
||||
dataset = KIEDataset(ann_file, loader, dict_file, pipeline=[])
|
||||
|
||||
tmp_dir.cleanup()
|
||||
|
||||
# test pre_pipeline
|
||||
img_info = dataset.data_infos[0]
|
||||
results = dict(img_info=img_info)
|
||||
dataset.pre_pipeline(results)
|
||||
assert results['img_prefix'] == dataset.img_prefix
|
||||
|
||||
# test _parse_anno_info
|
||||
annos = ann_info1['annotations']
|
||||
with pytest.raises(AssertionError):
|
||||
dataset._parse_anno_info(annos[0])
|
||||
tmp_annos = [{
|
||||
'text': 'store',
|
||||
'box': [11.0, 0.0, 22.0, 0.0, 12.0, 12.0, 0.0, 12.0]
|
||||
}]
|
||||
with pytest.raises(AssertionError):
|
||||
dataset._parse_anno_info(tmp_annos)
|
||||
|
||||
return_anno = dataset._parse_anno_info(annos)
|
||||
assert 'bboxes' in return_anno
|
||||
assert 'relations' in return_anno
|
||||
assert 'texts' in return_anno
|
||||
assert 'labels' in return_anno
|
||||
|
||||
# test evaluation
|
||||
result = {}
|
||||
result['nodes'] = torch.full((5, 5), 1, dtype=torch.float)
|
||||
result['nodes'][:, 1] = 100.
|
||||
print('hello', result['nodes'].size())
|
||||
results = [result for _ in range(5)]
|
||||
|
||||
eval_res = dataset.evaluate(results)
|
||||
assert math.isclose(eval_res['macro_f1'], 0.2, abs_tol=1e-4)
|
@ -4,6 +4,7 @@ import unittest.mock as mock
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as TF
|
||||
from PIL import Image
|
||||
|
||||
import mmocr.datasets.pipelines.ocr_transforms as transforms
|
||||
|
||||
@ -92,3 +93,48 @@ def test_online_crop(mock_random):
|
||||
|
||||
results = rci(results)
|
||||
assert np.allclose(results['img'].shape, [100, 100, 3])
|
||||
|
||||
|
||||
def test_fancy_pca():
|
||||
input_tensor = torch.rand(3, 32, 100)
|
||||
|
||||
rci = transforms.FancyPCA()
|
||||
|
||||
results = {'img': input_tensor}
|
||||
results = rci(results)
|
||||
|
||||
assert results['img'].shape == torch.Size([3, 32, 100])
|
||||
|
||||
|
||||
@mock.patch('%s.transforms.np.random.uniform' % __name__)
|
||||
def test_random_padding(mock_random):
|
||||
kwargs = dict(max_ratio=[0.0, 0.0, 0.0, 0.0], box_type=None)
|
||||
|
||||
mock_random.side_effect = [1, 1, 1, 1]
|
||||
|
||||
src_img = np.ones((32, 100, 3), dtype=np.uint8)
|
||||
results = {'img': src_img, 'img_shape': (32, 100, 3)}
|
||||
|
||||
rci = transforms.RandomPaddingOCR(**kwargs)
|
||||
|
||||
results = rci(results)
|
||||
print(results['img'].shape)
|
||||
assert np.allclose(results['img_shape'], [96, 300, 3])
|
||||
|
||||
|
||||
def test_opencv2pil():
|
||||
src_img = np.ones((32, 100, 3), dtype=np.uint8)
|
||||
results = {'img': src_img}
|
||||
rci = transforms.OpencvToPil()
|
||||
|
||||
results = rci(results)
|
||||
assert np.allclose(results['img'].size, (100, 32))
|
||||
|
||||
|
||||
def test_pil2opencv():
|
||||
src_img = Image.new('RGB', (100, 32), color=(255, 255, 255))
|
||||
results = {'img': src_img}
|
||||
rci = transforms.PilToOpencv()
|
||||
|
||||
results = rci(results)
|
||||
assert np.allclose(results['img'].shape, (32, 100, 3))
|
||||
|
@ -4,13 +4,13 @@ import torch
|
||||
from mmocr.models.textrecog import SegHead
|
||||
|
||||
|
||||
def test_cafcn_head():
|
||||
def test_seg_head():
|
||||
with pytest.raises(AssertionError):
|
||||
SegHead(num_classes='100')
|
||||
with pytest.raises(AssertionError):
|
||||
SegHead(num_classes=-1)
|
||||
|
||||
cafcn_head = SegHead(num_classes=37)
|
||||
seg_head = SegHead(num_classes=37)
|
||||
out_neck = (torch.rand(1, 128, 32, 32), )
|
||||
out_head = cafcn_head(out_neck)
|
||||
out_head = seg_head(out_neck)
|
||||
assert out_head.shape == torch.Size([1, 37, 32, 32])
|
||||
|
@ -6,6 +6,14 @@ from mmocr.models.textrecog.losses import CELoss, CTCLoss, SARLoss, TFLoss
|
||||
|
||||
|
||||
def test_ctc_loss():
|
||||
with pytest.raises(AssertionError):
|
||||
CTCLoss(flatten='flatten')
|
||||
with pytest.raises(AssertionError):
|
||||
CTCLoss(blank=None)
|
||||
with pytest.raises(AssertionError):
|
||||
CTCLoss(reduction=1)
|
||||
with pytest.raises(AssertionError):
|
||||
CTCLoss(zero_infinity='zero')
|
||||
# test CTCLoss
|
||||
ctc_loss = CTCLoss()
|
||||
outputs = torch.zeros(2, 40, 37)
|
||||
|
17
tests/test_models/test_ocr_neck.py
Normal file
17
tests/test_models/test_ocr_neck.py
Normal file
@ -0,0 +1,17 @@
|
||||
import torch
|
||||
|
||||
from mmocr.models.textrecog.necks import FPNOCR
|
||||
|
||||
|
||||
def test_fpn_ocr():
|
||||
in_s1 = torch.rand(1, 128, 32, 256)
|
||||
in_s2 = torch.rand(1, 256, 16, 128)
|
||||
in_s3 = torch.rand(1, 512, 8, 64)
|
||||
in_s4 = torch.rand(1, 512, 4, 32)
|
||||
|
||||
fpn_ocr = FPNOCR(in_channels=[128, 256, 512, 512], out_channels=256)
|
||||
fpn_ocr.init_weights()
|
||||
fpn_ocr.train()
|
||||
|
||||
out_neck = fpn_ocr((in_s1, in_s2, in_s3, in_s4))
|
||||
assert out_neck[0].shape == torch.Size([1, 256, 32, 256])
|
147
tests/test_models/test_recog_config.py
Normal file
147
tests/test_models/test_recog_config.py
Normal file
@ -0,0 +1,147 @@
|
||||
import copy
|
||||
from os.path import dirname, exists, join
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
def _demo_mm_inputs(num_kernels=0, input_shape=(1, 3, 300, 300),
|
||||
num_items=None): # yapf: disable
|
||||
"""Create a superset of inputs needed to run test or train batches.
|
||||
|
||||
Args:
|
||||
input_shape (tuple): Input batch dimensions.
|
||||
|
||||
num_items (None | list[int]): Specifies the number of boxes
|
||||
for each batch item.
|
||||
"""
|
||||
|
||||
(N, C, H, W) = input_shape
|
||||
|
||||
rng = np.random.RandomState(0)
|
||||
|
||||
imgs = rng.rand(*input_shape)
|
||||
|
||||
img_metas = [{
|
||||
'img_shape': (H, W, C),
|
||||
'ori_shape': (H, W, C),
|
||||
'pad_shape': (H, W, C),
|
||||
'filename': '<demo>.png',
|
||||
'text': 'hello',
|
||||
'valid_ratio': 1.0,
|
||||
} for _ in range(N)]
|
||||
|
||||
mm_inputs = {
|
||||
'imgs': torch.FloatTensor(imgs).requires_grad_(True),
|
||||
'img_metas': img_metas
|
||||
}
|
||||
return mm_inputs
|
||||
|
||||
|
||||
def _demo_gt_kernel_inputs(num_kernels=3, input_shape=(1, 3, 300, 300),
|
||||
num_items=None): # yapf: disable
|
||||
"""Create a superset of inputs needed to run test or train batches.
|
||||
|
||||
Args:
|
||||
input_shape (tuple): Input batch dimensions.
|
||||
|
||||
num_items (None | list[int]): Specifies the number of boxes
|
||||
for each batch item.
|
||||
"""
|
||||
from mmdet.core import BitmapMasks
|
||||
|
||||
(N, C, H, W) = input_shape
|
||||
gt_kernels = []
|
||||
|
||||
for batch_idx in range(N):
|
||||
kernels = []
|
||||
for kernel_inx in range(num_kernels):
|
||||
kernel = np.random.rand(H, W)
|
||||
kernels.append(kernel)
|
||||
gt_kernels.append(BitmapMasks(kernels, H, W))
|
||||
|
||||
return gt_kernels
|
||||
|
||||
|
||||
def _get_config_directory():
|
||||
"""Find the predefined detector config directory."""
|
||||
try:
|
||||
# Assume we are running in the source mmocr repo
|
||||
repo_dpath = dirname(dirname(dirname(__file__)))
|
||||
except NameError:
|
||||
# For IPython development when this __file__ is not defined
|
||||
import mmocr
|
||||
repo_dpath = dirname(dirname(mmocr.__file__))
|
||||
config_dpath = join(repo_dpath, 'configs')
|
||||
if not exists(config_dpath):
|
||||
raise Exception('Cannot find config path')
|
||||
return config_dpath
|
||||
|
||||
|
||||
def _get_config_module(fname):
|
||||
"""Load a configuration as a python module."""
|
||||
from mmcv import Config
|
||||
config_dpath = _get_config_directory()
|
||||
config_fpath = join(config_dpath, fname)
|
||||
config_mod = Config.fromfile(config_fpath)
|
||||
return config_mod
|
||||
|
||||
|
||||
def _get_detector_cfg(fname):
|
||||
"""Grab configs necessary to create a detector.
|
||||
|
||||
These are deep copied to allow for safe modification of parameters without
|
||||
influencing other tests.
|
||||
"""
|
||||
config = _get_config_module(fname)
|
||||
model = copy.deepcopy(config.model)
|
||||
return model
|
||||
|
||||
|
||||
@pytest.mark.parametrize('cfg_file', [
|
||||
'textrecog/sar/sar_r31_parallel_decoder_academic.py',
|
||||
'textrecog/crnn/crnn_academic_dataset.py',
|
||||
'textrecog/nrtr/nrtr_r31_academic.py',
|
||||
'textrecog/robust_scanner/robustscanner_r31_academic.py',
|
||||
'textrecog/seg/seg_r31_1by16_fpnocr_academic.py'
|
||||
])
|
||||
def test_encoder_decoder_pipeline(cfg_file):
|
||||
model = _get_detector_cfg(cfg_file)
|
||||
model['pretrained'] = None
|
||||
|
||||
from mmocr.models import build_detector
|
||||
detector = build_detector(model)
|
||||
|
||||
input_shape = (1, 3, 32, 160)
|
||||
if 'crnn' in cfg_file:
|
||||
input_shape = (1, 1, 32, 160)
|
||||
mm_inputs = _demo_mm_inputs(0, input_shape)
|
||||
gt_kernels = None
|
||||
if 'seg' in cfg_file:
|
||||
gt_kernels = _demo_gt_kernel_inputs(3, input_shape)
|
||||
|
||||
imgs = mm_inputs.pop('imgs')
|
||||
img_metas = mm_inputs.pop('img_metas')
|
||||
|
||||
# Test forward train
|
||||
if 'seg' in cfg_file:
|
||||
losses = detector.forward(imgs, img_metas, gt_kernels=gt_kernels)
|
||||
else:
|
||||
losses = detector.forward(imgs, img_metas)
|
||||
assert isinstance(losses, dict)
|
||||
|
||||
# Test forward test
|
||||
with torch.no_grad():
|
||||
img_list = [g[None, :] for g in imgs]
|
||||
batch_results = []
|
||||
for one_img, one_meta in zip(img_list, img_metas):
|
||||
result = detector.forward([one_img], [[one_meta]],
|
||||
return_loss=False)
|
||||
batch_results.append(result)
|
||||
|
||||
# Test show_result
|
||||
|
||||
results = {'text': 'hello', 'score': 1.0}
|
||||
img = np.random.rand(5, 5, 3)
|
||||
detector.show_result(img, results)
|
Loading…
x
Reference in New Issue
Block a user