diff --git a/mmocr/models/textrecog/backbones/nrtr_modality_transformer.py b/mmocr/models/textrecog/backbones/nrtr_modality_transformer.py index 5ef99132..9690f14f 100644 --- a/mmocr/models/textrecog/backbones/nrtr_modality_transformer.py +++ b/mmocr/models/textrecog/backbones/nrtr_modality_transformer.py @@ -1,4 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional, Sequence, Union + +import torch import torch.nn as nn from mmcv.runner import BaseModule @@ -7,17 +10,25 @@ from mmocr.registry import MODELS @MODELS.register_module() class NRTRModalityTransform(BaseModule): + """Modality transform in NRTR. - def __init__(self, - input_channels=3, - init_cfg=[ - dict(type='Kaiming', layer='Conv2d'), - dict(type='Uniform', layer='BatchNorm2d') - ]): + Args: + in_channels (int): Input channel of image. Defaults to 3. + init_cfg (dict or list[dict], optional): Initialization configs. + """ + + def __init__( + self, + in_channels: int = 3, + init_cfg: Optional[Union[Dict, Sequence[Dict]]] = [ + dict(type='Kaiming', layer='Conv2d'), + dict(type='Uniform', layer='BatchNorm2d') + ] + ) -> None: super().__init__(init_cfg=init_cfg) self.conv_1 = nn.Conv2d( - in_channels=input_channels, + in_channels=in_channels, out_channels=32, kernel_size=3, stride=2, @@ -36,7 +47,15 @@ class NRTRModalityTransform(BaseModule): self.linear = nn.Linear(512, 512) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Backbone forward. + + Args: + x (torch.Tensor): Image tensor of shape :math:`(N, C, W, H)`. W, H + is the width and height of image. + Return: + Tensor: Output tensor. + """ x = self.conv_1(x) x = self.relu_1(x) x = self.bn_1(x) diff --git a/tests/test_models/test_textrecog/test_backbone/test_nrtr_backbone.py b/tests/test_models/test_textrecog/test_backbone/test_nrtr_backbone.py new file mode 100644 index 00000000..3243cd40 --- /dev/null +++ b/tests/test_models/test_textrecog/test_backbone/test_nrtr_backbone.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import unittest + +import torch + +from mmocr.models.textrecog.backbones import NRTRModalityTransform + + +class TestNRTRBackbone(unittest.TestCase): + + def setUp(self): + self.img = torch.randn(2, 3, 32, 100) + + def test_encoder(self): + nrtr_backbone = NRTRModalityTransform() + nrtr_backbone.init_weights() + nrtr_backbone.train() + out_enc = nrtr_backbone(self.img) + self.assertEqual(out_enc.shape, torch.Size([2, 512, 1, 25]))