mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[NRTR] NRTR backbone
This commit is contained in:
parent
781166764c
commit
d41921f03d
@ -1,4 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
@ -7,17 +10,25 @@ from mmocr.registry import MODELS
|
||||
|
||||
@MODELS.register_module()
|
||||
class NRTRModalityTransform(BaseModule):
|
||||
"""Modality transform in NRTR.
|
||||
|
||||
def __init__(self,
|
||||
input_channels=3,
|
||||
init_cfg=[
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(type='Uniform', layer='BatchNorm2d')
|
||||
]):
|
||||
Args:
|
||||
in_channels (int): Input channel of image. Defaults to 3.
|
||||
init_cfg (dict or list[dict], optional): Initialization configs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
init_cfg: Optional[Union[Dict, Sequence[Dict]]] = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(type='Uniform', layer='BatchNorm2d')
|
||||
]
|
||||
) -> None:
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.conv_1 = nn.Conv2d(
|
||||
in_channels=input_channels,
|
||||
in_channels=in_channels,
|
||||
out_channels=32,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
@ -36,7 +47,15 @@ class NRTRModalityTransform(BaseModule):
|
||||
|
||||
self.linear = nn.Linear(512, 512)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Backbone forward.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Image tensor of shape :math:`(N, C, W, H)`. W, H
|
||||
is the width and height of image.
|
||||
Return:
|
||||
Tensor: Output tensor.
|
||||
"""
|
||||
x = self.conv_1(x)
|
||||
x = self.relu_1(x)
|
||||
x = self.bn_1(x)
|
||||
|
@ -0,0 +1,19 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from mmocr.models.textrecog.backbones import NRTRModalityTransform
|
||||
|
||||
|
||||
class TestNRTRBackbone(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.img = torch.randn(2, 3, 32, 100)
|
||||
|
||||
def test_encoder(self):
|
||||
nrtr_backbone = NRTRModalityTransform()
|
||||
nrtr_backbone.init_weights()
|
||||
nrtr_backbone.train()
|
||||
out_enc = nrtr_backbone(self.img)
|
||||
self.assertEqual(out_enc.shape, torch.Size([2, 512, 1, 25]))
|
Loading…
x
Reference in New Issue
Block a user