diff --git a/mmocr/models/textrecog/encoders/channel_reduction_encoder.py b/mmocr/models/textrecog/encoders/channel_reduction_encoder.py index 790c57f2..0f6797d1 100644 --- a/mmocr/models/textrecog/encoders/channel_reduction_encoder.py +++ b/mmocr/models/textrecog/encoders/channel_reduction_encoder.py @@ -1,6 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional, Sequence + +import torch import torch.nn as nn +from mmocr.core.data_structures import TextRecogDataSample from mmocr.registry import MODELS from .base_encoder import BaseEncoder @@ -13,23 +17,32 @@ class ChannelReductionEncoder(BaseEncoder): in_channels (int): Number of input channels. out_channels (int): Number of output channels. init_cfg (dict or list[dict], optional): Initialization configs. + Defaults to dict(type='Xavier', layer='Conv2d'). """ - def __init__(self, - in_channels, - out_channels, - init_cfg=dict(type='Xavier', layer='Conv2d')): + def __init__( + self, + in_channels: int, + out_channels: int, + init_cfg: Dict = dict(type='Xavier', layer='Conv2d') + ) -> None: super().__init__(init_cfg=init_cfg) self.layer = nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0) - def forward(self, feat, img_metas=None): + def forward( + self, + feat: torch.Tensor, + data_samples: Optional[Sequence[TextRecogDataSample]] = None + ) -> torch.Tensor: """ Args: feat (Tensor): Image features with the shape of :math:`(N, C_{in}, H, W)`. - img_metas (None): Unused. + data_samples (list[TextRecogDataSample], optional): Batch of + TextRecogDataSample, containing valid_ratio information. + Defaults to None. Returns: Tensor: A tensor of shape :math:`(N, C_{out}, H, W)`. diff --git a/tests/test_models/test_textrecog/test_encoders/test_channel_reduction_encoder.py b/tests/test_models/test_textrecog/test_encoders/test_channel_reduction_encoder.py new file mode 100644 index 00000000..dd1fa5e8 --- /dev/null +++ b/tests/test_models/test_textrecog/test_encoders/test_channel_reduction_encoder.py @@ -0,0 +1,26 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import unittest + +import torch + +from mmocr.core.data_structures import TextRecogDataSample +from mmocr.models.textrecog.encoders import ChannelReductionEncoder + + +class TestChannelReductionEncoder(unittest.TestCase): + + def setUp(self): + self.feat = torch.randn(2, 512, 8, 25) + gt_text_sample1 = TextRecogDataSample() + gt_text_sample1.set_metainfo(dict(valid_ratio=0.9)) + + gt_text_sample2 = TextRecogDataSample() + gt_text_sample2.set_metainfo(dict(valid_ratio=1.0)) + + self.data_info = [gt_text_sample1, gt_text_sample2] + + def test_encoder(self): + encoder = ChannelReductionEncoder(512, 256) + encoder.train() + out_enc = encoder(self.feat, self.data_info) + self.assertEqual(out_enc.shape, torch.Size([2, 256, 8, 25]))