mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
change argument names according to convention (#131)
* change argument names according to convention * bug fix when rename leakyRelu
This commit is contained in:
parent
47896a3f80
commit
e7f27ae317
@ -4,7 +4,7 @@ label_convertor = dict(
|
|||||||
model = dict(
|
model = dict(
|
||||||
type='CRNNNet',
|
type='CRNNNet',
|
||||||
preprocessor=None,
|
preprocessor=None,
|
||||||
backbone=dict(type='VeryDeepVgg', leakyRelu=False),
|
backbone=dict(type='VeryDeepVgg', leaky_relu=False),
|
||||||
encoder=None,
|
encoder=None,
|
||||||
decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True),
|
decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True),
|
||||||
loss=dict(type='CTCLoss', flatten=False),
|
loss=dict(type='CTCLoss', flatten=False),
|
||||||
|
@ -21,7 +21,7 @@ label_convertor = dict(
|
|||||||
model = dict(
|
model = dict(
|
||||||
type='CRNNNet',
|
type='CRNNNet',
|
||||||
preprocessor=None,
|
preprocessor=None,
|
||||||
backbone=dict(type='VeryDeepVgg', leakyRelu=False, input_channels=1),
|
backbone=dict(type='VeryDeepVgg', leaky_relu=False, input_channels=1),
|
||||||
encoder=None,
|
encoder=None,
|
||||||
decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True),
|
decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True),
|
||||||
loss=dict(type='CTCLoss'),
|
loss=dict(type='CTCLoss'),
|
||||||
|
@ -9,11 +9,11 @@ class VeryDeepVgg(nn.Module):
|
|||||||
"""Implement VGG-VeryDeep backbone for text recognition, modified from
|
"""Implement VGG-VeryDeep backbone for text recognition, modified from
|
||||||
`VGG-VeryDeep <https://arxiv.org/pdf/1409.1556.pdf>`_
|
`VGG-VeryDeep <https://arxiv.org/pdf/1409.1556.pdf>`_
|
||||||
Args:
|
Args:
|
||||||
|
leaky_relu (bool): Use leakyRelu or not.
|
||||||
input_channels (int): Number of channels of input image tensor.
|
input_channels (int): Number of channels of input image tensor.
|
||||||
leakyRelu (bool): Use leakyRelu or not.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, leakyRelu=True, input_channels=3):
|
def __init__(self, leaky_relu=True, input_channels=3):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
ks = [3, 3, 3, 3, 3, 3, 2]
|
ks = [3, 3, 3, 3, 3, 3, 2]
|
||||||
@ -25,32 +25,32 @@ class VeryDeepVgg(nn.Module):
|
|||||||
|
|
||||||
cnn = nn.Sequential()
|
cnn = nn.Sequential()
|
||||||
|
|
||||||
def convRelu(i, batchNormalization=False):
|
def conv_relu(i, batch_normalization=False):
|
||||||
nIn = input_channels if i == 0 else nm[i - 1]
|
n_in = input_channels if i == 0 else nm[i - 1]
|
||||||
nOut = nm[i]
|
n_out = nm[i]
|
||||||
cnn.add_module('conv{0}'.format(i),
|
cnn.add_module('conv{0}'.format(i),
|
||||||
nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
|
nn.Conv2d(n_in, n_out, ks[i], ss[i], ps[i]))
|
||||||
if batchNormalization:
|
if batch_normalization:
|
||||||
cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
|
cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(n_out))
|
||||||
if leakyRelu:
|
if leaky_relu:
|
||||||
cnn.add_module('relu{0}'.format(i),
|
cnn.add_module('relu{0}'.format(i),
|
||||||
nn.LeakyReLU(0.2, inplace=True))
|
nn.LeakyReLU(0.2, inplace=True))
|
||||||
else:
|
else:
|
||||||
cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
|
cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
|
||||||
|
|
||||||
convRelu(0)
|
conv_relu(0)
|
||||||
cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64
|
cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64
|
||||||
convRelu(1)
|
conv_relu(1)
|
||||||
cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
|
cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
|
||||||
convRelu(2, True)
|
conv_relu(2, True)
|
||||||
convRelu(3)
|
conv_relu(3)
|
||||||
cnn.add_module('pooling{0}'.format(2),
|
cnn.add_module('pooling{0}'.format(2),
|
||||||
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
|
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
|
||||||
convRelu(4, True)
|
conv_relu(4, True)
|
||||||
convRelu(5)
|
conv_relu(5)
|
||||||
cnn.add_module('pooling{0}'.format(3),
|
cnn.add_module('pooling{0}'.format(3),
|
||||||
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
|
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
|
||||||
convRelu(6, True) # 512x1x16
|
conv_relu(6, True) # 512x1x16
|
||||||
|
|
||||||
self.cnn = cnn
|
self.cnn = cnn
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ def test_base_recognizer():
|
|||||||
type='CTCConvertor', dict_file=dict_file, with_unknown=False)
|
type='CTCConvertor', dict_file=dict_file, with_unknown=False)
|
||||||
|
|
||||||
preprocessor = None
|
preprocessor = None
|
||||||
backbone = dict(type='VeryDeepVgg', leakyRelu=False)
|
backbone = dict(type='VeryDeepVgg', leaky_relu=False)
|
||||||
encoder = None
|
encoder = None
|
||||||
decoder = dict(type='CRNNDecoder', in_channels=512, rnn_flag=True)
|
decoder = dict(type='CRNNDecoder', in_channels=512, rnn_flag=True)
|
||||||
loss = dict(type='CTCLoss')
|
loss = dict(type='CTCLoss')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user