From 47771788f0c79814cf5a8846354acbde3dc37f23 Mon Sep 17 00:00:00 2001 From: liukuikun <641417025@qq.com> Date: Fri, 24 Jun 2022 03:22:56 +0000 Subject: [PATCH] [Encoder] sar encoder --- .../models/textrecog/encoders/sar_encoder.py | 65 +++++++++++-------- old_tests/test_models/test_ocr_encoder.py | 31 +-------- .../test_encoders/test_sar_encoder.py | 44 +++++++++++++ 3 files changed, 84 insertions(+), 56 deletions(-) create mode 100644 tests/test_models/test_textrecog/test_encoders/test_sar_encoder.py diff --git a/mmocr/models/textrecog/encoders/sar_encoder.py b/mmocr/models/textrecog/encoders/sar_encoder.py index 1e48c8ad..23c578fd 100644 --- a/mmocr/models/textrecog/encoders/sar_encoder.py +++ b/mmocr/models/textrecog/encoders/sar_encoder.py @@ -1,11 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import math +from typing import Dict, Optional, Sequence, Union import torch import torch.nn as nn import torch.nn.functional as F -import mmocr.utils as utils +from mmocr.core.data_structures import TextRecogDataSample from mmocr.registry import MODELS from .base_encoder import BaseEncoder @@ -18,37 +19,44 @@ class SAREncoder(BaseEncoder): Args: enc_bi_rnn (bool): If True, use bidirectional RNN in encoder. - enc_do_rnn (float): Dropout probability of RNN layer in encoder. - enc_gru (bool): If True, use GRU, else LSTM in encoder. - d_model (int): Dim :math:`D_i` of channels from backbone. - d_enc (int): Dim :math:`D_m` of encoder RNN layer. - mask (bool): If True, mask padding in RNN sequence. + Defaults to False. + rnn_dropout (float): Dropout probability of RNN layer in encoder. + Defaults to 0.0. + enc_gru (bool): If True, use GRU, else LSTM in encoder. Defaults + to False. + d_model (int): Dim :math:`D_i` of channels from backbone. Defaults + to 512. + d_enc (int): Dim :math:`D_m` of encoder RNN layer. Defaults to 512. + mask (bool): If True, mask padding in RNN sequence. Defaults to + True. init_cfg (dict or list[dict], optional): Initialization configs. + Defaults to [dict(type='Xavier', layer='Conv2d'), + dict(type='Uniform', layer='BatchNorm2d')]. """ def __init__(self, - enc_bi_rnn=False, - enc_do_rnn=0.0, - enc_gru=False, - d_model=512, - d_enc=512, - mask=True, - init_cfg=[ + enc_bi_rnn: bool = False, + rnn_dropout: Union[int, float] = 0.0, + enc_gru: bool = False, + d_model: int = 512, + d_enc: int = 512, + mask: bool = True, + init_cfg: Sequence[Dict] = [ dict(type='Xavier', layer='Conv2d'), dict(type='Uniform', layer='BatchNorm2d') ], - **kwargs): + **kwargs) -> None: super().__init__(init_cfg=init_cfg) assert isinstance(enc_bi_rnn, bool) - assert isinstance(enc_do_rnn, (int, float)) - assert 0 <= enc_do_rnn < 1.0 + assert isinstance(rnn_dropout, (int, float)) + assert 0 <= rnn_dropout < 1.0 assert isinstance(enc_gru, bool) assert isinstance(d_model, int) assert isinstance(d_enc, int) assert isinstance(mask, bool) self.enc_bi_rnn = enc_bi_rnn - self.enc_do_rnn = enc_do_rnn + self.rnn_dropout = rnn_dropout self.mask = mask # LSTM Encoder @@ -57,7 +65,7 @@ class SAREncoder(BaseEncoder): hidden_size=d_enc, num_layers=2, batch_first=True, - dropout=enc_do_rnn, + dropout=rnn_dropout, bidirectional=enc_bi_rnn) if enc_gru: self.rnn_encoder = nn.GRU(**kwargs) @@ -68,24 +76,29 @@ class SAREncoder(BaseEncoder): encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1) self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size) - def forward(self, feat, img_metas=None): + def forward( + self, + feat: torch.Tensor, + data_samples: Optional[Sequence[TextRecogDataSample]] = None + ) -> torch.Tensor: """ Args: feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`. - img_metas (dict): A dict that contains meta information of input - images. Preferably with the key ``valid_ratio``. + data_samples (list[TextRecogDataSample], optional): Batch of + TextRecogDataSample, containing valid_ratio information. + Defaults to None. Returns: Tensor: A tensor of shape :math:`(N, D_m)`. """ - if img_metas is not None: - assert utils.is_type_list(img_metas, dict) - assert len(img_metas) == feat.size(0) + if data_samples is not None: + assert len(data_samples) == feat.size(0) valid_ratios = None - if img_metas is not None: + if data_samples is not None: valid_ratios = [ - img_meta.get('valid_ratio', 1.0) for img_meta in img_metas + data_sample.get('valid_ratio', 1.0) + for data_sample in data_samples ] if self.mask else None h_feat = feat.size(2) diff --git a/old_tests/test_models/test_ocr_encoder.py b/old_tests/test_models/test_ocr_encoder.py index 2cfd8ea9..5218b3ad 100644 --- a/old_tests/test_models/test_ocr_encoder.py +++ b/old_tests/test_models/test_ocr_encoder.py @@ -1,37 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -import pytest import torch from mmocr.models.textrecog.encoders import (ABIVisionModel, BaseEncoder, - NRTREncoder, SAREncoder, - TransformerEncoder) - - -def test_sar_encoder(): - with pytest.raises(AssertionError): - SAREncoder(enc_bi_rnn='bi') - with pytest.raises(AssertionError): - SAREncoder(enc_do_rnn=2) - with pytest.raises(AssertionError): - SAREncoder(enc_gru='gru') - with pytest.raises(AssertionError): - SAREncoder(d_model=512.5) - with pytest.raises(AssertionError): - SAREncoder(d_enc=200.5) - with pytest.raises(AssertionError): - SAREncoder(mask='mask') - - encoder = SAREncoder() - encoder.init_weights() - encoder.train() - - feat = torch.randn(1, 512, 4, 40) - img_metas = [{'valid_ratio': 1.0}] - with pytest.raises(AssertionError): - encoder(feat, img_metas * 2) - out_enc = encoder(feat, img_metas) - - assert out_enc.shape == torch.Size([1, 512]) + NRTREncoder, TransformerEncoder) def test_nrtr_encoder(): diff --git a/tests/test_models/test_textrecog/test_encoders/test_sar_encoder.py b/tests/test_models/test_textrecog/test_encoders/test_sar_encoder.py new file mode 100644 index 00000000..51fe6015 --- /dev/null +++ b/tests/test_models/test_textrecog/test_encoders/test_sar_encoder.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmocr.core.data_structures import TextRecogDataSample +from mmocr.models.textrecog.encoders import SAREncoder + + +class TestSAREncoder(TestCase): + + def setUp(self): + gt_text_sample1 = TextRecogDataSample() + gt_text_sample1.set_metainfo(dict(valid_ratio=0.9)) + + gt_text_sample2 = TextRecogDataSample() + gt_text_sample2.set_metainfo(dict(valid_ratio=1.0)) + + self.data_info = [gt_text_sample1, gt_text_sample2] + + def test_init(self): + with self.assertRaises(AssertionError): + SAREncoder(enc_bi_rnn='bi') + with self.assertRaises(AssertionError): + SAREncoder(rnn_dropout=2) + with self.assertRaises(AssertionError): + SAREncoder(enc_gru='gru') + with self.assertRaises(AssertionError): + SAREncoder(d_model=512.5) + with self.assertRaises(AssertionError): + SAREncoder(d_enc=200.5) + with self.assertRaises(AssertionError): + SAREncoder(mask='mask') + + def test_forward(self): + encoder = SAREncoder() + encoder.init_weights() + encoder.train() + + feat = torch.randn(2, 512, 4, 40) + with self.assertRaises(AssertionError): + encoder(feat, self.data_info * 2) + out_enc = encoder(feat, self.data_info) + self.assertEqual(out_enc.shape, torch.Size([2, 512]))