diff --git a/mmocr/models/textrecog/backbones/__init__.py b/mmocr/models/textrecog/backbones/__init__.py index e9b68c38..09d8ad32 100644 --- a/mmocr/models/textrecog/backbones/__init__.py +++ b/mmocr/models/textrecog/backbones/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .mobilenet_v2 import MobileNetV2 from .nrtr_modality_transformer import NRTRModalityTransform from .resnet import ResNet from .resnet31_ocr import ResNet31OCR @@ -8,5 +9,5 @@ from .very_deep_vgg import VeryDeepVgg __all__ = [ 'ResNet31OCR', 'VeryDeepVgg', 'NRTRModalityTransform', 'ShallowCNN', - 'ResNetABI', 'ResNet' + 'ResNetABI', 'ResNet', 'MobileNetV2' ] diff --git a/mmocr/models/textrecog/backbones/mobilenet_v2.py b/mmocr/models/textrecog/backbones/mobilenet_v2.py new file mode 100644 index 00000000..ec0a02b1 --- /dev/null +++ b/mmocr/models/textrecog/backbones/mobilenet_v2.py @@ -0,0 +1,39 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch.nn as nn +from mmdet.models.backbones import MobileNetV2 as MMDet_MobileNetV2 +from torch import Tensor + +from mmocr.registry import MODELS + + +@MODELS.register_module() +class MobileNetV2(MMDet_MobileNetV2): + """See mmdet.models.backbones.MobileNetV2 for details. + + Args: + pooling_layers (list): List of indices of pooling layers. + """ + # Parameters to build layers. 4 parameters are needed to construct a + # layer, from left to right: expand_ratio, channel, num_blocks, stride. + arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 1], + [6, 64, 4, 1], [6, 96, 3, 1], [6, 160, 3, 1], + [6, 320, 1, 1]] + + def __init__(self, pooling_layers: List = [3, 4, 5]) -> None: + super().__init__() + self.pooling = nn.MaxPool2d((2, 2), (2, 1), (0, 1)) + self.pooling_layers = pooling_layers + + def forward(self, x: Tensor) -> Tensor: + """Forward function.""" + + x = self.conv1(x) + for i, layer_name in enumerate(self.layers): + layer = getattr(self, layer_name) + x = layer(x) + if i in self.pooling_layers: + x = self.pooling(x) + + return x diff --git a/tests/test_models/test_textrecog/test_backbones/test_mobilenet_v2.py b/tests/test_models/test_textrecog/test_backbones/test_mobilenet_v2.py new file mode 100644 index 00000000..7f55ef44 --- /dev/null +++ b/tests/test_models/test_textrecog/test_backbones/test_mobilenet_v2.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmocr.models.textrecog.backbones import MobileNetV2 + + +class TestMobileNetV2(TestCase): + + def setUp(self) -> None: + self.img = torch.rand(1, 3, 32, 160) + + def test_mobilenetv2(self): + mobilenet_v2 = MobileNetV2() + self.assertEqual( + mobilenet_v2(self.img).shape, torch.Size([1, 1280, 1, 43]))