From e7f27ae31712850843ba8ba9348eb4793cc76a15 Mon Sep 17 00:00:00 2001 From: kang sheng Date: Tue, 11 May 2021 16:07:36 +0800 Subject: [PATCH] change argument names according to convention (#131) * change argument names according to convention * bug fix when rename leakyRelu --- configs/_base_/recog_models/crnn.py | 2 +- .../textrecog/crnn/crnn_academic_dataset.py | 2 +- .../textrecog/backbones/very_deep_vgg.py | 32 +++++++++---------- tests/test_models/test_recognizer.py | 2 +- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/configs/_base_/recog_models/crnn.py b/configs/_base_/recog_models/crnn.py index 6b98c3d9..4bcfa5c8 100644 --- a/configs/_base_/recog_models/crnn.py +++ b/configs/_base_/recog_models/crnn.py @@ -4,7 +4,7 @@ label_convertor = dict( model = dict( type='CRNNNet', preprocessor=None, - backbone=dict(type='VeryDeepVgg', leakyRelu=False), + backbone=dict(type='VeryDeepVgg', leaky_relu=False), encoder=None, decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True), loss=dict(type='CTCLoss', flatten=False), diff --git a/configs/textrecog/crnn/crnn_academic_dataset.py b/configs/textrecog/crnn/crnn_academic_dataset.py index 045ff999..50bf3f3f 100644 --- a/configs/textrecog/crnn/crnn_academic_dataset.py +++ b/configs/textrecog/crnn/crnn_academic_dataset.py @@ -21,7 +21,7 @@ label_convertor = dict( model = dict( type='CRNNNet', preprocessor=None, - backbone=dict(type='VeryDeepVgg', leakyRelu=False, input_channels=1), + backbone=dict(type='VeryDeepVgg', leaky_relu=False, input_channels=1), encoder=None, decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True), loss=dict(type='CTCLoss'), diff --git a/mmocr/models/textrecog/backbones/very_deep_vgg.py b/mmocr/models/textrecog/backbones/very_deep_vgg.py index a30ace5c..277f89ca 100644 --- a/mmocr/models/textrecog/backbones/very_deep_vgg.py +++ b/mmocr/models/textrecog/backbones/very_deep_vgg.py @@ -9,11 +9,11 @@ class VeryDeepVgg(nn.Module): """Implement VGG-VeryDeep backbone for text recognition, modified from `VGG-VeryDeep `_ Args: + leaky_relu (bool): Use leakyRelu or not. input_channels (int): Number of channels of input image tensor. - leakyRelu (bool): Use leakyRelu or not. """ - def __init__(self, leakyRelu=True, input_channels=3): + def __init__(self, leaky_relu=True, input_channels=3): super().__init__() ks = [3, 3, 3, 3, 3, 3, 2] @@ -25,32 +25,32 @@ class VeryDeepVgg(nn.Module): cnn = nn.Sequential() - def convRelu(i, batchNormalization=False): - nIn = input_channels if i == 0 else nm[i - 1] - nOut = nm[i] + def conv_relu(i, batch_normalization=False): + n_in = input_channels if i == 0 else nm[i - 1] + n_out = nm[i] cnn.add_module('conv{0}'.format(i), - nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i])) - if batchNormalization: - cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut)) - if leakyRelu: + nn.Conv2d(n_in, n_out, ks[i], ss[i], ps[i])) + if batch_normalization: + cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(n_out)) + if leaky_relu: cnn.add_module('relu{0}'.format(i), nn.LeakyReLU(0.2, inplace=True)) else: cnn.add_module('relu{0}'.format(i), nn.ReLU(True)) - convRelu(0) + conv_relu(0) cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64 - convRelu(1) + conv_relu(1) cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32 - convRelu(2, True) - convRelu(3) + conv_relu(2, True) + conv_relu(3) cnn.add_module('pooling{0}'.format(2), nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16 - convRelu(4, True) - convRelu(5) + conv_relu(4, True) + conv_relu(5) cnn.add_module('pooling{0}'.format(3), nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16 - convRelu(6, True) # 512x1x16 + conv_relu(6, True) # 512x1x16 self.cnn = cnn diff --git a/tests/test_models/test_recognizer.py b/tests/test_models/test_recognizer.py index d1081406..495f360e 100644 --- a/tests/test_models/test_recognizer.py +++ b/tests/test_models/test_recognizer.py @@ -27,7 +27,7 @@ def test_base_recognizer(): type='CTCConvertor', dict_file=dict_file, with_unknown=False) preprocessor = None - backbone = dict(type='VeryDeepVgg', leakyRelu=False) + backbone = dict(type='VeryDeepVgg', leaky_relu=False) encoder = None decoder = dict(type='CRNNDecoder', in_channels=512, rnn_flag=True) loss = dict(type='CTCLoss')