[Model] Add MobilenetV2 Backbone

This commit is contained in:
wangxinyu 2022-07-05 03:31:56 +00:00 committed by gaotongxiao
parent ab6e897c6b
commit 4b185d3347
3 changed files with 58 additions and 1 deletions

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .mobilenet_v2 import MobileNetV2
from .nrtr_modality_transformer import NRTRModalityTransform
from .resnet import ResNet
from .resnet31_ocr import ResNet31OCR
@ -8,5 +9,5 @@ from .very_deep_vgg import VeryDeepVgg
__all__ = [
'ResNet31OCR', 'VeryDeepVgg', 'NRTRModalityTransform', 'ShallowCNN',
'ResNetABI', 'ResNet'
'ResNetABI', 'ResNet', 'MobileNetV2'
]

View File

@ -0,0 +1,39 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch.nn as nn
from mmdet.models.backbones import MobileNetV2 as MMDet_MobileNetV2
from torch import Tensor
from mmocr.registry import MODELS
@MODELS.register_module()
class MobileNetV2(MMDet_MobileNetV2):
"""See mmdet.models.backbones.MobileNetV2 for details.
Args:
pooling_layers (list): List of indices of pooling layers.
"""
# Parameters to build layers. 4 parameters are needed to construct a
# layer, from left to right: expand_ratio, channel, num_blocks, stride.
arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 1],
[6, 64, 4, 1], [6, 96, 3, 1], [6, 160, 3, 1],
[6, 320, 1, 1]]
def __init__(self, pooling_layers: List = [3, 4, 5]) -> None:
super().__init__()
self.pooling = nn.MaxPool2d((2, 2), (2, 1), (0, 1))
self.pooling_layers = pooling_layers
def forward(self, x: Tensor) -> Tensor:
"""Forward function."""
x = self.conv1(x)
for i, layer_name in enumerate(self.layers):
layer = getattr(self, layer_name)
x = layer(x)
if i in self.pooling_layers:
x = self.pooling(x)
return x

View File

@ -0,0 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmocr.models.textrecog.backbones import MobileNetV2
class TestMobileNetV2(TestCase):
def setUp(self) -> None:
self.img = torch.rand(1, 3, 32, 160)
def test_mobilenetv2(self):
mobilenet_v2 = MobileNetV2()
self.assertEqual(
mobilenet_v2(self.img).shape, torch.Size([1, 1280, 1, 43]))