mmocr/tests/models/textrecog/decoders/test_nrtr_decoder.py

91 lines
3.1 KiB
Python
Raw Normal View History

2022-06-10 10:14:36 +00:00
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from unittest import TestCase
import torch
from mmengine.data import LabelData
2022-07-12 10:46:11 +00:00
from mmocr.data import TextRecogDataSample
2022-06-10 10:14:36 +00:00
from mmocr.models.textrecog.decoders import NRTRDecoder
2022-07-13 11:52:02 +00:00
from mmocr.testing import create_dummy_dict_file
2022-06-10 10:14:36 +00:00
class TestNRTRDecoder(TestCase):
def setUp(self):
gt_text_sample1 = TextRecogDataSample()
gt_text = LabelData()
gt_text.item = 'Hello'
gt_text_sample1.gt_text = gt_text
gt_text_sample1.set_metainfo(dict(valid_ratio=0.9))
gt_text_sample2 = TextRecogDataSample()
gt_text = LabelData()
gt_text = LabelData()
gt_text.item = 'World'
gt_text_sample2.gt_text = gt_text
gt_text_sample2.set_metainfo(dict(valid_ratio=1.0))
self.data_info = [gt_text_sample1, gt_text_sample2]
def test_init(self):
tmp_dir = tempfile.TemporaryDirectory()
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
2022-07-13 11:52:02 +00:00
create_dummy_dict_file(dict_file)
2022-06-10 10:14:36 +00:00
dict_cfg = dict(
type='Dictionary',
dict_file=dict_file,
with_start=True,
with_end=True,
same_start_end=True,
with_padding=True,
with_unknown=True)
loss_cfg = dict(type='CEModuleLoss')
NRTRDecoder(dictionary=dict_cfg, module_loss=loss_cfg)
2022-06-10 10:14:36 +00:00
tmp_dir.cleanup()
def test_forward_train(self):
encoder_out = torch.randn(2, 25, 512)
tmp_dir = tempfile.TemporaryDirectory()
2022-07-13 07:00:22 +00:00
max_seq_len = 40
2022-06-10 10:14:36 +00:00
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
2022-07-13 11:52:02 +00:00
create_dummy_dict_file(dict_file)
2022-06-10 10:14:36 +00:00
# test diction cfg
dict_cfg = dict(
type='Dictionary',
dict_file=dict_file,
with_start=True,
with_end=True,
same_start_end=True,
with_padding=True,
with_unknown=True)
loss_cfg = dict(type='CEModuleLoss')
2022-07-13 07:00:22 +00:00
decoder = NRTRDecoder(
dictionary=dict_cfg, module_loss=loss_cfg, max_seq_len=max_seq_len)
data_samples = decoder.module_loss.get_targets(self.data_info)
2022-06-30 09:40:12 +00:00
output = decoder.forward_train(
out_enc=encoder_out, data_samples=data_samples)
2022-07-13 07:00:22 +00:00
self.assertTupleEqual(tuple(output.shape), (2, max_seq_len, 39))
2022-06-10 10:14:36 +00:00
def test_forward_test(self):
encoder_out = torch.randn(2, 25, 512)
tmp_dir = tempfile.TemporaryDirectory()
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
2022-07-13 11:52:02 +00:00
create_dummy_dict_file(dict_file)
2022-06-10 10:14:36 +00:00
# test diction cfg
dict_cfg = dict(
type='Dictionary',
dict_file=dict_file,
with_start=True,
with_end=True,
same_start_end=True,
with_padding=True,
with_unknown=True)
loss_cfg = dict(type='CEModuleLoss')
2022-06-10 10:14:36 +00:00
decoder = NRTRDecoder(
dictionary=dict_cfg, module_loss=loss_cfg, max_seq_len=40)
2022-06-30 09:40:12 +00:00
output = decoder.forward_test(
out_enc=encoder_out, data_samples=self.data_info)
2022-06-10 10:14:36 +00:00
self.assertTupleEqual(tuple(output.shape), (2, 40, 39))