mirror of https://github.com/open-mmlab/mmocr.git
[Encoder] sar encoder
parent
fe43b4e767
commit
47771788f0
|
@ -1,11 +1,12 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from typing import Dict, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmocr.core.data_structures import TextRecogDataSample
|
||||
from mmocr.registry import MODELS
|
||||
from .base_encoder import BaseEncoder
|
||||
|
||||
|
@ -18,37 +19,44 @@ class SAREncoder(BaseEncoder):
|
|||
|
||||
Args:
|
||||
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
|
||||
enc_do_rnn (float): Dropout probability of RNN layer in encoder.
|
||||
enc_gru (bool): If True, use GRU, else LSTM in encoder.
|
||||
d_model (int): Dim :math:`D_i` of channels from backbone.
|
||||
d_enc (int): Dim :math:`D_m` of encoder RNN layer.
|
||||
mask (bool): If True, mask padding in RNN sequence.
|
||||
Defaults to False.
|
||||
rnn_dropout (float): Dropout probability of RNN layer in encoder.
|
||||
Defaults to 0.0.
|
||||
enc_gru (bool): If True, use GRU, else LSTM in encoder. Defaults
|
||||
to False.
|
||||
d_model (int): Dim :math:`D_i` of channels from backbone. Defaults
|
||||
to 512.
|
||||
d_enc (int): Dim :math:`D_m` of encoder RNN layer. Defaults to 512.
|
||||
mask (bool): If True, mask padding in RNN sequence. Defaults to
|
||||
True.
|
||||
init_cfg (dict or list[dict], optional): Initialization configs.
|
||||
Defaults to [dict(type='Xavier', layer='Conv2d'),
|
||||
dict(type='Uniform', layer='BatchNorm2d')].
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
enc_bi_rnn=False,
|
||||
enc_do_rnn=0.0,
|
||||
enc_gru=False,
|
||||
d_model=512,
|
||||
d_enc=512,
|
||||
mask=True,
|
||||
init_cfg=[
|
||||
enc_bi_rnn: bool = False,
|
||||
rnn_dropout: Union[int, float] = 0.0,
|
||||
enc_gru: bool = False,
|
||||
d_model: int = 512,
|
||||
d_enc: int = 512,
|
||||
mask: bool = True,
|
||||
init_cfg: Sequence[Dict] = [
|
||||
dict(type='Xavier', layer='Conv2d'),
|
||||
dict(type='Uniform', layer='BatchNorm2d')
|
||||
],
|
||||
**kwargs):
|
||||
**kwargs) -> None:
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert isinstance(enc_bi_rnn, bool)
|
||||
assert isinstance(enc_do_rnn, (int, float))
|
||||
assert 0 <= enc_do_rnn < 1.0
|
||||
assert isinstance(rnn_dropout, (int, float))
|
||||
assert 0 <= rnn_dropout < 1.0
|
||||
assert isinstance(enc_gru, bool)
|
||||
assert isinstance(d_model, int)
|
||||
assert isinstance(d_enc, int)
|
||||
assert isinstance(mask, bool)
|
||||
|
||||
self.enc_bi_rnn = enc_bi_rnn
|
||||
self.enc_do_rnn = enc_do_rnn
|
||||
self.rnn_dropout = rnn_dropout
|
||||
self.mask = mask
|
||||
|
||||
# LSTM Encoder
|
||||
|
@ -57,7 +65,7 @@ class SAREncoder(BaseEncoder):
|
|||
hidden_size=d_enc,
|
||||
num_layers=2,
|
||||
batch_first=True,
|
||||
dropout=enc_do_rnn,
|
||||
dropout=rnn_dropout,
|
||||
bidirectional=enc_bi_rnn)
|
||||
if enc_gru:
|
||||
self.rnn_encoder = nn.GRU(**kwargs)
|
||||
|
@ -68,24 +76,29 @@ class SAREncoder(BaseEncoder):
|
|||
encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
|
||||
self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
|
||||
|
||||
def forward(self, feat, img_metas=None):
|
||||
def forward(
|
||||
self,
|
||||
feat: torch.Tensor,
|
||||
data_samples: Optional[Sequence[TextRecogDataSample]] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
|
||||
img_metas (dict): A dict that contains meta information of input
|
||||
images. Preferably with the key ``valid_ratio``.
|
||||
data_samples (list[TextRecogDataSample], optional): Batch of
|
||||
TextRecogDataSample, containing valid_ratio information.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tensor: A tensor of shape :math:`(N, D_m)`.
|
||||
"""
|
||||
if img_metas is not None:
|
||||
assert utils.is_type_list(img_metas, dict)
|
||||
assert len(img_metas) == feat.size(0)
|
||||
if data_samples is not None:
|
||||
assert len(data_samples) == feat.size(0)
|
||||
|
||||
valid_ratios = None
|
||||
if img_metas is not None:
|
||||
if data_samples is not None:
|
||||
valid_ratios = [
|
||||
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||
data_sample.get('valid_ratio', 1.0)
|
||||
for data_sample in data_samples
|
||||
] if self.mask else None
|
||||
|
||||
h_feat = feat.size(2)
|
||||
|
|
|
@ -1,37 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmocr.models.textrecog.encoders import (ABIVisionModel, BaseEncoder,
|
||||
NRTREncoder, SAREncoder,
|
||||
TransformerEncoder)
|
||||
|
||||
|
||||
def test_sar_encoder():
|
||||
with pytest.raises(AssertionError):
|
||||
SAREncoder(enc_bi_rnn='bi')
|
||||
with pytest.raises(AssertionError):
|
||||
SAREncoder(enc_do_rnn=2)
|
||||
with pytest.raises(AssertionError):
|
||||
SAREncoder(enc_gru='gru')
|
||||
with pytest.raises(AssertionError):
|
||||
SAREncoder(d_model=512.5)
|
||||
with pytest.raises(AssertionError):
|
||||
SAREncoder(d_enc=200.5)
|
||||
with pytest.raises(AssertionError):
|
||||
SAREncoder(mask='mask')
|
||||
|
||||
encoder = SAREncoder()
|
||||
encoder.init_weights()
|
||||
encoder.train()
|
||||
|
||||
feat = torch.randn(1, 512, 4, 40)
|
||||
img_metas = [{'valid_ratio': 1.0}]
|
||||
with pytest.raises(AssertionError):
|
||||
encoder(feat, img_metas * 2)
|
||||
out_enc = encoder(feat, img_metas)
|
||||
|
||||
assert out_enc.shape == torch.Size([1, 512])
|
||||
NRTREncoder, TransformerEncoder)
|
||||
|
||||
|
||||
def test_nrtr_encoder():
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
|
||||
from mmocr.core.data_structures import TextRecogDataSample
|
||||
from mmocr.models.textrecog.encoders import SAREncoder
|
||||
|
||||
|
||||
class TestSAREncoder(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
gt_text_sample1 = TextRecogDataSample()
|
||||
gt_text_sample1.set_metainfo(dict(valid_ratio=0.9))
|
||||
|
||||
gt_text_sample2 = TextRecogDataSample()
|
||||
gt_text_sample2.set_metainfo(dict(valid_ratio=1.0))
|
||||
|
||||
self.data_info = [gt_text_sample1, gt_text_sample2]
|
||||
|
||||
def test_init(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
SAREncoder(enc_bi_rnn='bi')
|
||||
with self.assertRaises(AssertionError):
|
||||
SAREncoder(rnn_dropout=2)
|
||||
with self.assertRaises(AssertionError):
|
||||
SAREncoder(enc_gru='gru')
|
||||
with self.assertRaises(AssertionError):
|
||||
SAREncoder(d_model=512.5)
|
||||
with self.assertRaises(AssertionError):
|
||||
SAREncoder(d_enc=200.5)
|
||||
with self.assertRaises(AssertionError):
|
||||
SAREncoder(mask='mask')
|
||||
|
||||
def test_forward(self):
|
||||
encoder = SAREncoder()
|
||||
encoder.init_weights()
|
||||
encoder.train()
|
||||
|
||||
feat = torch.randn(2, 512, 4, 40)
|
||||
with self.assertRaises(AssertionError):
|
||||
encoder(feat, self.data_info * 2)
|
||||
out_enc = encoder(feat, self.data_info)
|
||||
self.assertEqual(out_enc.shape, torch.Size([2, 512]))
|
Loading…
Reference in New Issue