mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
82 lines
3.0 KiB
Python
82 lines
3.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from unittest import TestCase
|
|
|
|
import torch
|
|
from mmengine.data import LabelData
|
|
|
|
from mmocr.data import TextRecogDataSample
|
|
from mmocr.models.textrecog.decoders import SequenceAttentionDecoder
|
|
|
|
|
|
class TestSequenceAttentionDecoder(TestCase):
|
|
|
|
def setUp(self):
|
|
gt_text_sample1 = TextRecogDataSample()
|
|
gt_text = LabelData()
|
|
gt_text.item = 'Hello'
|
|
gt_text_sample1.gt_text = gt_text
|
|
gt_text_sample1.set_metainfo(dict(valid_ratio=0.9))
|
|
|
|
gt_text_sample2 = TextRecogDataSample()
|
|
gt_text = LabelData()
|
|
gt_text = LabelData()
|
|
gt_text.item = 'World'
|
|
gt_text_sample2.gt_text = gt_text
|
|
gt_text_sample2.set_metainfo(dict(valid_ratio=1.0))
|
|
|
|
self.data_info = [gt_text_sample1, gt_text_sample2]
|
|
self.dict_cfg = dict(
|
|
type='Dictionary',
|
|
dict_file='dicts/lower_english_digits.txt',
|
|
with_start=True,
|
|
with_end=True,
|
|
same_start_end=True,
|
|
with_padding=True,
|
|
with_unknown=True)
|
|
|
|
def test_init(self):
|
|
|
|
module_loss_cfg = dict(type='CEModuleLoss')
|
|
decoder = SequenceAttentionDecoder(
|
|
dictionary=self.dict_cfg,
|
|
module_loss=module_loss_cfg,
|
|
return_feature=False)
|
|
self.assertIsInstance(decoder.prediction, torch.nn.Linear)
|
|
|
|
def test_forward_train(self):
|
|
feat = torch.randn(2, 512, 8, 8)
|
|
encoder_out = torch.randn(2, 128, 8, 8)
|
|
module_loss_cfg = dict(type='CEModuleLoss')
|
|
decoder = SequenceAttentionDecoder(
|
|
dictionary=self.dict_cfg,
|
|
module_loss=module_loss_cfg,
|
|
return_feature=False)
|
|
data_samples = decoder.module_loss.get_targets(self.data_info)
|
|
output = decoder.forward_train(
|
|
feat=feat, out_enc=encoder_out, data_samples=data_samples)
|
|
self.assertTupleEqual(tuple(output.shape), (2, 40, 39))
|
|
|
|
decoder = SequenceAttentionDecoder(
|
|
dictionary=self.dict_cfg, module_loss=module_loss_cfg)
|
|
output = decoder.forward_train(
|
|
feat=feat, out_enc=encoder_out, data_samples=data_samples)
|
|
self.assertTupleEqual(tuple(output.shape), (2, 40, 512))
|
|
|
|
feat_new = torch.randn(2, 256, 8, 8)
|
|
with self.assertRaises(AssertionError):
|
|
decoder.forward_train(feat_new, encoder_out, self.data_info)
|
|
encoder_out_new = torch.randn(2, 256, 8, 8)
|
|
with self.assertRaises(AssertionError):
|
|
decoder.forward_train(feat, encoder_out_new, self.data_info)
|
|
|
|
def test_forward_test(self):
|
|
feat = torch.randn(2, 512, 8, 8)
|
|
encoder_out = torch.randn(2, 128, 8, 8)
|
|
module_loss_cfg = dict(type='CEModuleLoss')
|
|
decoder = SequenceAttentionDecoder(
|
|
dictionary=self.dict_cfg,
|
|
module_loss=module_loss_cfg,
|
|
return_feature=False)
|
|
output = decoder.forward_test(feat, encoder_out, self.data_info)
|
|
self.assertTupleEqual(tuple(output.shape), (2, 40, 39))
|