mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[Model] Add MobilenetV2 Backbone
This commit is contained in:
parent
ab6e897c6b
commit
4b185d3347
@ -1,4 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from .mobilenet_v2 import MobileNetV2
|
||||||
from .nrtr_modality_transformer import NRTRModalityTransform
|
from .nrtr_modality_transformer import NRTRModalityTransform
|
||||||
from .resnet import ResNet
|
from .resnet import ResNet
|
||||||
from .resnet31_ocr import ResNet31OCR
|
from .resnet31_ocr import ResNet31OCR
|
||||||
@ -8,5 +9,5 @@ from .very_deep_vgg import VeryDeepVgg
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ResNet31OCR', 'VeryDeepVgg', 'NRTRModalityTransform', 'ShallowCNN',
|
'ResNet31OCR', 'VeryDeepVgg', 'NRTRModalityTransform', 'ShallowCNN',
|
||||||
'ResNetABI', 'ResNet'
|
'ResNetABI', 'ResNet', 'MobileNetV2'
|
||||||
]
|
]
|
||||||
|
39
mmocr/models/textrecog/backbones/mobilenet_v2.py
Normal file
39
mmocr/models/textrecog/backbones/mobilenet_v2.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from mmdet.models.backbones import MobileNetV2 as MMDet_MobileNetV2
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
|
@MODELS.register_module()
|
||||||
|
class MobileNetV2(MMDet_MobileNetV2):
|
||||||
|
"""See mmdet.models.backbones.MobileNetV2 for details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pooling_layers (list): List of indices of pooling layers.
|
||||||
|
"""
|
||||||
|
# Parameters to build layers. 4 parameters are needed to construct a
|
||||||
|
# layer, from left to right: expand_ratio, channel, num_blocks, stride.
|
||||||
|
arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 1],
|
||||||
|
[6, 64, 4, 1], [6, 96, 3, 1], [6, 160, 3, 1],
|
||||||
|
[6, 320, 1, 1]]
|
||||||
|
|
||||||
|
def __init__(self, pooling_layers: List = [3, 4, 5]) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.pooling = nn.MaxPool2d((2, 2), (2, 1), (0, 1))
|
||||||
|
self.pooling_layers = pooling_layers
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
"""Forward function."""
|
||||||
|
|
||||||
|
x = self.conv1(x)
|
||||||
|
for i, layer_name in enumerate(self.layers):
|
||||||
|
layer = getattr(self, layer_name)
|
||||||
|
x = layer(x)
|
||||||
|
if i in self.pooling_layers:
|
||||||
|
x = self.pooling(x)
|
||||||
|
|
||||||
|
return x
|
@ -0,0 +1,17 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from unittest import TestCase
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmocr.models.textrecog.backbones import MobileNetV2
|
||||||
|
|
||||||
|
|
||||||
|
class TestMobileNetV2(TestCase):
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.img = torch.rand(1, 3, 32, 160)
|
||||||
|
|
||||||
|
def test_mobilenetv2(self):
|
||||||
|
mobilenet_v2 = MobileNetV2()
|
||||||
|
self.assertEqual(
|
||||||
|
mobilenet_v2(self.img).shape, torch.Size([1, 1280, 1, 43]))
|
Loading…
x
Reference in New Issue
Block a user