fix PREN export and infer ()

pull/7845/head
littletomatodonkey 2022-10-08 16:37:12 +08:00 committed by GitHub
parent 077196f3cb
commit eeef62b3c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 174 additions and 120 deletions

View File

@ -21,124 +21,165 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import math import math
from collections import namedtuple import re
import collections
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
__all__ = ['EfficientNetb3'] __all__ = ['EfficientNetb3']
GlobalParams = collections.namedtuple('GlobalParams', [
'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 'num_classes',
'width_coefficient', 'depth_coefficient', 'depth_divisor', 'min_depth',
'drop_connect_rate', 'image_size'
])
class EffB3Params: BlockArgs = collections.namedtuple('BlockArgs', [
@staticmethod
def get_global_params():
"""
The fllowing are efficientnetb3's arch superparams, but to fit for scene
text recognition task, the resolution(image_size) here is changed
from 300 to 64.
"""
GlobalParams = namedtuple('GlobalParams', [
'drop_connect_rate', 'width_coefficient', 'depth_coefficient',
'depth_divisor', 'image_size'
])
global_params = GlobalParams(
drop_connect_rate=0.3,
width_coefficient=1.2,
depth_coefficient=1.4,
depth_divisor=8,
image_size=64)
return global_params
@staticmethod
def get_block_params():
BlockParams = namedtuple('BlockParams', [
'kernel_size', 'num_repeat', 'input_filters', 'output_filters', 'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
'expand_ratio', 'id_skip', 'se_ratio', 'stride' 'expand_ratio', 'id_skip', 'stride', 'se_ratio'
]) ])
block_params = [
BlockParams(3, 1, 32, 16, 1, True, 0.25, 1),
BlockParams(3, 2, 16, 24, 6, True, 0.25, 2), class BlockDecoder:
BlockParams(5, 2, 24, 40, 6, True, 0.25, 2), @staticmethod
BlockParams(3, 3, 40, 80, 6, True, 0.25, 2), def _decode_block_string(block_string):
BlockParams(5, 3, 80, 112, 6, True, 0.25, 1), assert isinstance(block_string, str)
BlockParams(5, 4, 112, 192, 6, True, 0.25, 2),
BlockParams(3, 1, 192, 320, 6, True, 0.25, 1) ops = block_string.split('_')
options = {}
for op in ops:
splits = re.split(r'(\d.*)', op)
if len(splits) >= 2:
key, value = splits[:2]
options[key] = value
assert (('s' in options and len(options['s']) == 1) or
(len(options['s']) == 2 and options['s'][0] == options['s'][1]))
return BlockArgs(
kernel_size=int(options['k']),
num_repeat=int(options['r']),
input_filters=int(options['i']),
output_filters=int(options['o']),
expand_ratio=int(options['e']),
id_skip=('noskip' not in block_string),
se_ratio=float(options['se']) if 'se' in options else None,
stride=[int(options['s'][0])])
@staticmethod
def decode(string_list):
assert isinstance(string_list, list)
blocks_args = []
for block_string in string_list:
blocks_args.append(BlockDecoder._decode_block_string(block_string))
return blocks_args
def efficientnet(width_coefficient=None,
depth_coefficient=None,
dropout_rate=0.2,
drop_connect_rate=0.2,
image_size=None,
num_classes=1000):
blocks_args = [
'r1_k3_s11_e1_i32_o16_se0.25',
'r2_k3_s22_e6_i16_o24_se0.25',
'r2_k5_s22_e6_i24_o40_se0.25',
'r3_k3_s22_e6_i40_o80_se0.25',
'r3_k5_s11_e6_i80_o112_se0.25',
'r4_k5_s22_e6_i112_o192_se0.25',
'r1_k3_s11_e6_i192_o320_se0.25',
] ]
return block_params blocks_args = BlockDecoder.decode(blocks_args)
global_params = GlobalParams(
batch_norm_momentum=0.99,
batch_norm_epsilon=1e-3,
dropout_rate=dropout_rate,
drop_connect_rate=drop_connect_rate,
num_classes=num_classes,
width_coefficient=width_coefficient,
depth_coefficient=depth_coefficient,
depth_divisor=8,
min_depth=None,
image_size=image_size, )
return blocks_args, global_params
class EffUtils: class EffUtils:
@staticmethod @staticmethod
def round_filters(filters, global_params): def round_filters(filters, global_params):
"""Calculate and round number of filters based on depth multiplier.""" """ Calculate and round number of filters based on depth multiplier. """
multiplier = global_params.width_coefficient multiplier = global_params.width_coefficient
if not multiplier: if not multiplier:
return filters return filters
divisor = global_params.depth_divisor divisor = global_params.depth_divisor
min_depth = global_params.min_depth
filters *= multiplier filters *= multiplier
new_filters = int(filters + divisor / 2) // divisor * divisor min_depth = min_depth or divisor
new_filters = max(min_depth,
int(filters + divisor / 2) // divisor * divisor)
if new_filters < 0.9 * filters: if new_filters < 0.9 * filters:
new_filters += divisor new_filters += divisor
return int(new_filters) return int(new_filters)
@staticmethod @staticmethod
def round_repeats(repeats, global_params): def round_repeats(repeats, global_params):
"""Round number of filters based on depth multiplier.""" """ Round number of filters based on depth multiplier. """
multiplier = global_params.depth_coefficient multiplier = global_params.depth_coefficient
if not multiplier: if not multiplier:
return repeats return repeats
return int(math.ceil(multiplier * repeats)) return int(math.ceil(multiplier * repeats))
class ConvBlock(nn.Layer): class MbConvBlock(nn.Layer):
def __init__(self, block_params): def __init__(self, block_args):
super(ConvBlock, self).__init__() super(MbConvBlock, self).__init__()
self.block_args = block_params self._block_args = block_args
self.has_se = (self.block_args.se_ratio is not None) and \ self.has_se = (self._block_args.se_ratio is not None) and \
(0 < self.block_args.se_ratio <= 1) (0 < self._block_args.se_ratio <= 1)
self.id_skip = block_params.id_skip self.id_skip = block_args.id_skip
# expansion phase # expansion phase
self.input_filters = self.block_args.input_filters self.inp = self._block_args.input_filters
output_filters = \ oup = self._block_args.input_filters * self._block_args.expand_ratio
self.block_args.input_filters * self.block_args.expand_ratio if self._block_args.expand_ratio != 1:
if self.block_args.expand_ratio != 1: self._expand_conv = nn.Conv2D(self.inp, oup, 1, bias_attr=False)
self.expand_conv = nn.Conv2D( self._bn0 = nn.BatchNorm(oup)
self.input_filters, output_filters, 1, bias_attr=False)
self.bn0 = nn.BatchNorm(output_filters)
# depthwise conv phase # depthwise conv phase
k = self.block_args.kernel_size k = self._block_args.kernel_size
s = self.block_args.stride s = self._block_args.stride
self.depthwise_conv = nn.Conv2D( if isinstance(s, list):
output_filters, s = s[0]
output_filters, self._depthwise_conv = nn.Conv2D(
groups=output_filters, oup,
oup,
groups=oup,
kernel_size=k, kernel_size=k,
stride=s, stride=s,
padding='same', padding='same',
bias_attr=False) bias_attr=False)
self.bn1 = nn.BatchNorm(output_filters) self._bn1 = nn.BatchNorm(oup)
# squeeze and excitation layer, if desired # squeeze and excitation layer, if desired
if self.has_se: if self.has_se:
num_squeezed_channels = max(1, num_squeezed_channels = max(1,
int(self.block_args.input_filters * int(self._block_args.input_filters *
self.block_args.se_ratio)) self._block_args.se_ratio))
self.se_reduce = nn.Conv2D(output_filters, num_squeezed_channels, 1) self._se_reduce = nn.Conv2D(oup, num_squeezed_channels, 1)
self.se_expand = nn.Conv2D(num_squeezed_channels, output_filters, 1) self._se_expand = nn.Conv2D(num_squeezed_channels, oup, 1)
# output phase # output phase and some util class
self.final_oup = self.block_args.output_filters self.final_oup = self._block_args.output_filters
self.project_conv = nn.Conv2D( self._project_conv = nn.Conv2D(oup, self.final_oup, 1, bias_attr=False)
output_filters, self.final_oup, 1, bias_attr=False) self._bn2 = nn.BatchNorm(self.final_oup)
self.bn2 = nn.BatchNorm(self.final_oup) self._swish = nn.Swish()
self.swish = nn.Swish()
def drop_connect(self, inputs, p, training): def _drop_connect(self, inputs, p, training):
if not training: if not training:
return inputs return inputs
batch_size = inputs.shape[0] batch_size = inputs.shape[0]
keep_prob = 1 - p keep_prob = 1 - p
random_tensor = keep_prob random_tensor = keep_prob
@ -151,22 +192,23 @@ class ConvBlock(nn.Layer):
def forward(self, inputs, drop_connect_rate=None): def forward(self, inputs, drop_connect_rate=None):
# expansion and depthwise conv # expansion and depthwise conv
x = inputs x = inputs
if self.block_args.expand_ratio != 1: if self._block_args.expand_ratio != 1:
x = self.swish(self.bn0(self.expand_conv(inputs))) x = self._swish(self._bn0(self._expand_conv(inputs)))
x = self.swish(self.bn1(self.depthwise_conv(x))) x = self._swish(self._bn1(self._depthwise_conv(x)))
# squeeze and excitation # squeeze and excitation
if self.has_se: if self.has_se:
x_squeezed = F.adaptive_avg_pool2d(x, 1) x_squeezed = F.adaptive_avg_pool2d(x, 1)
x_squeezed = self.se_expand(self.swish(self.se_reduce(x_squeezed))) x_squeezed = self._se_expand(
self._swish(self._se_reduce(x_squeezed)))
x = F.sigmoid(x_squeezed) * x x = F.sigmoid(x_squeezed) * x
x = self.bn2(self.project_conv(x)) x = self._bn2(self._project_conv(x))
# skip conntection and drop connect # skip conntection and drop connect
if self.id_skip and self.block_args.stride == 1 and \ if self.id_skip and self._block_args.stride == 1 and \
self.input_filters == self.final_oup: self.inp == self.final_oup:
if drop_connect_rate: if drop_connect_rate:
x = self.drop_connect( x = self._drop_connect(
x, p=drop_connect_rate, training=self.training) x, p=drop_connect_rate, training=self.training)
x = x + inputs x = x + inputs
return x return x
@ -175,54 +217,63 @@ class ConvBlock(nn.Layer):
class EfficientNetb3_PREN(nn.Layer): class EfficientNetb3_PREN(nn.Layer):
def __init__(self, in_channels): def __init__(self, in_channels):
super(EfficientNetb3_PREN, self).__init__() super(EfficientNetb3_PREN, self).__init__()
self.blocks_params = EffB3Params.get_block_params() """
self.global_params = EffB3Params.get_global_params() the fllowing are efficientnetb3's superparams,
they means efficientnetb3 network's width, depth, resolution and
dropout respectively, to fit for text recognition task, the resolution
here is changed from 300 to 64.
"""
w, d, s, p = 1.2, 1.4, 64, 0.3
self._blocks_args, self._global_params = efficientnet(
width_coefficient=w,
depth_coefficient=d,
dropout_rate=p,
image_size=s)
self.out_channels = [] self.out_channels = []
# stem # stem
stem_channels = EffUtils.round_filters(32, self.global_params) out_channels = EffUtils.round_filters(32, self._global_params)
self.conv_stem = nn.Conv2D( self._conv_stem = nn.Conv2D(
in_channels, stem_channels, 3, 2, padding='same', bias_attr=False) in_channels, out_channels, 3, 2, padding='same', bias_attr=False)
self.bn0 = nn.BatchNorm(stem_channels) self._bn0 = nn.BatchNorm(out_channels)
self.blocks = [] # build blocks
self._blocks = []
# to extract three feature maps for fpn based on efficientnetb3 backbone # to extract three feature maps for fpn based on efficientnetb3 backbone
self.concerned_block_idxes = [7, 17, 25] self._concerned_block_idxes = [7, 17, 25]
concerned_idx = 0 _concerned_idx = 0
for i, block_params in enumerate(self.blocks_params): for i, block_args in enumerate(self._blocks_args):
block_params = block_params._replace( block_args = block_args._replace(
input_filters=EffUtils.round_filters(block_params.input_filters, input_filters=EffUtils.round_filters(block_args.input_filters,
self.global_params), self._global_params),
output_filters=EffUtils.round_filters( output_filters=EffUtils.round_filters(block_args.output_filters,
block_params.output_filters, self.global_params), self._global_params),
num_repeat=EffUtils.round_repeats(block_params.num_repeat, num_repeat=EffUtils.round_repeats(block_args.num_repeat,
self.global_params)) self._global_params))
self.blocks.append( self._blocks.append(
self.add_sublayer("{}-0".format(i), ConvBlock(block_params))) self.add_sublayer(f"{i}-0", MbConvBlock(block_args)))
concerned_idx += 1 _concerned_idx += 1
if concerned_idx in self.concerned_block_idxes: if _concerned_idx in self._concerned_block_idxes:
self.out_channels.append(block_params.output_filters) self.out_channels.append(block_args.output_filters)
if block_params.num_repeat > 1: if block_args.num_repeat > 1:
block_params = block_params._replace( block_args = block_args._replace(
input_filters=block_params.output_filters, stride=1) input_filters=block_args.output_filters, stride=1)
for j in range(block_params.num_repeat - 1): for j in range(block_args.num_repeat - 1):
self.blocks.append( self._blocks.append(
self.add_sublayer('{}-{}'.format(i, j + 1), self.add_sublayer(f'{i}-{j+1}', MbConvBlock(block_args)))
ConvBlock(block_params))) _concerned_idx += 1
concerned_idx += 1 if _concerned_idx in self._concerned_block_idxes:
if concerned_idx in self.concerned_block_idxes: self.out_channels.append(block_args.output_filters)
self.out_channels.append(block_params.output_filters)
self.swish = nn.Swish() self._swish = nn.Swish()
def forward(self, inputs): def forward(self, inputs):
outs = [] outs = []
x = self._swish(self._bn0(self._conv_stem(inputs)))
x = self.swish(self.bn0(self.conv_stem(inputs))) for idx, block in enumerate(self._blocks):
for idx, block in enumerate(self.blocks): drop_connect_rate = self._global_params.drop_connect_rate
drop_connect_rate = self.global_params.drop_connect_rate
if drop_connect_rate: if drop_connect_rate:
drop_connect_rate *= float(idx) / len(self.blocks) drop_connect_rate *= float(idx) / len(self._blocks)
x = block(x, drop_connect_rate=drop_connect_rate) x = block(x, drop_connect_rate=drop_connect_rate)
if idx in self.concerned_block_idxes: if idx in self._concerned_block_idxes:
outs.append(x) outs.append(x)
return outs return outs

View File

@ -562,6 +562,7 @@ class PRENLabelDecode(BaseRecLabelDecode):
return result_list return result_list
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, paddle.Tensor):
preds = preds.numpy() preds = preds.numpy()
preds_idx = preds.argmax(axis=2) preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2) preds_prob = preds.max(axis=2)

View File

@ -77,7 +77,7 @@ def export_single_model(model,
elif arch_config["algorithm"] == "PREN": elif arch_config["algorithm"] == "PREN":
other_shape = [ other_shape = [
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 3, 64, 512], dtype="float32"), shape=[None, 3, 64, 256], dtype="float32"),
] ]
model = to_static(model, input_spec=other_shape) model = to_static(model, input_spec=other_shape)
elif arch_config["model_type"] == "sr": elif arch_config["model_type"] == "sr":

View File

@ -100,6 +100,8 @@ class TextRecognizer(object):
"use_space_char": args.use_space_char, "use_space_char": args.use_space_char,
"rm_symbol": True "rm_symbol": True
} }
elif self.rec_algorithm == "PREN":
postprocess_params = {'name': 'PRENLabelDecode'}
self.postprocess_op = build_post_process(postprocess_params) self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \ self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger) utility.create_predictor(args, 'rec', logger)
@ -384,7 +386,7 @@ class TextRecognizer(object):
self.rec_image_shape) self.rec_image_shape)
norm_img = norm_img[np.newaxis, :] norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img) norm_img_batch.append(norm_img)
elif self.rec_algorithm == "VisionLAN": elif self.rec_algorithm in ["VisionLAN", "PREN"]:
norm_img = self.resize_norm_img_vl(img_list[indices[ino]], norm_img = self.resize_norm_img_vl(img_list[indices[ino]],
self.rec_image_shape) self.rec_image_shape)
norm_img = norm_img[np.newaxis, :] norm_img = norm_img[np.newaxis, :]