[Encoder] sar encoder

pull/1178/head
liukuikun 2022-06-24 03:22:56 +00:00 committed by gaotongxiao
parent fe43b4e767
commit 47771788f0
3 changed files with 84 additions and 56 deletions

View File

@ -1,11 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Dict, Optional, Sequence, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import mmocr.utils as utils
from mmocr.core.data_structures import TextRecogDataSample
from mmocr.registry import MODELS
from .base_encoder import BaseEncoder
@ -18,37 +19,44 @@ class SAREncoder(BaseEncoder):
Args:
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
enc_do_rnn (float): Dropout probability of RNN layer in encoder.
enc_gru (bool): If True, use GRU, else LSTM in encoder.
d_model (int): Dim :math:`D_i` of channels from backbone.
d_enc (int): Dim :math:`D_m` of encoder RNN layer.
mask (bool): If True, mask padding in RNN sequence.
Defaults to False.
rnn_dropout (float): Dropout probability of RNN layer in encoder.
Defaults to 0.0.
enc_gru (bool): If True, use GRU, else LSTM in encoder. Defaults
to False.
d_model (int): Dim :math:`D_i` of channels from backbone. Defaults
to 512.
d_enc (int): Dim :math:`D_m` of encoder RNN layer. Defaults to 512.
mask (bool): If True, mask padding in RNN sequence. Defaults to
True.
init_cfg (dict or list[dict], optional): Initialization configs.
Defaults to [dict(type='Xavier', layer='Conv2d'),
dict(type='Uniform', layer='BatchNorm2d')].
"""
def __init__(self,
enc_bi_rnn=False,
enc_do_rnn=0.0,
enc_gru=False,
d_model=512,
d_enc=512,
mask=True,
init_cfg=[
enc_bi_rnn: bool = False,
rnn_dropout: Union[int, float] = 0.0,
enc_gru: bool = False,
d_model: int = 512,
d_enc: int = 512,
mask: bool = True,
init_cfg: Sequence[Dict] = [
dict(type='Xavier', layer='Conv2d'),
dict(type='Uniform', layer='BatchNorm2d')
],
**kwargs):
**kwargs) -> None:
super().__init__(init_cfg=init_cfg)
assert isinstance(enc_bi_rnn, bool)
assert isinstance(enc_do_rnn, (int, float))
assert 0 <= enc_do_rnn < 1.0
assert isinstance(rnn_dropout, (int, float))
assert 0 <= rnn_dropout < 1.0
assert isinstance(enc_gru, bool)
assert isinstance(d_model, int)
assert isinstance(d_enc, int)
assert isinstance(mask, bool)
self.enc_bi_rnn = enc_bi_rnn
self.enc_do_rnn = enc_do_rnn
self.rnn_dropout = rnn_dropout
self.mask = mask
# LSTM Encoder
@ -57,7 +65,7 @@ class SAREncoder(BaseEncoder):
hidden_size=d_enc,
num_layers=2,
batch_first=True,
dropout=enc_do_rnn,
dropout=rnn_dropout,
bidirectional=enc_bi_rnn)
if enc_gru:
self.rnn_encoder = nn.GRU(**kwargs)
@ -68,24 +76,29 @@ class SAREncoder(BaseEncoder):
encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
def forward(self, feat, img_metas=None):
def forward(
self,
feat: torch.Tensor,
data_samples: Optional[Sequence[TextRecogDataSample]] = None
) -> torch.Tensor:
"""
Args:
feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
img_metas (dict): A dict that contains meta information of input
images. Preferably with the key ``valid_ratio``.
data_samples (list[TextRecogDataSample], optional): Batch of
TextRecogDataSample, containing valid_ratio information.
Defaults to None.
Returns:
Tensor: A tensor of shape :math:`(N, D_m)`.
"""
if img_metas is not None:
assert utils.is_type_list(img_metas, dict)
assert len(img_metas) == feat.size(0)
if data_samples is not None:
assert len(data_samples) == feat.size(0)
valid_ratios = None
if img_metas is not None:
if data_samples is not None:
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
data_sample.get('valid_ratio', 1.0)
for data_sample in data_samples
] if self.mask else None
h_feat = feat.size(2)

View File

@ -1,37 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmocr.models.textrecog.encoders import (ABIVisionModel, BaseEncoder,
NRTREncoder, SAREncoder,
TransformerEncoder)
def test_sar_encoder():
with pytest.raises(AssertionError):
SAREncoder(enc_bi_rnn='bi')
with pytest.raises(AssertionError):
SAREncoder(enc_do_rnn=2)
with pytest.raises(AssertionError):
SAREncoder(enc_gru='gru')
with pytest.raises(AssertionError):
SAREncoder(d_model=512.5)
with pytest.raises(AssertionError):
SAREncoder(d_enc=200.5)
with pytest.raises(AssertionError):
SAREncoder(mask='mask')
encoder = SAREncoder()
encoder.init_weights()
encoder.train()
feat = torch.randn(1, 512, 4, 40)
img_metas = [{'valid_ratio': 1.0}]
with pytest.raises(AssertionError):
encoder(feat, img_metas * 2)
out_enc = encoder(feat, img_metas)
assert out_enc.shape == torch.Size([1, 512])
NRTREncoder, TransformerEncoder)
def test_nrtr_encoder():

View File

@ -0,0 +1,44 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmocr.core.data_structures import TextRecogDataSample
from mmocr.models.textrecog.encoders import SAREncoder
class TestSAREncoder(TestCase):
def setUp(self):
gt_text_sample1 = TextRecogDataSample()
gt_text_sample1.set_metainfo(dict(valid_ratio=0.9))
gt_text_sample2 = TextRecogDataSample()
gt_text_sample2.set_metainfo(dict(valid_ratio=1.0))
self.data_info = [gt_text_sample1, gt_text_sample2]
def test_init(self):
with self.assertRaises(AssertionError):
SAREncoder(enc_bi_rnn='bi')
with self.assertRaises(AssertionError):
SAREncoder(rnn_dropout=2)
with self.assertRaises(AssertionError):
SAREncoder(enc_gru='gru')
with self.assertRaises(AssertionError):
SAREncoder(d_model=512.5)
with self.assertRaises(AssertionError):
SAREncoder(d_enc=200.5)
with self.assertRaises(AssertionError):
SAREncoder(mask='mask')
def test_forward(self):
encoder = SAREncoder()
encoder.init_weights()
encoder.train()
feat = torch.randn(2, 512, 4, 40)
with self.assertRaises(AssertionError):
encoder(feat, self.data_info * 2)
out_enc = encoder(feat, self.data_info)
self.assertEqual(out_enc.shape, torch.Size([2, 512]))