From ff0f23d49565f44af57311a72ef387aa1aa1d17e Mon Sep 17 00:00:00 2001
From: WenmuZhou <zjwenmu@gmail.com>
Date: Tue, 10 Nov 2020 12:45:25 +0800
Subject: [PATCH 1/5] =?UTF-8?q?=E8=AE=AD=E7=BB=83=E9=9B=86shuffle=E6=94=B9?=
 =?UTF-8?q?=E4=B8=BAfalse?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 configs/rec/rec_mv3_none_bilstm_ctc.yml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/configs/rec/rec_mv3_none_bilstm_ctc.yml b/configs/rec/rec_mv3_none_bilstm_ctc.yml
index 57a7c049df..def7237514 100644
--- a/configs/rec/rec_mv3_none_bilstm_ctc.yml
+++ b/configs/rec/rec_mv3_none_bilstm_ctc.yml
@@ -72,7 +72,7 @@ Train:
       - KeepKeys:
           keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
   loader:
-    shuffle: True
+    shuffle: False
     batch_size_per_card: 256
     drop_last: True
     num_workers: 8

From 2f9f258ff48a1084724d9684175b311e9030efdf Mon Sep 17 00:00:00 2001
From: WenmuZhou <zjwenmu@gmail.com>
Date: Tue, 10 Nov 2020 17:18:32 +0800
Subject: [PATCH 2/5] =?UTF-8?q?=E6=B7=BB=E5=8A=A0tps=E7=BD=91=E7=BB=9C?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 configs/rec/rec_r34_vd_tps_bilstm_ctc.yml  | 100 +++++++
 ppocr/modeling/architectures/base_model.py |  11 +-
 ppocr/modeling/transform/__init__.py       |   4 +-
 ppocr/modeling/transform/tps.py            | 289 +++++++++++++++++++++
 4 files changed, 398 insertions(+), 6 deletions(-)
 create mode 100644 configs/rec/rec_r34_vd_tps_bilstm_ctc.yml
 create mode 100644 ppocr/modeling/transform/tps.py

diff --git a/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml b/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml
new file mode 100644
index 0000000000..269f1e4117
--- /dev/null
+++ b/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml
@@ -0,0 +1,100 @@
+Global:
+  use_gpu: true
+  epoch_num: 72
+  log_smooth_window: 20
+  print_batch_step: 10
+  save_model_dir: ./output/rec/r34_vd_tps_bilstm_ctc/
+  save_epoch_step: 3
+  # evaluation is run every 5000 iterations after the 4000th iteration
+  eval_batch_step: [0, 2000]
+  # if pretrained_model is saved in static mode, load_static_weights must set to True
+  cal_metric_during_train: True
+  pretrained_model:
+  checkpoints:
+  save_inference_dir:
+  use_visualdl: False
+  infer_img: doc/imgs_words/ch/word_1.jpg
+  # for data or label process
+  character_dict_path: 
+  character_type: en
+  max_text_length: 25
+  infer_mode: False
+  use_space_char: False
+
+
+Optimizer:
+  name: Adam
+  beta1: 0.9
+  beta2: 0.999
+  lr:
+    learning_rate: 0.0005
+  regularizer:
+    name: 'L2'
+    factor: 0
+
+Architecture:
+  model_type: rec
+  algorithm: CRNN
+  Transform:
+    name: TPS
+    num_fiducial: 20
+    loc_lr: 0.1
+    model_name: small
+  Backbone:
+    name: ResNet
+    layers: 34
+  Neck:
+    name: SequenceEncoder
+    encoder_type: rnn
+    hidden_size: 256
+  Head:
+    name: CTCHead
+    fc_decay: 0
+
+Loss:
+  name: CTCLoss
+
+PostProcess:
+  name: CTCLabelDecode
+
+Metric:
+  name: RecMetric
+  main_indicator: acc
+
+Train:
+  dataset:
+    name: LMDBDateSet
+    data_dir: ./train_data/data_lmdb_release/training/
+    transforms:
+      - DecodeImage: # load image
+          img_mode: BGR
+          channel_first: False
+      - CTCLabelEncode: # Class handling label
+      - RecResizeImg:
+          image_shape: [3, 32, 100]
+      - KeepKeys:
+          keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+  loader:
+    shuffle: True
+    batch_size_per_card: 256
+    drop_last: True
+    num_workers: 8
+
+Eval:
+  dataset:
+    name: LMDBDateSet
+    data_dir: ./train_data/data_lmdb_release/validation/
+    transforms:
+      - DecodeImage: # load image
+          img_mode: BGR
+          channel_first: False
+      - CTCLabelEncode: # Class handling label
+      - RecResizeImg:
+          image_shape: [3, 32, 100]
+      - KeepKeys:
+          keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
+  loader:
+    shuffle: False
+    drop_last: False
+    batch_size_per_card: 256
+    num_workers: 4
diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py
index c111960473..0c4fe6503f 100644
--- a/ppocr/modeling/architectures/base_model.py
+++ b/ppocr/modeling/architectures/base_model.py
@@ -16,13 +16,14 @@ from __future__ import division
 from __future__ import print_function
 
 from paddle import nn
-
+from ppocr.modeling.transform import build_transform
 from ppocr.modeling.backbones import build_backbone
 from ppocr.modeling.necks import build_neck
 from ppocr.modeling.heads import build_head
 
 __all__ = ['BaseModel']
 
+
 class BaseModel(nn.Layer):
     def __init__(self, config):
         """
@@ -31,7 +32,7 @@ class BaseModel(nn.Layer):
             config (dict): the super parameters for module.
         """
         super(BaseModel, self).__init__()
-        
+
         in_channels = config.get('in_channels', 3)
         model_type = config['model_type']
         # build transfrom,
@@ -50,7 +51,7 @@ class BaseModel(nn.Layer):
         config["Backbone"]['in_channels'] = in_channels
         self.backbone = build_backbone(config["Backbone"], model_type)
         in_channels = self.backbone.out_channels
-        
+
         # build neck
         # for rec, neck can be cnn,rnn or reshape(None)
         # for det, neck can be FPN, BIFPN and so on.
@@ -62,7 +63,7 @@ class BaseModel(nn.Layer):
             config['Neck']['in_channels'] = in_channels
             self.neck = build_neck(config['Neck'])
             in_channels = self.neck.out_channels
-        
+
         # # build head, head is need for det, rec and cls
         config["Head"]['in_channels'] = in_channels
         self.head = build_head(config["Head"])
@@ -74,4 +75,4 @@ class BaseModel(nn.Layer):
         if self.use_neck:
             x = self.neck(x)
         x = self.head(x)
-        return x
\ No newline at end of file
+        return x
diff --git a/ppocr/modeling/transform/__init__.py b/ppocr/modeling/transform/__init__.py
index af3b3f8697..78eaecccc5 100755
--- a/ppocr/modeling/transform/__init__.py
+++ b/ppocr/modeling/transform/__init__.py
@@ -16,7 +16,9 @@ __all__ = ['build_transform']
 
 
 def build_transform(config):
-    support_dict = ['']
+    from .tps import TPS
+
+    support_dict = ['TPS']
 
     module_name = config.pop('name')
     assert module_name in support_dict, Exception(
diff --git a/ppocr/modeling/transform/tps.py b/ppocr/modeling/transform/tps.py
new file mode 100644
index 0000000000..f5b4f60b8a
--- /dev/null
+++ b/ppocr/modeling/transform/tps.py
@@ -0,0 +1,289 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn, ParamAttr
+from paddle.nn import functional as F
+import numpy as np
+
+
+class ConvBNLayer(nn.Layer):
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size,
+                 stride=1,
+                 groups=1,
+                 act=None,
+                 name=None):
+        super(ConvBNLayer, self).__init__()
+        self.conv = nn.Conv2D(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=(kernel_size - 1) // 2,
+            groups=groups,
+            weight_attr=ParamAttr(name=name + "_weights"),
+            bias_attr=False)
+        bn_name = "bn_" + name
+        self.bn = nn.BatchNorm(
+            out_channels,
+            act=act,
+            param_attr=ParamAttr(name=bn_name + '_scale'),
+            bias_attr=ParamAttr(bn_name + '_offset'),
+            moving_mean_name=bn_name + '_mean',
+            moving_variance_name=bn_name + '_variance')
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.bn(x)
+        return x
+
+
+class LocalizationNetwork(nn.Layer):
+    def __init__(self, in_channels, num_fiducial, loc_lr, model_name):
+        super(LocalizationNetwork, self).__init__()
+        self.F = num_fiducial
+        F = num_fiducial
+        if model_name == "large":
+            num_filters_list = [64, 128, 256, 512]
+            fc_dim = 256
+        else:
+            num_filters_list = [16, 32, 64, 128]
+            fc_dim = 64
+
+        self.block_list = []
+        for fno in range(0, len(num_filters_list)):
+            num_filters = num_filters_list[fno]
+            name = "loc_conv%d" % fno
+            conv = self.add_sublayer(
+                name,
+                ConvBNLayer(
+                    in_channels=in_channels,
+                    out_channels=num_filters,
+                    kernel_size=3,
+                    act='relu',
+                    name=name))
+            self.block_list.append(conv)
+            if fno == len(num_filters_list) - 1:
+                pool = nn.AdaptiveAvgPool2D(1)
+            else:
+                pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
+            in_channels = num_filters
+            self.block_list.append(pool)
+        name = "loc_fc1"
+        self.fc1 = nn.Linear(
+            in_channels,
+            fc_dim,
+            weight_attr=ParamAttr(
+                learning_rate=loc_lr, name=name + "_w"),
+            bias_attr=ParamAttr(name=name + '.b_0'),
+            name=name)
+
+        # Init fc2 in LocalizationNetwork
+        initial_bias = self.get_initial_fiducials()
+        initial_bias = initial_bias.reshape(-1)
+        name = "loc_fc2"
+        param_attr = ParamAttr(
+            learning_rate=loc_lr,
+            initializer=paddle.fluid.initializer.NumpyArrayInitializer(
+                np.zeros([fc_dim, F * 2])),
+            name=name + "_w")
+        bias_attr = ParamAttr(
+            learning_rate=loc_lr,
+            initializer=paddle.fluid.initializer.NumpyArrayInitializer(
+                initial_bias),
+            name=name + "_b")
+        self.fc2 = nn.Linear(
+            fc_dim,
+            F * 2,
+            weight_attr=param_attr,
+            bias_attr=bias_attr,
+            name=name)
+        self.out_channels = F * 2
+
+    def forward(self, x):
+        """
+           Estimating parameters of geometric transformation
+           Args:
+               image: input
+           Return:
+               batch_C_prime: the matrix of the geometric transformation
+        """
+        B = x.shape[0]
+        i = 0
+        for block in self.block_list:
+            x = block(x)
+        x = x.reshape([B, -1])
+        x = self.fc1(x)
+
+        x = F.relu(x)
+        x = self.fc2(x)
+        x = x.reshape(shape=[-1, self.F, 2])
+        return x
+
+    def get_initial_fiducials(self):
+        """ see RARE paper Fig. 6 (a) """
+        F = self.F
+        ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
+        ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2))
+        ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2))
+        ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
+        ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
+        initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
+        return initial_bias
+
+
+class GridGenerator(nn.Layer):
+    def __init__(self, in_channels, num_fiducial):
+        super(GridGenerator, self).__init__()
+        self.eps = 1e-6
+        self.F = num_fiducial
+
+        name = "ex_fc"
+        initializer = nn.initializer.Constant(value=0.0)
+        param_attr = ParamAttr(
+            learning_rate=0.0, initializer=initializer, name=name + "_w")
+        bias_attr = ParamAttr(
+            learning_rate=0.0, initializer=initializer, name=name + "_b")
+        self.fc = nn.Linear(
+            in_channels,
+            6,
+            weight_attr=param_attr,
+            bias_attr=bias_attr,
+            name=name)
+
+    def forward(self, batch_C_prime, I_r_size):
+        """
+        Generate the grid for the grid_sampler.
+        Args:
+            batch_C_prime: the matrix of the geometric transformation
+            I_r_size: the shape of the input image
+        Return:
+            batch_P_prime: the grid for the grid_sampler
+        """
+        C = self.build_C()
+        P = self.build_P(I_r_size)
+        inv_delta_C = self.build_inv_delta_C(C).astype('float32')
+        P_hat = self.build_P_hat(C, P).astype('float32')
+
+        inv_delta_C_tensor = paddle.to_tensor(inv_delta_C)
+        inv_delta_C_tensor.stop_gradient = True
+        P_hat_tensor = paddle.to_tensor(P_hat)
+        P_hat_tensor.stop_gradient = True
+
+        batch_C_ex_part_tensor = self.get_expand_tensor(batch_C_prime)
+
+        batch_C_ex_part_tensor.stop_gradient = True
+
+        batch_C_prime_with_zeros = paddle.concat(
+            [batch_C_prime, batch_C_ex_part_tensor], axis=1)
+        batch_T = paddle.matmul(inv_delta_C_tensor, batch_C_prime_with_zeros)
+        batch_P_prime = paddle.matmul(P_hat_tensor, batch_T)
+        return batch_P_prime
+
+    def build_C(self):
+        """ Return coordinates of fiducial points in I_r; C """
+        F = self.F
+        ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
+        ctrl_pts_y_top = -1 * np.ones(int(F / 2))
+        ctrl_pts_y_bottom = np.ones(int(F / 2))
+        ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
+        ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
+        C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
+        return C  # F x 2
+
+    def build_P(self, I_r_size):
+        I_r_width, I_r_height = I_r_size
+        I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) \
+                     / I_r_width  # self.I_r_width
+        I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) \
+                     / I_r_height  # self.I_r_height
+        # P: self.I_r_width x self.I_r_height x 2
+        P = np.stack(np.meshgrid(I_r_grid_x, I_r_grid_y), axis=2)
+        # n (= self.I_r_width x self.I_r_height) x 2
+        return P.reshape([-1, 2])
+
+    def build_inv_delta_C(self, C):
+        """ Return inv_delta_C which is needed to calculate T """
+        F = self.F
+        hat_C = np.zeros((F, F), dtype=float)  # F x F
+        for i in range(0, F):
+            for j in range(i, F):
+                r = np.linalg.norm(C[i] - C[j])
+                hat_C[i, j] = r
+                hat_C[j, i] = r
+        np.fill_diagonal(hat_C, 1)
+        hat_C = (hat_C**2) * np.log(hat_C)
+        # print(C.shape, hat_C.shape)
+        delta_C = np.concatenate(  # F+3 x F+3
+            [
+                np.concatenate(
+                    [np.ones((F, 1)), C, hat_C], axis=1),  # F x F+3
+                np.concatenate(
+                    [np.zeros((2, 3)), np.transpose(C)], axis=1),  # 2 x F+3
+                np.concatenate(
+                    [np.zeros((1, 3)), np.ones((1, F))], axis=1)  # 1 x F+3
+            ],
+            axis=0)
+        inv_delta_C = np.linalg.inv(delta_C)
+        return inv_delta_C  # F+3 x F+3
+
+    def build_P_hat(self, C, P):
+        F = self.F
+        eps = self.eps
+        n = P.shape[0]  # n (= self.I_r_width x self.I_r_height)
+        # P_tile: n x 2 -> n x 1 x 2 -> n x F x 2
+        P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1))
+        C_tile = np.expand_dims(C, axis=0)  # 1 x F x 2
+        P_diff = P_tile - C_tile  # n x F x 2
+        # rbf_norm: n x F
+        rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False)
+        # rbf: n x F
+        rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + eps))
+        P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1)
+        return P_hat  # n x F+3
+
+    def get_expand_tensor(self, batch_C_prime):
+        B = batch_C_prime.shape[0]
+        batch_C_prime = batch_C_prime.reshape([B, -1])
+        batch_C_ex_part_tensor = self.fc(batch_C_prime)
+        batch_C_ex_part_tensor = batch_C_ex_part_tensor.reshape([-1, 3, 2])
+        return batch_C_ex_part_tensor
+
+
+class TPS(nn.Layer):
+    def __init__(self, in_channels, num_fiducial, loc_lr, model_name):
+        super(TPS, self).__init__()
+        self.loc_net = LocalizationNetwork(in_channels, num_fiducial, loc_lr,
+                                           model_name)
+        self.grid_generator = GridGenerator(self.loc_net.out_channels,
+                                            num_fiducial)
+        self.out_channels = in_channels
+
+    def forward(self, image):
+        image.stop_gradient = False
+        I_r_size = [image.shape[3], image.shape[2]]
+
+        batch_C_prime = self.loc_net(image)
+        batch_P_prime = self.grid_generator(batch_C_prime, I_r_size)
+        batch_P_prime = batch_P_prime.reshape(
+            [-1, image.shape[2], image.shape[3], 2])
+        batch_I_r = F.grid_sample(x=image, grid=batch_P_prime)
+        return batch_I_r

From 65d3dfc729b820ec04c71084a7573f20698f2a14 Mon Sep 17 00:00:00 2001
From: WenmuZhou <zjwenmu@gmail.com>
Date: Tue, 10 Nov 2020 17:18:50 +0800
Subject: [PATCH 3/5] =?UTF-8?q?rnn=E6=94=AF=E6=8C=81=E5=AF=BC=E5=87=BA?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 ppocr/modeling/necks/rnn.py | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/ppocr/modeling/necks/rnn.py b/ppocr/modeling/necks/rnn.py
index 810c2c8d3b..de87b3d989 100644
--- a/ppocr/modeling/necks/rnn.py
+++ b/ppocr/modeling/necks/rnn.py
@@ -28,8 +28,9 @@ class Im2Seq(nn.Layer):
 
     def forward(self, x):
         B, C, H, W = x.shape
-        x = x.reshape((B, -1, W))
-        x = x.transpose((0, 2, 1))  # (NTC)(batch, width, channels)
+        assert H == 1
+        x = x.squeeze(axis=2)
+        x = x.transpose([0, 2, 1])  # (NTC)(batch, width, channels)
         return x
 
 
@@ -76,7 +77,8 @@ class SequenceEncoder(nn.Layer):
                 'fc': EncoderWithFC,
                 'rnn': EncoderWithRNN
             }
-            assert encoder_type in support_encoder_dict, '{} must in {}'.format(encoder_type, support_encoder_dict.keys())
+            assert encoder_type in support_encoder_dict, '{} must in {}'.format(
+                encoder_type, support_encoder_dict.keys())
 
             self.encoder = support_encoder_dict[encoder_type](
                 self.encoder_reshape.out_channels, hidden_size)

From 33d9688014833434618cfbdd0a14a633e46e83c7 Mon Sep 17 00:00:00 2001
From: WenmuZhou <zjwenmu@gmail.com>
Date: Tue, 10 Nov 2020 17:25:44 +0800
Subject: [PATCH 4/5] =?UTF-8?q?=E6=9B=B4=E6=96=B0NumpyArrayInitializer?=
 =?UTF-8?q?=E4=B8=BAAssign?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 ppocr/modeling/transform/tps.py | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/ppocr/modeling/transform/tps.py b/ppocr/modeling/transform/tps.py
index f5b4f60b8a..50c1740ee4 100644
--- a/ppocr/modeling/transform/tps.py
+++ b/ppocr/modeling/transform/tps.py
@@ -102,13 +102,11 @@ class LocalizationNetwork(nn.Layer):
         name = "loc_fc2"
         param_attr = ParamAttr(
             learning_rate=loc_lr,
-            initializer=paddle.fluid.initializer.NumpyArrayInitializer(
-                np.zeros([fc_dim, F * 2])),
+            initializer=nn.initializer.Assign(np.zeros([fc_dim, F * 2])),
             name=name + "_w")
         bias_attr = ParamAttr(
             learning_rate=loc_lr,
-            initializer=paddle.fluid.initializer.NumpyArrayInitializer(
-                initial_bias),
+            initializer=nn.initializer.Assign(initial_bias),
             name=name + "_b")
         self.fc2 = nn.Linear(
             fc_dim,

From 367c49dffd15c92a5afc59ea5e1af901ecc84fc8 Mon Sep 17 00:00:00 2001
From: WenmuZhou <zjwenmu@gmail.com>
Date: Tue, 10 Nov 2020 18:16:07 +0800
Subject: [PATCH 5/5] =?UTF-8?q?=E5=88=A0=E9=99=A4=20db=20torch=E5=90=8E?=
 =?UTF-8?q?=E5=A4=84=E7=90=86?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 ppocr/postprocess/db_postprocess_torch.py | 136 ----------------------
 1 file changed, 136 deletions(-)
 delete mode 100644 ppocr/postprocess/db_postprocess_torch.py

diff --git a/ppocr/postprocess/db_postprocess_torch.py b/ppocr/postprocess/db_postprocess_torch.py
deleted file mode 100644
index d1466327f1..0000000000
--- a/ppocr/postprocess/db_postprocess_torch.py
+++ /dev/null
@@ -1,136 +0,0 @@
-import cv2
-import paddle
-import numpy as np
-import pyclipper
-from shapely.geometry import Polygon
-
-
-class DBPostProcess():
-    def __init__(self,
-                 thresh=0.3,
-                 box_thresh=0.7,
-                 max_candidates=1000,
-                 unclip_ratio=1.5):
-        self.min_size = 3
-        self.thresh = thresh
-        self.box_thresh = box_thresh
-        self.max_candidates = max_candidates
-        self.unclip_ratio = unclip_ratio
-
-    def __call__(self, pred, shape_list, is_output_polygon=False):
-        '''
-        batch: (image, polygons, ignore_tags
-        h_w_list: 包含[h,w]的数组
-        pred:
-            binary: text region segmentation map, with shape (N, 1,H, W)
-        '''
-        if isinstance(pred, paddle.Tensor):
-            pred = pred.numpy()
-        pred = pred[:, 0, :, :]
-        segmentation = self.binarize(pred)
-        batch_out = []
-        for batch_index in range(pred.shape[0]):
-            height, width = shape_list[batch_index]
-            boxes, scores = self.post_p(
-                pred[batch_index],
-                segmentation[batch_index],
-                width,
-                height,
-                is_output_polygon=is_output_polygon)
-            batch_out.append({"points": boxes})
-        return batch_out
-
-    def binarize(self, pred):
-        return pred > self.thresh
-
-    def post_p(self,
-               pred,
-               bitmap,
-               dest_width,
-               dest_height,
-               is_output_polygon=True):
-        '''
-        _bitmap: single map with shape (H, W),
-            whose values are binarized as {0, 1}
-        '''
-        height, width = pred.shape
-        boxes = []
-        new_scores = []
-        contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
-                                       cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
-        for contour in contours[:self.max_candidates]:
-            epsilon = 0.005 * cv2.arcLength(contour, True)
-            approx = cv2.approxPolyDP(contour, epsilon, True)
-            points = approx.reshape((-1, 2))
-            if points.shape[0] < 4:
-                continue
-            score = self.box_score_fast(pred, points.reshape(-1, 2))
-            if self.box_thresh > score:
-                continue
-
-            if points.shape[0] > 2:
-                box = self.unclip(points, unclip_ratio=self.unclip_ratio)
-                if len(box) > 1 or len(box) == 0:
-                    continue
-            else:
-                continue
-            four_point_box, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
-            if sside < self.min_size + 2:
-                continue
-
-            if not is_output_polygon:
-                box = np.array(four_point_box)
-            else:
-                box = box.reshape(-1, 2)
-            box[:, 0] = np.clip(
-                np.round(box[:, 0] / width * dest_width), 0, dest_width)
-            box[:, 1] = np.clip(
-                np.round(box[:, 1] / height * dest_height), 0, dest_height)
-            boxes.append(box)
-            new_scores.append(score)
-        return boxes, new_scores
-
-    def unclip(self, box, unclip_ratio=1.5):
-        poly = Polygon(box)
-        distance = poly.area * unclip_ratio / poly.length
-        offset = pyclipper.PyclipperOffset()
-        offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
-        expanded = np.array(offset.Execute(distance))
-        return expanded
-
-    def get_mini_boxes(self, contour):
-        bounding_box = cv2.minAreaRect(contour)
-        points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
-
-        index_1, index_2, index_3, index_4 = 0, 1, 2, 3
-        if points[1][1] > points[0][1]:
-            index_1 = 0
-            index_4 = 1
-        else:
-            index_1 = 1
-            index_4 = 0
-        if points[3][1] > points[2][1]:
-            index_2 = 2
-            index_3 = 3
-        else:
-            index_2 = 3
-            index_3 = 2
-
-        box = [
-            points[index_1], points[index_2], points[index_3], points[index_4]
-        ]
-        return box, min(bounding_box[1])
-
-    def box_score_fast(self, bitmap, _box):
-        h, w = bitmap.shape[:2]
-        box = _box.copy()
-        xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
-        xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1)
-        ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1)
-        ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1)
-
-        mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
-        box[:, 0] = box[:, 0] - xmin
-        box[:, 1] = box[:, 1] - ymin
-        cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
-        return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
\ No newline at end of file