change argument names according to convention (#131)

* change argument names according to convention

* bug fix when rename leakyRelu
pull/166/head
kang sheng 2021-05-11 16:07:36 +08:00 committed by GitHub
parent 47896a3f80
commit e7f27ae317
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 19 additions and 19 deletions

View File

@ -4,7 +4,7 @@ label_convertor = dict(
model = dict(
type='CRNNNet',
preprocessor=None,
backbone=dict(type='VeryDeepVgg', leakyRelu=False),
backbone=dict(type='VeryDeepVgg', leaky_relu=False),
encoder=None,
decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True),
loss=dict(type='CTCLoss', flatten=False),

View File

@ -21,7 +21,7 @@ label_convertor = dict(
model = dict(
type='CRNNNet',
preprocessor=None,
backbone=dict(type='VeryDeepVgg', leakyRelu=False, input_channels=1),
backbone=dict(type='VeryDeepVgg', leaky_relu=False, input_channels=1),
encoder=None,
decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True),
loss=dict(type='CTCLoss'),

View File

@ -9,11 +9,11 @@ class VeryDeepVgg(nn.Module):
"""Implement VGG-VeryDeep backbone for text recognition, modified from
`VGG-VeryDeep <https://arxiv.org/pdf/1409.1556.pdf>`_
Args:
leaky_relu (bool): Use leakyRelu or not.
input_channels (int): Number of channels of input image tensor.
leakyRelu (bool): Use leakyRelu or not.
"""
def __init__(self, leakyRelu=True, input_channels=3):
def __init__(self, leaky_relu=True, input_channels=3):
super().__init__()
ks = [3, 3, 3, 3, 3, 3, 2]
@ -25,32 +25,32 @@ class VeryDeepVgg(nn.Module):
cnn = nn.Sequential()
def convRelu(i, batchNormalization=False):
nIn = input_channels if i == 0 else nm[i - 1]
nOut = nm[i]
def conv_relu(i, batch_normalization=False):
n_in = input_channels if i == 0 else nm[i - 1]
n_out = nm[i]
cnn.add_module('conv{0}'.format(i),
nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
if batchNormalization:
cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
if leakyRelu:
nn.Conv2d(n_in, n_out, ks[i], ss[i], ps[i]))
if batch_normalization:
cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(n_out))
if leaky_relu:
cnn.add_module('relu{0}'.format(i),
nn.LeakyReLU(0.2, inplace=True))
else:
cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
convRelu(0)
conv_relu(0)
cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64
convRelu(1)
conv_relu(1)
cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
convRelu(2, True)
convRelu(3)
conv_relu(2, True)
conv_relu(3)
cnn.add_module('pooling{0}'.format(2),
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
convRelu(4, True)
convRelu(5)
conv_relu(4, True)
conv_relu(5)
cnn.add_module('pooling{0}'.format(3),
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
convRelu(6, True) # 512x1x16
conv_relu(6, True) # 512x1x16
self.cnn = cnn

View File

@ -27,7 +27,7 @@ def test_base_recognizer():
type='CTCConvertor', dict_file=dict_file, with_unknown=False)
preprocessor = None
backbone = dict(type='VeryDeepVgg', leakyRelu=False)
backbone = dict(type='VeryDeepVgg', leaky_relu=False)
encoder = None
decoder = dict(type='CRNNDecoder', in_channels=512, rnn_flag=True)
loss = dict(type='CTCLoss')