Merge branch 'PaddlePaddle:release/2.3' into release/2.3
commit
349b7d38cd
|
@ -141,6 +141,7 @@ Train:
|
||||||
img_mode: BGR
|
img_mode: BGR
|
||||||
channel_first: False
|
channel_first: False
|
||||||
- DetLabelEncode: # Class handling label
|
- DetLabelEncode: # Class handling label
|
||||||
|
- CopyPaste:
|
||||||
- IaaAugment:
|
- IaaAugment:
|
||||||
augmenter_args:
|
augmenter_args:
|
||||||
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
|
@ -91,7 +91,7 @@ Optimizer:
|
||||||
|
|
||||||
PostProcess:
|
PostProcess:
|
||||||
name: DistillationDBPostProcess
|
name: DistillationDBPostProcess
|
||||||
model_name: ["Student", "Student2"]
|
model_name: ["Student"]
|
||||||
key: head_out
|
key: head_out
|
||||||
thresh: 0.3
|
thresh: 0.3
|
||||||
box_thresh: 0.6
|
box_thresh: 0.6
|
|
@ -8,7 +8,7 @@ Global:
|
||||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||||
eval_batch_step: [4000, 5000]
|
eval_batch_step: [4000, 5000]
|
||||||
cal_metric_during_train: False
|
cal_metric_during_train: False
|
||||||
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained/
|
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained
|
||||||
checkpoints:
|
checkpoints:
|
||||||
save_inference_dir:
|
save_inference_dir:
|
||||||
use_visualdl: False
|
use_visualdl: False
|
||||||
|
@ -106,4 +106,4 @@ Eval:
|
||||||
shuffle: False
|
shuffle: False
|
||||||
drop_last: False
|
drop_last: False
|
||||||
batch_size_per_card: 1 # must be 1
|
batch_size_per_card: 1 # must be 1
|
||||||
num_workers: 2
|
num_workers: 2
|
||||||
|
|
|
@ -8,7 +8,7 @@ Global:
|
||||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||||
eval_batch_step: [4000, 5000]
|
eval_batch_step: [4000, 5000]
|
||||||
cal_metric_during_train: False
|
cal_metric_during_train: False
|
||||||
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained/
|
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained
|
||||||
checkpoints:
|
checkpoints:
|
||||||
save_inference_dir:
|
save_inference_dir:
|
||||||
use_visualdl: False
|
use_visualdl: False
|
||||||
|
@ -105,4 +105,4 @@ Eval:
|
||||||
shuffle: False
|
shuffle: False
|
||||||
drop_last: False
|
drop_last: False
|
||||||
batch_size_per_card: 1 # must be 1
|
batch_size_per_card: 1 # must be 1
|
||||||
num_workers: 2
|
num_workers: 2
|
||||||
|
|
|
@ -1,29 +1,28 @@
|
||||||
Global:
|
Global:
|
||||||
use_gpu: true
|
use_gpu: true
|
||||||
epoch_num: 50
|
epoch_num: 400
|
||||||
log_smooth_window: 20
|
log_smooth_window: 20
|
||||||
print_batch_step: 5
|
print_batch_step: 5
|
||||||
save_model_dir: ./output/table_mv3/
|
save_model_dir: ./output/table_mv3/
|
||||||
save_epoch_step: 5
|
save_epoch_step: 3
|
||||||
# evaluation is run every 400 iterations after the 0th iteration
|
# evaluation is run every 400 iterations after the 0th iteration
|
||||||
eval_batch_step: [0, 400]
|
eval_batch_step: [0, 400]
|
||||||
cal_metric_during_train: True
|
cal_metric_during_train: True
|
||||||
pretrained_model:
|
pretrained_model:
|
||||||
checkpoints:
|
checkpoints:
|
||||||
save_inference_dir:
|
save_inference_dir:
|
||||||
use_visualdl: False
|
use_visualdl: False
|
||||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
infer_img: doc/table/table.jpg
|
||||||
# for data or label process
|
# for data or label process
|
||||||
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
|
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
|
||||||
character_type: en
|
character_type: en
|
||||||
max_text_length: 100
|
max_text_length: 100
|
||||||
max_elem_length: 500
|
max_elem_length: 800
|
||||||
max_cell_num: 500
|
max_cell_num: 500
|
||||||
infer_mode: False
|
infer_mode: False
|
||||||
process_total_num: 0
|
process_total_num: 0
|
||||||
process_cut_num: 0
|
process_cut_num: 0
|
||||||
|
|
||||||
|
|
||||||
Optimizer:
|
Optimizer:
|
||||||
name: Adam
|
name: Adam
|
||||||
beta1: 0.9
|
beta1: 0.9
|
||||||
|
@ -41,13 +40,15 @@ Architecture:
|
||||||
Backbone:
|
Backbone:
|
||||||
name: MobileNetV3
|
name: MobileNetV3
|
||||||
scale: 1.0
|
scale: 1.0
|
||||||
model_name: small
|
model_name: large
|
||||||
disable_se: True
|
|
||||||
Head:
|
Head:
|
||||||
name: TableAttentionHead
|
name: TableAttentionHead
|
||||||
hidden_size: 256
|
hidden_size: 256
|
||||||
l2_decay: 0.00001
|
l2_decay: 0.00001
|
||||||
loc_type: 2
|
loc_type: 2
|
||||||
|
max_text_length: 100
|
||||||
|
max_elem_length: 800
|
||||||
|
max_cell_num: 500
|
||||||
|
|
||||||
Loss:
|
Loss:
|
||||||
name: TableAttentionLoss
|
name: TableAttentionLoss
|
||||||
|
|
|
@ -18,7 +18,7 @@ import paddlehub as hub
|
||||||
from tools.infer.utility import base64_to_cv2
|
from tools.infer.utility import base64_to_cv2
|
||||||
from tools.infer.predict_det import TextDetector
|
from tools.infer.predict_det import TextDetector
|
||||||
from tools.infer.utility import parse_args
|
from tools.infer.utility import parse_args
|
||||||
from deploy.hubserving.ocr_system.params import read_params
|
from deploy.hubserving.ocr_det.params import read_params
|
||||||
|
|
||||||
|
|
||||||
@moduleinfo(
|
@moduleinfo(
|
||||||
|
|
|
@ -50,7 +50,7 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|
||||||
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
|
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
|
||||||
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
|
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
|
||||||
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
|
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
|
||||||
- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2))
|
- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2))[13]
|
||||||
|
|
||||||
参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|
参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|
||||||
|
|
||||||
|
@ -78,4 +78,3 @@ PaddleOCR文本检测算法的训练和使用请参考文档教程中[模型训
|
||||||
## 3. 模型推理
|
## 3. 模型推理
|
||||||
|
|
||||||
上述模型中除PP-OCR系列模型以外,其余模型仅支持基于Python引擎的推理,具体内容可参考[基于Python预测引擎推理](./inference.md)
|
上述模型中除PP-OCR系列模型以外,其余模型仅支持基于Python引擎的推理,具体内容可参考[基于Python预测引擎推理](./inference.md)
|
||||||
|
|
||||||
|
|
|
@ -112,4 +112,14 @@
|
||||||
year={2016}
|
year={2016}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
13.NRTR
|
||||||
|
@misc{sheng2019nrtr,
|
||||||
|
title={NRTR: A No-Recurrence Sequence-to-Sequence Model For Scene Text Recognition},
|
||||||
|
author={Fenfen Sheng and Zhineng Chen and Bo Xu},
|
||||||
|
year={2019},
|
||||||
|
eprint={1806.00926},
|
||||||
|
archivePrefix={arXiv},
|
||||||
|
primaryClass={cs.CV}
|
||||||
|
}
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
|
@ -11,7 +11,10 @@
|
||||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
#See the License for the specific language governing permissions and
|
#See the License for the specific language governing permissions and
|
||||||
#limitations under the License.
|
#limitations under the License.
|
||||||
|
"""
|
||||||
|
This code is refered from:
|
||||||
|
https://github.com/songdejia/EAST/blob/master/data_utils.py
|
||||||
|
"""
|
||||||
import math
|
import math
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -24,10 +27,10 @@ __all__ = ['EASTProcessTrain']
|
||||||
|
|
||||||
class EASTProcessTrain(object):
|
class EASTProcessTrain(object):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
image_shape = [512, 512],
|
image_shape=[512, 512],
|
||||||
background_ratio = 0.125,
|
background_ratio=0.125,
|
||||||
min_crop_side_ratio = 0.1,
|
min_crop_side_ratio=0.1,
|
||||||
min_text_size = 10,
|
min_text_size=10,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
self.input_size = image_shape[1]
|
self.input_size = image_shape[1]
|
||||||
self.random_scale = np.array([0.5, 1, 2.0, 3.0])
|
self.random_scale = np.array([0.5, 1, 2.0, 3.0])
|
||||||
|
@ -282,12 +285,7 @@ class EASTProcessTrain(object):
|
||||||
1.0 / max(min(poly_h, poly_w), 1.0)
|
1.0 / max(min(poly_h, poly_w), 1.0)
|
||||||
return score_map, geo_map, training_mask
|
return score_map, geo_map, training_mask
|
||||||
|
|
||||||
def crop_area(self,
|
def crop_area(self, im, polys, tags, crop_background=False, max_tries=50):
|
||||||
im,
|
|
||||||
polys,
|
|
||||||
tags,
|
|
||||||
crop_background=False,
|
|
||||||
max_tries=50):
|
|
||||||
"""
|
"""
|
||||||
make random crop from the input image
|
make random crop from the input image
|
||||||
:param im:
|
:param im:
|
||||||
|
@ -436,4 +434,4 @@ class EASTProcessTrain(object):
|
||||||
data['geo_map'] = geo_map
|
data['geo_map'] = geo_map
|
||||||
data['training_mask'] = training_mask
|
data['training_mask'] = training_mask
|
||||||
# print(im.shape, score_map.shape, geo_map.shape, training_mask.shape)
|
# print(im.shape, score_map.shape, geo_map.shape, training_mask.shape)
|
||||||
return data
|
return data
|
||||||
|
|
|
@ -11,6 +11,11 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This code is refer from:
|
||||||
|
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/iaa_augment.py
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
|
@ -1,5 +1,20 @@
|
||||||
# -*- coding:utf-8 -*-
|
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This code is refer from:
|
||||||
|
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/make_border_map.py
|
||||||
|
"""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
|
@ -1,5 +1,20 @@
|
||||||
# -*- coding:utf-8 -*-
|
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This code is refer from:
|
||||||
|
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/make_shrink_map.py
|
||||||
|
"""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
|
@ -1,5 +1,20 @@
|
||||||
# -*- coding:utf-8 -*-
|
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This code is refer from:
|
||||||
|
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/random_crop_data.py
|
||||||
|
"""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
|
@ -11,7 +11,10 @@
|
||||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
#See the License for the specific language governing permissions and
|
#See the License for the specific language governing permissions and
|
||||||
#limitations under the License.
|
#limitations under the License.
|
||||||
|
"""
|
||||||
|
This part code is refered from:
|
||||||
|
https://github.com/songdejia/EAST/blob/master/data_utils.py
|
||||||
|
"""
|
||||||
import math
|
import math
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
|
@ -11,7 +11,10 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This code is refer from:
|
||||||
|
https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/augment.py
|
||||||
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .warp_mls import WarpMLS
|
from .warp_mls import WarpMLS
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,10 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This code is refer from:
|
||||||
|
https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/warp_mls.py
|
||||||
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,10 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This code is refer from:
|
||||||
|
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/models/losses/basic_loss.py
|
||||||
|
"""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
|
@ -11,6 +11,10 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This code is refer from:
|
||||||
|
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/models/losses/DB_loss.py
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
# This code is refer from: https://github.com/PaddlePaddle/PaddleClas/blob/develop/ppcls/arch/backbone/legendary_models/pp_lcnet.py
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
|
@ -75,7 +75,7 @@ class AttentionHead(nn.Layer):
|
||||||
probs_step, axis=1)], axis=1)
|
probs_step, axis=1)], axis=1)
|
||||||
next_input = probs_step.argmax(axis=1)
|
next_input = probs_step.argmax(axis=1)
|
||||||
targets = next_input
|
targets = next_input
|
||||||
|
probs = paddle.nn.functional.softmax(probs, axis=2)
|
||||||
return probs
|
return probs
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -23,32 +23,40 @@ import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class TableAttentionHead(nn.Layer):
|
class TableAttentionHead(nn.Layer):
|
||||||
def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs):
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
hidden_size,
|
||||||
|
loc_type,
|
||||||
|
in_max_len=488,
|
||||||
|
max_text_length=100,
|
||||||
|
max_elem_length=800,
|
||||||
|
max_cell_num=500,
|
||||||
|
**kwargs):
|
||||||
super(TableAttentionHead, self).__init__()
|
super(TableAttentionHead, self).__init__()
|
||||||
self.input_size = in_channels[-1]
|
self.input_size = in_channels[-1]
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.elem_num = 30
|
self.elem_num = 30
|
||||||
self.max_text_length = 100
|
self.max_text_length = max_text_length
|
||||||
self.max_elem_length = 500
|
self.max_elem_length = max_elem_length
|
||||||
self.max_cell_num = 500
|
self.max_cell_num = max_cell_num
|
||||||
|
|
||||||
self.structure_attention_cell = AttentionGRUCell(
|
self.structure_attention_cell = AttentionGRUCell(
|
||||||
self.input_size, hidden_size, self.elem_num, use_gru=False)
|
self.input_size, hidden_size, self.elem_num, use_gru=False)
|
||||||
self.structure_generator = nn.Linear(hidden_size, self.elem_num)
|
self.structure_generator = nn.Linear(hidden_size, self.elem_num)
|
||||||
self.loc_type = loc_type
|
self.loc_type = loc_type
|
||||||
self.in_max_len = in_max_len
|
self.in_max_len = in_max_len
|
||||||
|
|
||||||
if self.loc_type == 1:
|
if self.loc_type == 1:
|
||||||
self.loc_generator = nn.Linear(hidden_size, 4)
|
self.loc_generator = nn.Linear(hidden_size, 4)
|
||||||
else:
|
else:
|
||||||
if self.in_max_len == 640:
|
if self.in_max_len == 640:
|
||||||
self.loc_fea_trans = nn.Linear(400, self.max_elem_length+1)
|
self.loc_fea_trans = nn.Linear(400, self.max_elem_length + 1)
|
||||||
elif self.in_max_len == 800:
|
elif self.in_max_len == 800:
|
||||||
self.loc_fea_trans = nn.Linear(625, self.max_elem_length+1)
|
self.loc_fea_trans = nn.Linear(625, self.max_elem_length + 1)
|
||||||
else:
|
else:
|
||||||
self.loc_fea_trans = nn.Linear(256, self.max_elem_length+1)
|
self.loc_fea_trans = nn.Linear(256, self.max_elem_length + 1)
|
||||||
self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
|
self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
|
||||||
|
|
||||||
def _char_to_onehot(self, input_char, onehot_dim):
|
def _char_to_onehot(self, input_char, onehot_dim):
|
||||||
input_ont_hot = F.one_hot(input_char, onehot_dim)
|
input_ont_hot = F.one_hot(input_char, onehot_dim)
|
||||||
return input_ont_hot
|
return input_ont_hot
|
||||||
|
@ -60,16 +68,16 @@ class TableAttentionHead(nn.Layer):
|
||||||
if len(fea.shape) == 3:
|
if len(fea.shape) == 3:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
last_shape = int(np.prod(fea.shape[2:])) # gry added
|
last_shape = int(np.prod(fea.shape[2:])) # gry added
|
||||||
fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
|
fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
|
||||||
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
|
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
|
||||||
batch_size = fea.shape[0]
|
batch_size = fea.shape[0]
|
||||||
|
|
||||||
hidden = paddle.zeros((batch_size, self.hidden_size))
|
hidden = paddle.zeros((batch_size, self.hidden_size))
|
||||||
output_hiddens = []
|
output_hiddens = []
|
||||||
if self.training and targets is not None:
|
if self.training and targets is not None:
|
||||||
structure = targets[0]
|
structure = targets[0]
|
||||||
for i in range(self.max_elem_length+1):
|
for i in range(self.max_elem_length + 1):
|
||||||
elem_onehots = self._char_to_onehot(
|
elem_onehots = self._char_to_onehot(
|
||||||
structure[:, i], onehot_dim=self.elem_num)
|
structure[:, i], onehot_dim=self.elem_num)
|
||||||
(outputs, hidden), alpha = self.structure_attention_cell(
|
(outputs, hidden), alpha = self.structure_attention_cell(
|
||||||
|
@ -96,7 +104,7 @@ class TableAttentionHead(nn.Layer):
|
||||||
alpha = None
|
alpha = None
|
||||||
max_elem_length = paddle.to_tensor(self.max_elem_length)
|
max_elem_length = paddle.to_tensor(self.max_elem_length)
|
||||||
i = 0
|
i = 0
|
||||||
while i < max_elem_length+1:
|
while i < max_elem_length + 1:
|
||||||
elem_onehots = self._char_to_onehot(
|
elem_onehots = self._char_to_onehot(
|
||||||
temp_elem, onehot_dim=self.elem_num)
|
temp_elem, onehot_dim=self.elem_num)
|
||||||
(outputs, hidden), alpha = self.structure_attention_cell(
|
(outputs, hidden), alpha = self.structure_attention_cell(
|
||||||
|
@ -105,7 +113,7 @@ class TableAttentionHead(nn.Layer):
|
||||||
structure_probs_step = self.structure_generator(outputs)
|
structure_probs_step = self.structure_generator(outputs)
|
||||||
temp_elem = structure_probs_step.argmax(axis=1, dtype="int32")
|
temp_elem = structure_probs_step.argmax(axis=1, dtype="int32")
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
output = paddle.concat(output_hiddens, axis=1)
|
output = paddle.concat(output_hiddens, axis=1)
|
||||||
structure_probs = self.structure_generator(output)
|
structure_probs = self.structure_generator(output)
|
||||||
structure_probs = F.softmax(structure_probs)
|
structure_probs = F.softmax(structure_probs)
|
||||||
|
@ -119,9 +127,9 @@ class TableAttentionHead(nn.Layer):
|
||||||
loc_concat = paddle.concat([output, loc_fea], axis=2)
|
loc_concat = paddle.concat([output, loc_fea], axis=2)
|
||||||
loc_preds = self.loc_generator(loc_concat)
|
loc_preds = self.loc_generator(loc_concat)
|
||||||
loc_preds = F.sigmoid(loc_preds)
|
loc_preds = F.sigmoid(loc_preds)
|
||||||
return {'structure_probs':structure_probs, 'loc_preds':loc_preds}
|
return {'structure_probs': structure_probs, 'loc_preds': loc_preds}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AttentionGRUCell(nn.Layer):
|
class AttentionGRUCell(nn.Layer):
|
||||||
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
|
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
|
||||||
super(AttentionGRUCell, self).__init__()
|
super(AttentionGRUCell, self).__init__()
|
||||||
|
|
|
@ -11,6 +11,10 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This code is refer from:
|
||||||
|
https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/modules/transformation.py
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
@ -231,7 +235,8 @@ class GridGenerator(nn.Layer):
|
||||||
""" Return inv_delta_C which is needed to calculate T """
|
""" Return inv_delta_C which is needed to calculate T """
|
||||||
F = self.F
|
F = self.F
|
||||||
hat_eye = paddle.eye(F, dtype='float64') # F x F
|
hat_eye = paddle.eye(F, dtype='float64') # F x F
|
||||||
hat_C = paddle.norm(C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye
|
hat_C = paddle.norm(
|
||||||
|
C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye
|
||||||
hat_C = (hat_C**2) * paddle.log(hat_C)
|
hat_C = (hat_C**2) * paddle.log(hat_C)
|
||||||
delta_C = paddle.concat( # F+3 x F+3
|
delta_C = paddle.concat( # F+3 x F+3
|
||||||
[
|
[
|
||||||
|
|
|
@ -11,7 +11,10 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This code is refered from:
|
||||||
|
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/post_processing/seg_detector_representer.py
|
||||||
|
"""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
@ -190,7 +193,8 @@ class DBPostProcess(object):
|
||||||
|
|
||||||
|
|
||||||
class DistillationDBPostProcess(object):
|
class DistillationDBPostProcess(object):
|
||||||
def __init__(self, model_name=["student"],
|
def __init__(self,
|
||||||
|
model_name=["student"],
|
||||||
key=None,
|
key=None,
|
||||||
thresh=0.3,
|
thresh=0.3,
|
||||||
box_thresh=0.6,
|
box_thresh=0.6,
|
||||||
|
@ -201,12 +205,13 @@ class DistillationDBPostProcess(object):
|
||||||
**kwargs):
|
**kwargs):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.key = key
|
self.key = key
|
||||||
self.post_process = DBPostProcess(thresh=thresh,
|
self.post_process = DBPostProcess(
|
||||||
box_thresh=box_thresh,
|
thresh=thresh,
|
||||||
max_candidates=max_candidates,
|
box_thresh=box_thresh,
|
||||||
unclip_ratio=unclip_ratio,
|
max_candidates=max_candidates,
|
||||||
use_dilation=use_dilation,
|
unclip_ratio=unclip_ratio,
|
||||||
score_mode=score_mode)
|
use_dilation=use_dilation,
|
||||||
|
score_mode=score_mode)
|
||||||
|
|
||||||
def __call__(self, predicts, shape_list):
|
def __call__(self, predicts, shape_list):
|
||||||
results = {}
|
results = {}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
"""
|
"""
|
||||||
Locality aware nms.
|
Locality aware nms.
|
||||||
|
This code is refered from: https://github.com/songdejia/EAST/blob/master/locality_aware_nms.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
|
@ -11,7 +11,10 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This code is refer from:
|
||||||
|
https://github.com/WenmuZhou/PytorchOCR/blob/master/torchocr/utils/logging.py
|
||||||
|
"""
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import logging
|
import logging
|
||||||
|
|
|
@ -187,7 +187,7 @@ def create_predictor(args, mode, logger):
|
||||||
"nearest_interp_v2_0.tmp_0": [1, 256, 2, 2]
|
"nearest_interp_v2_0.tmp_0": [1, 256, 2, 2]
|
||||||
}
|
}
|
||||||
max_input_shape = {
|
max_input_shape = {
|
||||||
"x": [1, 3, 2000, 2000],
|
"x": [1, 3, 1280, 1280],
|
||||||
"conv2d_92.tmp_0": [1, 120, 400, 400],
|
"conv2d_92.tmp_0": [1, 120, 400, 400],
|
||||||
"conv2d_91.tmp_0": [1, 24, 200, 200],
|
"conv2d_91.tmp_0": [1, 24, 200, 200],
|
||||||
"conv2d_59.tmp_0": [1, 96, 400, 400],
|
"conv2d_59.tmp_0": [1, 96, 400, 400],
|
||||||
|
@ -237,16 +237,16 @@ def create_predictor(args, mode, logger):
|
||||||
opt_input_shape.update(opt_pact_shape)
|
opt_input_shape.update(opt_pact_shape)
|
||||||
elif mode == "rec":
|
elif mode == "rec":
|
||||||
min_input_shape = {"x": [1, 3, 32, 10]}
|
min_input_shape = {"x": [1, 3, 32, 10]}
|
||||||
max_input_shape = {"x": [args.rec_batch_num, 3, 32, 2000]}
|
max_input_shape = {"x": [args.rec_batch_num, 3, 32, 1024]}
|
||||||
opt_input_shape = {"x": [args.rec_batch_num, 3, 32, 320]}
|
opt_input_shape = {"x": [args.rec_batch_num, 3, 32, 320]}
|
||||||
elif mode == "cls":
|
elif mode == "cls":
|
||||||
min_input_shape = {"x": [1, 3, 48, 10]}
|
min_input_shape = {"x": [1, 3, 48, 10]}
|
||||||
max_input_shape = {"x": [args.rec_batch_num, 3, 48, 2000]}
|
max_input_shape = {"x": [args.rec_batch_num, 3, 48, 1024]}
|
||||||
opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]}
|
opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]}
|
||||||
else:
|
else:
|
||||||
min_input_shape = {"x": [1, 3, 10, 10]}
|
min_input_shape = {"x": [1, 3, 10, 10]}
|
||||||
max_input_shape = {"x": [1, 3, 1000, 1000]}
|
max_input_shape = {"x": [1, 3, 512, 512]}
|
||||||
opt_input_shape = {"x": [1, 3, 500, 500]}
|
opt_input_shape = {"x": [1, 3, 256, 256]}
|
||||||
config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
|
config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
|
||||||
opt_input_shape)
|
opt_input_shape)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue