2021-08-17 17:39:30 +08:00
|
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
2021-04-02 23:54:57 +08:00
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
|
2021-12-15 11:21:54 +08:00
|
|
|
from mmocr.models.textrecog.encoders import (ABIVisionModel, BaseEncoder,
|
|
|
|
NRTREncoder, SAREncoder,
|
|
|
|
SatrnEncoder, TransformerEncoder)
|
2021-04-02 23:54:57 +08:00
|
|
|
|
|
|
|
|
|
|
|
def test_sar_encoder():
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
SAREncoder(enc_bi_rnn='bi')
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
SAREncoder(enc_do_rnn=2)
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
SAREncoder(enc_gru='gru')
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
SAREncoder(d_model=512.5)
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
SAREncoder(d_enc=200.5)
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
SAREncoder(mask='mask')
|
|
|
|
|
|
|
|
encoder = SAREncoder()
|
|
|
|
encoder.init_weights()
|
|
|
|
encoder.train()
|
|
|
|
|
|
|
|
feat = torch.randn(1, 512, 4, 40)
|
|
|
|
img_metas = [{'valid_ratio': 1.0}]
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
encoder(feat, img_metas * 2)
|
|
|
|
out_enc = encoder(feat, img_metas)
|
|
|
|
|
|
|
|
assert out_enc.shape == torch.Size([1, 512])
|
|
|
|
|
|
|
|
|
2021-12-15 11:21:54 +08:00
|
|
|
def test_nrtr_encoder():
|
2021-12-04 17:12:31 +08:00
|
|
|
tf_encoder = NRTREncoder()
|
2021-04-02 23:54:57 +08:00
|
|
|
tf_encoder.init_weights()
|
|
|
|
tf_encoder.train()
|
|
|
|
|
2021-04-05 23:54:57 +08:00
|
|
|
feat = torch.randn(1, 512, 1, 25)
|
2021-04-02 23:54:57 +08:00
|
|
|
out_enc = tf_encoder(feat)
|
2021-04-05 23:54:57 +08:00
|
|
|
print('hello', out_enc.size())
|
2021-12-04 17:12:31 +08:00
|
|
|
assert out_enc.shape == torch.Size([1, 25, 512])
|
2021-04-02 23:54:57 +08:00
|
|
|
|
|
|
|
|
2021-08-19 22:02:58 +08:00
|
|
|
def test_satrn_encoder():
|
|
|
|
satrn_encoder = SatrnEncoder()
|
|
|
|
satrn_encoder.init_weights()
|
|
|
|
satrn_encoder.train()
|
|
|
|
|
|
|
|
feat = torch.randn(1, 512, 8, 25)
|
|
|
|
out_enc = satrn_encoder(feat)
|
2021-12-04 17:12:31 +08:00
|
|
|
assert out_enc.shape == torch.Size([1, 200, 512])
|
2021-08-19 22:02:58 +08:00
|
|
|
|
|
|
|
|
2021-04-02 23:54:57 +08:00
|
|
|
def test_base_encoder():
|
|
|
|
encoder = BaseEncoder()
|
|
|
|
encoder.init_weights()
|
|
|
|
encoder.train()
|
|
|
|
|
|
|
|
feat = torch.randn(1, 256, 4, 40)
|
|
|
|
out_enc = encoder(feat)
|
|
|
|
assert out_enc.shape == torch.Size([1, 256, 4, 40])
|
2021-12-15 11:21:54 +08:00
|
|
|
|
|
|
|
|
|
|
|
def test_transformer_encoder():
|
|
|
|
model = TransformerEncoder()
|
|
|
|
x = torch.randn(10, 512, 8, 32)
|
|
|
|
assert model(x).shape == torch.Size([10, 512, 8, 32])
|
|
|
|
|
|
|
|
|
|
|
|
def test_abi_vision_model():
|
|
|
|
model = ABIVisionModel(
|
|
|
|
decoder=dict(type='ABIVisionDecoder', max_seq_len=10, use_result=None))
|
|
|
|
x = torch.randn(1, 512, 8, 32)
|
|
|
|
result = model(x)
|
|
|
|
assert result['feature'].shape == torch.Size([1, 10, 512])
|
|
|
|
assert result['logits'].shape == torch.Size([1, 10, 90])
|
|
|
|
assert result['attn_scores'].shape == torch.Size([1, 10, 8, 32])
|