370 lines
11 KiB
Python
370 lines
11 KiB
Python
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
This code is refer from:
|
|
https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/mmocr/models/textrecog/backbones/table_resnet_extra.py
|
|
"""
|
|
|
|
import paddle
|
|
import paddle.nn as nn
|
|
import paddle.nn.functional as F
|
|
|
|
|
|
class BasicBlock(nn.Layer):
|
|
expansion = 1
|
|
|
|
def __init__(self,
|
|
inplanes,
|
|
planes,
|
|
stride=1,
|
|
downsample=None,
|
|
gcb_config=None):
|
|
super(BasicBlock, self).__init__()
|
|
self.conv1 = nn.Conv2D(
|
|
inplanes,
|
|
planes,
|
|
kernel_size=3,
|
|
stride=stride,
|
|
padding=1,
|
|
bias_attr=False)
|
|
self.bn1 = nn.BatchNorm2D(planes, momentum=0.9)
|
|
self.relu = nn.ReLU()
|
|
self.conv2 = nn.Conv2D(
|
|
planes, planes, kernel_size=3, stride=1, padding=1, bias_attr=False)
|
|
self.bn2 = nn.BatchNorm2D(planes, momentum=0.9)
|
|
self.downsample = downsample
|
|
self.stride = stride
|
|
self.gcb_config = gcb_config
|
|
|
|
if self.gcb_config is not None:
|
|
gcb_ratio = gcb_config['ratio']
|
|
gcb_headers = gcb_config['headers']
|
|
att_scale = gcb_config['att_scale']
|
|
fusion_type = gcb_config['fusion_type']
|
|
self.context_block = MultiAspectGCAttention(
|
|
inplanes=planes,
|
|
ratio=gcb_ratio,
|
|
headers=gcb_headers,
|
|
att_scale=att_scale,
|
|
fusion_type=fusion_type)
|
|
|
|
def forward(self, x):
|
|
residual = x
|
|
|
|
out = self.conv1(x)
|
|
out = self.bn1(out)
|
|
out = self.relu(out)
|
|
|
|
out = self.conv2(out)
|
|
out = self.bn2(out)
|
|
|
|
if self.gcb_config is not None:
|
|
out = self.context_block(out)
|
|
|
|
if self.downsample is not None:
|
|
residual = self.downsample(x)
|
|
|
|
out += residual
|
|
out = self.relu(out)
|
|
|
|
return out
|
|
|
|
|
|
def get_gcb_config(gcb_config, layer):
|
|
if gcb_config is None or not gcb_config['layers'][layer]:
|
|
return None
|
|
else:
|
|
return gcb_config
|
|
|
|
|
|
class TableResNetExtra(nn.Layer):
|
|
def __init__(self, layers, in_channels=3, gcb_config=None):
|
|
assert len(layers) >= 4
|
|
|
|
super(TableResNetExtra, self).__init__()
|
|
self.inplanes = 128
|
|
self.conv1 = nn.Conv2D(
|
|
in_channels,
|
|
64,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
bias_attr=False)
|
|
self.bn1 = nn.BatchNorm2D(64)
|
|
self.relu1 = nn.ReLU()
|
|
|
|
self.conv2 = nn.Conv2D(
|
|
64, 128, kernel_size=3, stride=1, padding=1, bias_attr=False)
|
|
self.bn2 = nn.BatchNorm2D(128)
|
|
self.relu2 = nn.ReLU()
|
|
|
|
self.maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2)
|
|
|
|
self.layer1 = self._make_layer(
|
|
BasicBlock,
|
|
256,
|
|
layers[0],
|
|
stride=1,
|
|
gcb_config=get_gcb_config(gcb_config, 0))
|
|
|
|
self.conv3 = nn.Conv2D(
|
|
256, 256, kernel_size=3, stride=1, padding=1, bias_attr=False)
|
|
self.bn3 = nn.BatchNorm2D(256)
|
|
self.relu3 = nn.ReLU()
|
|
|
|
self.maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2)
|
|
|
|
self.layer2 = self._make_layer(
|
|
BasicBlock,
|
|
256,
|
|
layers[1],
|
|
stride=1,
|
|
gcb_config=get_gcb_config(gcb_config, 1))
|
|
|
|
self.conv4 = nn.Conv2D(
|
|
256, 256, kernel_size=3, stride=1, padding=1, bias_attr=False)
|
|
self.bn4 = nn.BatchNorm2D(256)
|
|
self.relu4 = nn.ReLU()
|
|
|
|
self.maxpool3 = nn.MaxPool2D(kernel_size=2, stride=2)
|
|
|
|
self.layer3 = self._make_layer(
|
|
BasicBlock,
|
|
512,
|
|
layers[2],
|
|
stride=1,
|
|
gcb_config=get_gcb_config(gcb_config, 2))
|
|
|
|
self.conv5 = nn.Conv2D(
|
|
512, 512, kernel_size=3, stride=1, padding=1, bias_attr=False)
|
|
self.bn5 = nn.BatchNorm2D(512)
|
|
self.relu5 = nn.ReLU()
|
|
|
|
self.layer4 = self._make_layer(
|
|
BasicBlock,
|
|
512,
|
|
layers[3],
|
|
stride=1,
|
|
gcb_config=get_gcb_config(gcb_config, 3))
|
|
|
|
self.conv6 = nn.Conv2D(
|
|
512, 512, kernel_size=3, stride=1, padding=1, bias_attr=False)
|
|
self.bn6 = nn.BatchNorm2D(512)
|
|
self.relu6 = nn.ReLU()
|
|
|
|
self.out_channels = [256, 256, 512]
|
|
|
|
def _make_layer(self, block, planes, blocks, stride=1, gcb_config=None):
|
|
downsample = None
|
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
|
downsample = nn.Sequential(
|
|
nn.Conv2D(
|
|
self.inplanes,
|
|
planes * block.expansion,
|
|
kernel_size=1,
|
|
stride=stride,
|
|
bias_attr=False),
|
|
nn.BatchNorm2D(planes * block.expansion), )
|
|
|
|
layers = []
|
|
layers.append(
|
|
block(
|
|
self.inplanes,
|
|
planes,
|
|
stride,
|
|
downsample,
|
|
gcb_config=gcb_config))
|
|
self.inplanes = planes * block.expansion
|
|
for _ in range(1, blocks):
|
|
layers.append(block(self.inplanes, planes))
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
f = []
|
|
x = self.conv1(x)
|
|
|
|
x = self.bn1(x)
|
|
x = self.relu1(x)
|
|
|
|
x = self.conv2(x)
|
|
x = self.bn2(x)
|
|
x = self.relu2(x)
|
|
|
|
x = self.maxpool1(x)
|
|
x = self.layer1(x)
|
|
|
|
x = self.conv3(x)
|
|
x = self.bn3(x)
|
|
x = self.relu3(x)
|
|
f.append(x)
|
|
|
|
x = self.maxpool2(x)
|
|
x = self.layer2(x)
|
|
|
|
x = self.conv4(x)
|
|
x = self.bn4(x)
|
|
x = self.relu4(x)
|
|
f.append(x)
|
|
|
|
x = self.maxpool3(x)
|
|
|
|
x = self.layer3(x)
|
|
x = self.conv5(x)
|
|
x = self.bn5(x)
|
|
x = self.relu5(x)
|
|
|
|
x = self.layer4(x)
|
|
x = self.conv6(x)
|
|
x = self.bn6(x)
|
|
x = self.relu6(x)
|
|
f.append(x)
|
|
return f
|
|
|
|
|
|
class MultiAspectGCAttention(nn.Layer):
|
|
def __init__(self,
|
|
inplanes,
|
|
ratio,
|
|
headers,
|
|
pooling_type='att',
|
|
att_scale=False,
|
|
fusion_type='channel_add'):
|
|
super(MultiAspectGCAttention, self).__init__()
|
|
assert pooling_type in ['avg', 'att']
|
|
|
|
assert fusion_type in ['channel_add', 'channel_mul', 'channel_concat']
|
|
assert inplanes % headers == 0 and inplanes >= 8 # inplanes must be divided by headers evenly
|
|
|
|
self.headers = headers
|
|
self.inplanes = inplanes
|
|
self.ratio = ratio
|
|
self.planes = int(inplanes * ratio)
|
|
self.pooling_type = pooling_type
|
|
self.fusion_type = fusion_type
|
|
self.att_scale = False
|
|
|
|
self.single_header_inplanes = int(inplanes / headers)
|
|
|
|
if pooling_type == 'att':
|
|
self.conv_mask = nn.Conv2D(
|
|
self.single_header_inplanes, 1, kernel_size=1)
|
|
self.softmax = nn.Softmax(axis=2)
|
|
else:
|
|
self.avg_pool = nn.AdaptiveAvgPool2D(1)
|
|
|
|
if fusion_type == 'channel_add':
|
|
self.channel_add_conv = nn.Sequential(
|
|
nn.Conv2D(
|
|
self.inplanes, self.planes, kernel_size=1),
|
|
nn.LayerNorm([self.planes, 1, 1]),
|
|
nn.ReLU(),
|
|
nn.Conv2D(
|
|
self.planes, self.inplanes, kernel_size=1))
|
|
elif fusion_type == 'channel_concat':
|
|
self.channel_concat_conv = nn.Sequential(
|
|
nn.Conv2D(
|
|
self.inplanes, self.planes, kernel_size=1),
|
|
nn.LayerNorm([self.planes, 1, 1]),
|
|
nn.ReLU(),
|
|
nn.Conv2D(
|
|
self.planes, self.inplanes, kernel_size=1))
|
|
# for concat
|
|
self.cat_conv = nn.Conv2D(
|
|
2 * self.inplanes, self.inplanes, kernel_size=1)
|
|
elif fusion_type == 'channel_mul':
|
|
self.channel_mul_conv = nn.Sequential(
|
|
nn.Conv2D(
|
|
self.inplanes, self.planes, kernel_size=1),
|
|
nn.LayerNorm([self.planes, 1, 1]),
|
|
nn.ReLU(),
|
|
nn.Conv2D(
|
|
self.planes, self.inplanes, kernel_size=1))
|
|
|
|
def spatial_pool(self, x):
|
|
batch, channel, height, width = x.shape
|
|
if self.pooling_type == 'att':
|
|
# [N*headers, C', H , W] C = headers * C'
|
|
x = x.reshape([
|
|
batch * self.headers, self.single_header_inplanes, height, width
|
|
])
|
|
input_x = x
|
|
|
|
# [N*headers, C', H * W] C = headers * C'
|
|
# input_x = input_x.view(batch, channel, height * width)
|
|
input_x = input_x.reshape([
|
|
batch * self.headers, self.single_header_inplanes,
|
|
height * width
|
|
])
|
|
|
|
# [N*headers, 1, C', H * W]
|
|
input_x = input_x.unsqueeze(1)
|
|
# [N*headers, 1, H, W]
|
|
context_mask = self.conv_mask(x)
|
|
# [N*headers, 1, H * W]
|
|
context_mask = context_mask.reshape(
|
|
[batch * self.headers, 1, height * width])
|
|
|
|
# scale variance
|
|
if self.att_scale and self.headers > 1:
|
|
context_mask = context_mask / paddle.sqrt(
|
|
self.single_header_inplanes)
|
|
|
|
# [N*headers, 1, H * W]
|
|
context_mask = self.softmax(context_mask)
|
|
|
|
# [N*headers, 1, H * W, 1]
|
|
context_mask = context_mask.unsqueeze(-1)
|
|
# [N*headers, 1, C', 1] = [N*headers, 1, C', H * W] * [N*headers, 1, H * W, 1]
|
|
context = paddle.matmul(input_x, context_mask)
|
|
|
|
# [N, headers * C', 1, 1]
|
|
context = context.reshape(
|
|
[batch, self.headers * self.single_header_inplanes, 1, 1])
|
|
else:
|
|
# [N, C, 1, 1]
|
|
context = self.avg_pool(x)
|
|
|
|
return context
|
|
|
|
def forward(self, x):
|
|
# [N, C, 1, 1]
|
|
context = self.spatial_pool(x)
|
|
|
|
out = x
|
|
|
|
if self.fusion_type == 'channel_mul':
|
|
# [N, C, 1, 1]
|
|
channel_mul_term = F.sigmoid(self.channel_mul_conv(context))
|
|
out = out * channel_mul_term
|
|
elif self.fusion_type == 'channel_add':
|
|
# [N, C, 1, 1]
|
|
channel_add_term = self.channel_add_conv(context)
|
|
out = out + channel_add_term
|
|
else:
|
|
# [N, C, 1, 1]
|
|
channel_concat_term = self.channel_concat_conv(context)
|
|
|
|
# use concat
|
|
_, C1, _, _ = channel_concat_term.shape
|
|
N, C2, H, W = out.shape
|
|
|
|
out = paddle.concat(
|
|
[out, channel_concat_term.expand([-1, -1, H, W])], axis=1)
|
|
out = self.cat_conv(out)
|
|
out = F.layer_norm(out, [self.inplanes, H, W])
|
|
out = F.relu(out)
|
|
|
|
return out
|