mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[Model] Add MobilenetV2 Backbone
This commit is contained in:
parent
ab6e897c6b
commit
4b185d3347
@ -1,4 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .mobilenet_v2 import MobileNetV2
|
||||
from .nrtr_modality_transformer import NRTRModalityTransform
|
||||
from .resnet import ResNet
|
||||
from .resnet31_ocr import ResNet31OCR
|
||||
@ -8,5 +9,5 @@ from .very_deep_vgg import VeryDeepVgg
|
||||
|
||||
__all__ = [
|
||||
'ResNet31OCR', 'VeryDeepVgg', 'NRTRModalityTransform', 'ShallowCNN',
|
||||
'ResNetABI', 'ResNet'
|
||||
'ResNetABI', 'ResNet', 'MobileNetV2'
|
||||
]
|
||||
|
39
mmocr/models/textrecog/backbones/mobilenet_v2.py
Normal file
39
mmocr/models/textrecog/backbones/mobilenet_v2.py
Normal file
@ -0,0 +1,39 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
import torch.nn as nn
|
||||
from mmdet.models.backbones import MobileNetV2 as MMDet_MobileNetV2
|
||||
from torch import Tensor
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MobileNetV2(MMDet_MobileNetV2):
|
||||
"""See mmdet.models.backbones.MobileNetV2 for details.
|
||||
|
||||
Args:
|
||||
pooling_layers (list): List of indices of pooling layers.
|
||||
"""
|
||||
# Parameters to build layers. 4 parameters are needed to construct a
|
||||
# layer, from left to right: expand_ratio, channel, num_blocks, stride.
|
||||
arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 1],
|
||||
[6, 64, 4, 1], [6, 96, 3, 1], [6, 160, 3, 1],
|
||||
[6, 320, 1, 1]]
|
||||
|
||||
def __init__(self, pooling_layers: List = [3, 4, 5]) -> None:
|
||||
super().__init__()
|
||||
self.pooling = nn.MaxPool2d((2, 2), (2, 1), (0, 1))
|
||||
self.pooling_layers = pooling_layers
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Forward function."""
|
||||
|
||||
x = self.conv1(x)
|
||||
for i, layer_name in enumerate(self.layers):
|
||||
layer = getattr(self, layer_name)
|
||||
x = layer(x)
|
||||
if i in self.pooling_layers:
|
||||
x = self.pooling(x)
|
||||
|
||||
return x
|
@ -0,0 +1,17 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
|
||||
from mmocr.models.textrecog.backbones import MobileNetV2
|
||||
|
||||
|
||||
class TestMobileNetV2(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.img = torch.rand(1, 3, 32, 160)
|
||||
|
||||
def test_mobilenetv2(self):
|
||||
mobilenet_v2 = MobileNetV2()
|
||||
self.assertEqual(
|
||||
mobilenet_v2(self.img).shape, torch.Size([1, 1280, 1, 43]))
|
Loading…
x
Reference in New Issue
Block a user