add SLANet
parent
342522ab52
commit
ddaa2c2552
|
@ -32,3 +32,31 @@ paddleocr.egg-info/
|
|||
/deploy/android_demo/app/cache/
|
||||
test_tipc/web/models/
|
||||
test_tipc/web/node_modules/
|
||||
en_ppocr_mobile_v2.0_table_structure_infer/._inference.pdiparams
|
||||
en_ppocr_mobile_v2.0_table_structure_infer/._inference.pdiparams.info
|
||||
en_ppocr_mobile_v2.0_table_structure_infer/._inference.pdmodel
|
||||
en_ppocr_mobile_v2.0_table_structure_infer/inference.pdiparams
|
||||
en_ppocr_mobile_v2.0_table_structure_infer/inference.pdiparams.info
|
||||
en_ppocr_mobile_v2.0_table_structure_infer/inference.pdmodel
|
||||
ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/._inference.pdiparams
|
||||
ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/._inference.pdiparams.info
|
||||
ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/._inference.pdmodel
|
||||
ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/inference.pdiparams
|
||||
ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/inference.pdiparams.info
|
||||
ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/inference.pdmodel
|
||||
.gitignore
|
||||
.gitignore
|
||||
ppstructure/layout/en_ppocr_mobile_v2.0_table_rec_infer/inference.pdiparams
|
||||
ppstructure/layout/en_ppocr_mobile_v2.0_table_rec_infer/inference.pdiparams.info
|
||||
ppstructure/layout/en_ppocr_mobile_v2.0_table_rec_infer/inference.pdmodel
|
||||
ppstructure/layout/picodet_lcnet_x2_5_640_publayernet_shape/infer_cfg.yml
|
||||
ppstructure/layout/picodet_lcnet_x2_5_640_publayernet_shape/inference.pdiparams
|
||||
ppstructure/layout/picodet_lcnet_x2_5_640_publayernet_shape/inference.pdiparams.info
|
||||
ppstructure/layout/picodet_lcnet_x2_5_640_publayernet_shape/inference.pdmodel
|
||||
.gitignore
|
||||
ppstructure/layout/table/inference.pdiparams
|
||||
ppstructure/layout/table/inference.pdiparams.info
|
||||
ppstructure/layout/table/inference.pdmodel
|
||||
ppstructure/layout/picodet_lcnet_x2_5_640_publayernet_shape.tar
|
||||
._en_ppocr_mobile_v2.0_table_structure_infer
|
||||
en_ppocr_mobile_v2.0_table_structure_infer.tar
|
||||
|
|
|
@ -0,0 +1,141 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 400
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 20
|
||||
save_model_dir: ./output/SLANet
|
||||
save_epoch_step: 400
|
||||
# evaluation is run every 1000 iterations after the 0th iteration
|
||||
eval_batch_step: [0, 1000]
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints: /ssd1/zhoujun20/table/ch/PaddleOCR/output/en/table_lcnet_1_0_csp_pan_headsv3_smooth_l1_pretrain_ssld_weight81_sync_bn/best_accuracy.pdparams
|
||||
save_inference_dir: ./output/SLANet/infer
|
||||
use_visualdl: False
|
||||
infer_img: doc/table/table.jpg
|
||||
# for data or label process
|
||||
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
|
||||
character_type: en
|
||||
max_text_length: &max_text_length 500
|
||||
box_format: &box_format 'xyxy' # 'xywh', 'xyxy', 'xyxyxyxy'
|
||||
infer_mode: False
|
||||
use_sync_bn: True
|
||||
save_res_path: 'output/infer'
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
clip_norm: 5.0
|
||||
lr:
|
||||
# name: Piecewise
|
||||
learning_rate: 0.001
|
||||
# decay_epochs : [10, 20]
|
||||
# values : [0.002, 0.0002, 0.0001]
|
||||
# warmup_epoch: 0
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0.00000
|
||||
|
||||
Architecture:
|
||||
model_type: table
|
||||
algorithm: SLANet
|
||||
Backbone:
|
||||
name: PPLCNet
|
||||
scale: 1.0
|
||||
pretrained: true
|
||||
use_ssld: true
|
||||
Neck:
|
||||
name: CSPPAN
|
||||
out_channels: 96
|
||||
Head:
|
||||
name: SLAHead
|
||||
hidden_size: 256
|
||||
max_text_length: *max_text_length
|
||||
loc_reg_num: &loc_reg_num 4
|
||||
|
||||
Loss:
|
||||
name: SLANetLoss
|
||||
structure_weight: 1.0
|
||||
loc_weight: 2.0
|
||||
loc_loss: smooth_l1
|
||||
|
||||
PostProcess:
|
||||
name: TableLabelDecode
|
||||
|
||||
Metric:
|
||||
name: TableMetric
|
||||
main_indicator: acc
|
||||
compute_bbox_metric: False
|
||||
loc_reg_num: *loc_reg_num
|
||||
box_format: *box_format
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: PubTabDataSet
|
||||
data_dir: /home/zhoujun20/table/PubTabNe/pubtabnet/train/
|
||||
label_file_list: [/home/zhoujun20/table/PubTabNe/pubtabnet/PubTabNet_2.0.0_train.jsonl]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- TableLabelEncode:
|
||||
learn_empty_box: False
|
||||
merge_no_span_structure: False
|
||||
replace_empty_cell_token: False
|
||||
loc_reg_num: *loc_reg_num
|
||||
max_text_length: *max_text_length
|
||||
- TableBoxEncode:
|
||||
box_format: *box_format
|
||||
- ResizeTableImage:
|
||||
max_len: 488
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- PaddingTableImage:
|
||||
size: [488, 488]
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 48
|
||||
drop_last: True
|
||||
num_workers: 1
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: PubTabDataSet
|
||||
data_dir: /home/zhoujun20/table/PubTabNe/pubtabnet/val/
|
||||
label_file_list: [/home/zhoujun20/table/PubTabNe/pubtabnet/PubTabNet_2.0.0_val.jsonl]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- TableLabelEncode:
|
||||
learn_empty_box: False
|
||||
merge_no_span_structure: False
|
||||
replace_empty_cell_token: False
|
||||
loc_reg_num: *loc_reg_num
|
||||
max_text_length: *max_text_length
|
||||
- TableBoxEncode:
|
||||
box_format: *box_format
|
||||
- ResizeTableImage:
|
||||
max_len: 488
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- PaddingTableImage:
|
||||
size: [488, 488]
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 48
|
||||
num_workers: 1
|
|
@ -15,9 +15,8 @@ Global:
|
|||
save_res_path: ./output/table_master
|
||||
character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt
|
||||
infer_mode: false
|
||||
max_text_length: 500
|
||||
process_total_num: 0
|
||||
process_cut_num: 0
|
||||
max_text_length: &max_text_length 500
|
||||
box_format: &box_format 'xywh' # 'xywh', 'xyxy', 'xyxyxyxy'
|
||||
|
||||
|
||||
Optimizer:
|
||||
|
@ -52,7 +51,8 @@ Architecture:
|
|||
headers: 8
|
||||
dropout: 0
|
||||
d_ff: 2024
|
||||
max_text_length: 500
|
||||
max_text_length: *max_text_length
|
||||
loc_reg_num: &loc_reg_num 4
|
||||
|
||||
Loss:
|
||||
name: TableMasterLoss
|
||||
|
@ -66,6 +66,7 @@ Metric:
|
|||
name: TableMetric
|
||||
main_indicator: acc
|
||||
compute_bbox_metric: False
|
||||
box_format: *box_format
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
|
@ -80,13 +81,15 @@ Train:
|
|||
learn_empty_box: False
|
||||
merge_no_span_structure: True
|
||||
replace_empty_cell_token: True
|
||||
loc_reg_num: *loc_reg_num
|
||||
max_text_length: *max_text_length
|
||||
- ResizeTableImage:
|
||||
max_len: 480
|
||||
resize_bboxes: True
|
||||
- PaddingTableImage:
|
||||
size: [480, 480]
|
||||
- TableBoxEncode:
|
||||
use_xywh: True
|
||||
box_format: *box_format
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.5, 0.5, 0.5]
|
||||
|
@ -114,13 +117,15 @@ Eval:
|
|||
learn_empty_box: False
|
||||
merge_no_span_structure: True
|
||||
replace_empty_cell_token: True
|
||||
loc_reg_num: *loc_reg_num
|
||||
max_text_length: *max_text_length
|
||||
- ResizeTableImage:
|
||||
max_len: 480
|
||||
resize_bboxes: True
|
||||
- PaddingTableImage:
|
||||
size: [480, 480]
|
||||
- TableBoxEncode:
|
||||
use_xywh: True
|
||||
box_format: *box_format
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.5, 0.5, 0.5]
|
||||
|
|
|
@ -17,10 +17,9 @@ Global:
|
|||
# for data or label process
|
||||
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
|
||||
character_type: en
|
||||
max_text_length: 800
|
||||
max_text_length: &max_text_length 800
|
||||
box_format: &box_format 'xyxy' # 'xywh', 'xyxy', 'xyxyxyxy'
|
||||
infer_mode: False
|
||||
process_total_num: 0
|
||||
process_cut_num: 0
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
|
@ -44,7 +43,8 @@ Architecture:
|
|||
name: TableAttentionHead
|
||||
hidden_size: 256
|
||||
loc_type: 2
|
||||
max_text_length: 800
|
||||
max_text_length: *max_text_length
|
||||
loc_reg_num: &loc_reg_num 4
|
||||
|
||||
Loss:
|
||||
name: TableAttentionLoss
|
||||
|
@ -72,6 +72,8 @@ Train:
|
|||
learn_empty_box: False
|
||||
merge_no_span_structure: False
|
||||
replace_empty_cell_token: False
|
||||
loc_reg_num: *loc_reg_num
|
||||
max_text_length: *max_text_length
|
||||
- TableBoxEncode:
|
||||
- ResizeTableImage:
|
||||
max_len: 488
|
||||
|
@ -104,6 +106,8 @@ Eval:
|
|||
learn_empty_box: False
|
||||
merge_no_span_structure: False
|
||||
replace_empty_cell_token: False
|
||||
loc_reg_num: *loc_reg_num
|
||||
max_text_length: *max_text_length
|
||||
- TableBoxEncode:
|
||||
- ResizeTableImage:
|
||||
max_len: 488
|
||||
|
|
|
@ -118,7 +118,7 @@ class OCRSystem(hub.Module):
|
|||
all_results.append([])
|
||||
continue
|
||||
starttime = time.time()
|
||||
dt_boxes, rec_res = self.text_sys(img)
|
||||
dt_boxes, rec_res, _ = self.text_sys(img)
|
||||
elapse = time.time() - starttime
|
||||
logger.info("Predict time: {}".format(elapse))
|
||||
|
||||
|
|
|
@ -571,7 +571,7 @@ class TableLabelEncode(AttnLabelEncode):
|
|||
replace_empty_cell_token=False,
|
||||
merge_no_span_structure=False,
|
||||
learn_empty_box=False,
|
||||
point_num=2,
|
||||
loc_reg_num=4,
|
||||
**kwargs):
|
||||
self.max_text_len = max_text_length
|
||||
self.lower = False
|
||||
|
@ -593,7 +593,7 @@ class TableLabelEncode(AttnLabelEncode):
|
|||
self.idx2char = {v: k for k, v in self.dict.items()}
|
||||
|
||||
self.character = dict_character
|
||||
self.point_num = point_num
|
||||
self.loc_reg_num = loc_reg_num
|
||||
self.pad_idx = self.dict[self.beg_str]
|
||||
self.start_idx = self.dict[self.beg_str]
|
||||
self.end_idx = self.dict[self.end_str]
|
||||
|
@ -649,7 +649,7 @@ class TableLabelEncode(AttnLabelEncode):
|
|||
|
||||
# encode box
|
||||
bboxes = np.zeros(
|
||||
(self._max_text_len, self.point_num * 2), dtype=np.float32)
|
||||
(self._max_text_len, self.loc_reg_num), dtype=np.float32)
|
||||
bbox_masks = np.zeros((self._max_text_len, 1), dtype=np.float32)
|
||||
|
||||
bbox_idx = 0
|
||||
|
@ -714,11 +714,11 @@ class TableMasterLabelEncode(TableLabelEncode):
|
|||
replace_empty_cell_token=False,
|
||||
merge_no_span_structure=False,
|
||||
learn_empty_box=False,
|
||||
point_num=2,
|
||||
loc_reg_num=4,
|
||||
**kwargs):
|
||||
super(TableMasterLabelEncode, self).__init__(
|
||||
max_text_length, character_dict_path, replace_empty_cell_token,
|
||||
merge_no_span_structure, learn_empty_box, point_num, **kwargs)
|
||||
merge_no_span_structure, learn_empty_box, loc_reg_num, **kwargs)
|
||||
self.pad_idx = self.dict[self.pad_str]
|
||||
self.unknown_idx = self.dict[self.unknown_str]
|
||||
|
||||
|
@ -739,13 +739,14 @@ class TableMasterLabelEncode(TableLabelEncode):
|
|||
|
||||
|
||||
class TableBoxEncode(object):
|
||||
def __init__(self, use_xywh=False, **kwargs):
|
||||
self.use_xywh = use_xywh
|
||||
def __init__(self, box_format='xyxy', **kwargs):
|
||||
assert box_format in ['xywh', 'xyxy', 'xyxyxyxy']
|
||||
self.box_format = box_format
|
||||
|
||||
def __call__(self, data):
|
||||
img_height, img_width = data['image'].shape[:2]
|
||||
bboxes = data['bboxes']
|
||||
if self.use_xywh and bboxes.shape[1] == 4:
|
||||
if self.box_format == 'xywh' and bboxes.shape[1] == 4:
|
||||
bboxes = self.xyxy2xywh(bboxes)
|
||||
bboxes[:, 0::2] /= img_width
|
||||
bboxes[:, 1::2] /= img_height
|
||||
|
@ -1217,6 +1218,7 @@ class ABINetLabelEncode(BaseRecLabelEncode):
|
|||
dict_character = ['</s>'] + dict_character
|
||||
return dict_character
|
||||
|
||||
|
||||
class SPINLabelEncode(AttnLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
|
@ -1229,6 +1231,7 @@ class SPINLabelEncode(AttnLabelEncode):
|
|||
super(SPINLabelEncode, self).__init__(
|
||||
max_text_length, character_dict_path, use_space_char)
|
||||
self.lower = lower
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
|
@ -1248,4 +1251,4 @@ class SPINLabelEncode(AttnLabelEncode):
|
|||
|
||||
padded_text[:len(target)] = target
|
||||
data['label'] = np.array(padded_text)
|
||||
return data
|
||||
return data
|
||||
|
|
|
@ -206,7 +206,7 @@ class ResizeTableImage(object):
|
|||
data['bboxes'] = data['bboxes'] * ratio
|
||||
data['image'] = resize_img
|
||||
data['src_img'] = img
|
||||
data['shape'] = np.array([resize_h, resize_w, ratio, ratio])
|
||||
data['shape'] = np.array([height, width, ratio, ratio])
|
||||
data['max_len'] = self.max_len
|
||||
return data
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ from .basic_loss import DistanceLoss
|
|||
from .combined_loss import CombinedLoss
|
||||
|
||||
# table loss
|
||||
from .table_att_loss import TableAttentionLoss
|
||||
from .table_att_loss import TableAttentionLoss, SLANetLoss
|
||||
from .table_master_loss import TableMasterLoss
|
||||
# vqa token loss
|
||||
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
|
||||
|
@ -63,7 +63,7 @@ def build_loss(config):
|
|||
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
|
||||
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
|
||||
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
|
||||
'TableMasterLoss', 'SPINAttentionLoss'
|
||||
'TableMasterLoss', 'SPINAttentionLoss', 'SLANetLoss'
|
||||
]
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -22,65 +22,11 @@ from paddle.nn import functional as F
|
|||
|
||||
|
||||
class TableAttentionLoss(nn.Layer):
|
||||
def __init__(self,
|
||||
structure_weight,
|
||||
loc_weight,
|
||||
use_giou=False,
|
||||
giou_weight=1.0,
|
||||
**kwargs):
|
||||
def __init__(self, structure_weight, loc_weight, **kwargs):
|
||||
super(TableAttentionLoss, self).__init__()
|
||||
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
|
||||
self.structure_weight = structure_weight
|
||||
self.loc_weight = loc_weight
|
||||
self.use_giou = use_giou
|
||||
self.giou_weight = giou_weight
|
||||
|
||||
def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'):
|
||||
'''
|
||||
:param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
|
||||
:param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
|
||||
:return: loss
|
||||
'''
|
||||
ix1 = paddle.maximum(preds[:, 0], bbox[:, 0])
|
||||
iy1 = paddle.maximum(preds[:, 1], bbox[:, 1])
|
||||
ix2 = paddle.minimum(preds[:, 2], bbox[:, 2])
|
||||
iy2 = paddle.minimum(preds[:, 3], bbox[:, 3])
|
||||
|
||||
iw = paddle.clip(ix2 - ix1 + 1e-3, 0., 1e10)
|
||||
ih = paddle.clip(iy2 - iy1 + 1e-3, 0., 1e10)
|
||||
|
||||
# overlap
|
||||
inters = iw * ih
|
||||
|
||||
# union
|
||||
uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (
|
||||
preds[:, 3] - preds[:, 1] + 1e-3) + (bbox[:, 2] - bbox[:, 0] + 1e-3
|
||||
) * (bbox[:, 3] - bbox[:, 1] +
|
||||
1e-3) - inters + eps
|
||||
|
||||
# ious
|
||||
ious = inters / uni
|
||||
|
||||
ex1 = paddle.minimum(preds[:, 0], bbox[:, 0])
|
||||
ey1 = paddle.minimum(preds[:, 1], bbox[:, 1])
|
||||
ex2 = paddle.maximum(preds[:, 2], bbox[:, 2])
|
||||
ey2 = paddle.maximum(preds[:, 3], bbox[:, 3])
|
||||
ew = paddle.clip(ex2 - ex1 + 1e-3, 0., 1e10)
|
||||
eh = paddle.clip(ey2 - ey1 + 1e-3, 0., 1e10)
|
||||
|
||||
# enclose erea
|
||||
enclose = ew * eh + eps
|
||||
giou = ious - (enclose - uni) / enclose
|
||||
|
||||
loss = 1 - giou
|
||||
|
||||
if reduction == 'mean':
|
||||
loss = paddle.mean(loss)
|
||||
elif reduction == 'sum':
|
||||
loss = paddle.sum(loss)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return loss
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
structure_probs = predicts['structure_probs']
|
||||
|
@ -100,20 +46,48 @@ class TableAttentionLoss(nn.Layer):
|
|||
loc_targets_mask = loc_targets_mask[:, 1:, :]
|
||||
loc_loss = F.mse_loss(loc_preds * loc_targets_mask,
|
||||
loc_targets) * self.loc_weight
|
||||
if self.use_giou:
|
||||
loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask,
|
||||
loc_targets) * self.giou_weight
|
||||
total_loss = structure_loss + loc_loss + loc_loss_giou
|
||||
return {
|
||||
'loss': total_loss,
|
||||
"structure_loss": structure_loss,
|
||||
"loc_loss": loc_loss,
|
||||
"loc_loss_giou": loc_loss_giou
|
||||
}
|
||||
else:
|
||||
total_loss = structure_loss + loc_loss
|
||||
return {
|
||||
'loss': total_loss,
|
||||
"structure_loss": structure_loss,
|
||||
"loc_loss": loc_loss
|
||||
}
|
||||
|
||||
total_loss = structure_loss + loc_loss
|
||||
return {
|
||||
'loss': total_loss,
|
||||
"structure_loss": structure_loss,
|
||||
"loc_loss": loc_loss
|
||||
}
|
||||
|
||||
|
||||
class SLANetLoss(nn.Layer):
|
||||
def __init__(self, structure_weight, loc_weight, loc_loss='mse', **kwargs):
|
||||
super(SLANetLoss, self).__init__()
|
||||
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='mean')
|
||||
self.structure_weight = structure_weight
|
||||
self.loc_weight = loc_weight
|
||||
self.loc_loss = loc_loss
|
||||
self.eps = 1e-12
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
structure_probs = predicts['structure_probs']
|
||||
structure_targets = batch[1].astype("int64")
|
||||
structure_targets = structure_targets[:, 1:]
|
||||
|
||||
structure_loss = self.loss_func(structure_probs, structure_targets)
|
||||
|
||||
structure_loss = paddle.mean(structure_loss) * self.structure_weight
|
||||
|
||||
loc_preds = predicts['loc_preds']
|
||||
loc_targets = batch[2].astype("float32")
|
||||
loc_targets_mask = batch[3].astype("float32")
|
||||
loc_targets = loc_targets[:, 1:, :]
|
||||
loc_targets_mask = loc_targets_mask[:, 1:, :]
|
||||
|
||||
loc_loss = F.smooth_l1_loss(
|
||||
loc_preds * loc_targets_mask,
|
||||
loc_targets * loc_targets_mask,
|
||||
reduction='sum') * self.loc_weight
|
||||
|
||||
loc_loss = loc_loss / (loc_targets_mask.sum() + self.eps)
|
||||
total_loss = structure_loss + loc_loss
|
||||
return {
|
||||
'loss': total_loss,
|
||||
"structure_loss": structure_loss,
|
||||
"loc_loss": loc_loss
|
||||
}
|
||||
|
|
|
@ -59,7 +59,7 @@ class TableMetric(object):
|
|||
def __init__(self,
|
||||
main_indicator='acc',
|
||||
compute_bbox_metric=False,
|
||||
point_num=2,
|
||||
box_format='xyxy',
|
||||
**kwargs):
|
||||
"""
|
||||
|
||||
|
@ -70,7 +70,7 @@ class TableMetric(object):
|
|||
self.structure_metric = TableStructureMetric()
|
||||
self.bbox_metric = DetMetric() if compute_bbox_metric else None
|
||||
self.main_indicator = main_indicator
|
||||
self.point_num = point_num
|
||||
self.box_format = box_format
|
||||
self.reset()
|
||||
|
||||
def __call__(self, pred_label, batch=None, *args, **kwargs):
|
||||
|
@ -129,10 +129,14 @@ class TableMetric(object):
|
|||
self.bbox_metric.reset()
|
||||
|
||||
def format_box(self, box):
|
||||
if self.point_num == 2:
|
||||
if self.box_format == 'xyxy':
|
||||
x1, y1, x2, y2 = box
|
||||
box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
|
||||
elif self.point_num == 4:
|
||||
elif self.box_format == 'xywh':
|
||||
x, y, w, h = box
|
||||
x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
|
||||
box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
|
||||
elif self.box_format == 'xyxyxyxy':
|
||||
x1, y1, x2, y2, x3, y3, x4, y4 = box
|
||||
box = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
|
||||
return box
|
||||
|
|
|
@ -21,7 +21,10 @@ def build_backbone(config, model_type):
|
|||
from .det_resnet import ResNet
|
||||
from .det_resnet_vd import ResNet_vd
|
||||
from .det_resnet_vd_sast import ResNet_SAST
|
||||
support_dict = ["MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST"]
|
||||
from .det_pp_lcnet import PPLCNet
|
||||
support_dict = [
|
||||
"MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST", "PPLCNet"
|
||||
]
|
||||
if model_type == "table":
|
||||
from .table_master_resnet import TableResNetExtra
|
||||
support_dict.append('TableResNetExtra')
|
||||
|
|
|
@ -0,0 +1,271 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from paddle import ParamAttr
|
||||
from paddle.nn import AdaptiveAvgPool2D, BatchNorm, Conv2D, Dropout, Linear
|
||||
from paddle.regularizer import L2Decay
|
||||
from paddle.nn.initializer import KaimingNormal
|
||||
from paddle.utils.download import get_path_from_url
|
||||
|
||||
MODEL_URLS = {
|
||||
"PPLCNet_x0.25":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_25_pretrained.pdparams",
|
||||
"PPLCNet_x0.35":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_35_pretrained.pdparams",
|
||||
"PPLCNet_x0.5":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_5_pretrained.pdparams",
|
||||
"PPLCNet_x0.75":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_75_pretrained.pdparams",
|
||||
"PPLCNet_x1.0":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x1_0_pretrained.pdparams",
|
||||
"PPLCNet_x1.5":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x1_5_pretrained.pdparams",
|
||||
"PPLCNet_x2.0":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_0_pretrained.pdparams",
|
||||
"PPLCNet_x2.5":
|
||||
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_5_pretrained.pdparams"
|
||||
}
|
||||
|
||||
MODEL_STAGES_PATTERN = {
|
||||
"PPLCNet": ["blocks2", "blocks3", "blocks4", "blocks5", "blocks6"]
|
||||
}
|
||||
|
||||
__all__ = list(MODEL_URLS.keys())
|
||||
|
||||
# Each element(list) represents a depthwise block, which is composed of k, in_c, out_c, s, use_se.
|
||||
# k: kernel_size
|
||||
# in_c: input channel number in depthwise block
|
||||
# out_c: output channel number in depthwise block
|
||||
# s: stride in depthwise block
|
||||
# use_se: whether to use SE block
|
||||
|
||||
NET_CONFIG = {
|
||||
"blocks2":
|
||||
# k, in_c, out_c, s, use_se
|
||||
[[3, 16, 32, 1, False]],
|
||||
"blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]],
|
||||
"blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]],
|
||||
"blocks5":
|
||||
[[3, 128, 256, 2, False], [5, 256, 256, 1, False], [5, 256, 256, 1, False],
|
||||
[5, 256, 256, 1, False], [5, 256, 256, 1, False], [5, 256, 256, 1, False]],
|
||||
"blocks6": [[5, 256, 512, 2, True], [5, 512, 512, 1, True]]
|
||||
}
|
||||
|
||||
|
||||
def make_divisible(v, divisor=8, min_value=None):
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
filter_size,
|
||||
num_filters,
|
||||
stride,
|
||||
num_groups=1):
|
||||
super().__init__()
|
||||
|
||||
self.conv = Conv2D(
|
||||
in_channels=num_channels,
|
||||
out_channels=num_filters,
|
||||
kernel_size=filter_size,
|
||||
stride=stride,
|
||||
padding=(filter_size - 1) // 2,
|
||||
groups=num_groups,
|
||||
weight_attr=ParamAttr(initializer=KaimingNormal()),
|
||||
bias_attr=False)
|
||||
|
||||
self.bn = BatchNorm(
|
||||
num_filters,
|
||||
param_attr=ParamAttr(regularizer=L2Decay(0.0)),
|
||||
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
|
||||
self.hardswish = nn.Hardswish()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.hardswish(x)
|
||||
return x
|
||||
|
||||
|
||||
class DepthwiseSeparable(nn.Layer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
num_filters,
|
||||
stride,
|
||||
dw_size=3,
|
||||
use_se=False):
|
||||
super().__init__()
|
||||
self.use_se = use_se
|
||||
self.dw_conv = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=num_channels,
|
||||
filter_size=dw_size,
|
||||
stride=stride,
|
||||
num_groups=num_channels)
|
||||
if use_se:
|
||||
self.se = SEModule(num_channels)
|
||||
self.pw_conv = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
filter_size=1,
|
||||
num_filters=num_filters,
|
||||
stride=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.dw_conv(x)
|
||||
if self.use_se:
|
||||
x = self.se(x)
|
||||
x = self.pw_conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class SEModule(nn.Layer):
|
||||
def __init__(self, channel, reduction=4):
|
||||
super().__init__()
|
||||
self.avg_pool = AdaptiveAvgPool2D(1)
|
||||
self.conv1 = Conv2D(
|
||||
in_channels=channel,
|
||||
out_channels=channel // reduction,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.relu = nn.ReLU()
|
||||
self.conv2 = Conv2D(
|
||||
in_channels=channel // reduction,
|
||||
out_channels=channel,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.hardsigmoid = nn.Hardsigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.avg_pool(x)
|
||||
x = self.conv1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv2(x)
|
||||
x = self.hardsigmoid(x)
|
||||
x = paddle.multiply(x=identity, y=x)
|
||||
return x
|
||||
|
||||
|
||||
class PPLCNet(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
scale=1.0,
|
||||
pretrained=False,
|
||||
use_ssld=False):
|
||||
super().__init__()
|
||||
self.out_channels = [
|
||||
int(NET_CONFIG["blocks3"][-1][2] * scale),
|
||||
int(NET_CONFIG["blocks4"][-1][2] * scale),
|
||||
int(NET_CONFIG["blocks5"][-1][2] * scale),
|
||||
int(NET_CONFIG["blocks6"][-1][2] * scale)
|
||||
]
|
||||
self.scale = scale
|
||||
|
||||
self.conv1 = ConvBNLayer(
|
||||
num_channels=in_channels,
|
||||
filter_size=3,
|
||||
num_filters=make_divisible(16 * scale),
|
||||
stride=2)
|
||||
|
||||
self.blocks2 = nn.Sequential(* [
|
||||
DepthwiseSeparable(
|
||||
num_channels=make_divisible(in_c * scale),
|
||||
num_filters=make_divisible(out_c * scale),
|
||||
dw_size=k,
|
||||
stride=s,
|
||||
use_se=se)
|
||||
for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks2"])
|
||||
])
|
||||
|
||||
self.blocks3 = nn.Sequential(* [
|
||||
DepthwiseSeparable(
|
||||
num_channels=make_divisible(in_c * scale),
|
||||
num_filters=make_divisible(out_c * scale),
|
||||
dw_size=k,
|
||||
stride=s,
|
||||
use_se=se)
|
||||
for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks3"])
|
||||
])
|
||||
|
||||
self.blocks4 = nn.Sequential(* [
|
||||
DepthwiseSeparable(
|
||||
num_channels=make_divisible(in_c * scale),
|
||||
num_filters=make_divisible(out_c * scale),
|
||||
dw_size=k,
|
||||
stride=s,
|
||||
use_se=se)
|
||||
for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks4"])
|
||||
])
|
||||
|
||||
self.blocks5 = nn.Sequential(* [
|
||||
DepthwiseSeparable(
|
||||
num_channels=make_divisible(in_c * scale),
|
||||
num_filters=make_divisible(out_c * scale),
|
||||
dw_size=k,
|
||||
stride=s,
|
||||
use_se=se)
|
||||
for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks5"])
|
||||
])
|
||||
|
||||
self.blocks6 = nn.Sequential(* [
|
||||
DepthwiseSeparable(
|
||||
num_channels=make_divisible(in_c * scale),
|
||||
num_filters=make_divisible(out_c * scale),
|
||||
dw_size=k,
|
||||
stride=s,
|
||||
use_se=se)
|
||||
for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks6"])
|
||||
])
|
||||
|
||||
if pretrained:
|
||||
self._load_pretrained(
|
||||
MODEL_URLS['PPLCNet_x{}'.format(scale)], use_ssld=use_ssld)
|
||||
|
||||
def forward(self, x):
|
||||
outs = []
|
||||
x = self.conv1(x)
|
||||
x = self.blocks2(x)
|
||||
x = self.blocks3(x)
|
||||
outs.append(x)
|
||||
x = self.blocks4(x)
|
||||
outs.append(x)
|
||||
x = self.blocks5(x)
|
||||
outs.append(x)
|
||||
x = self.blocks6(x)
|
||||
outs.append(x)
|
||||
return outs
|
||||
|
||||
def _load_pretrained(self, pretrained_url, use_ssld=False):
|
||||
if use_ssld:
|
||||
pretrained_url = pretrained_url.replace("_pretrained",
|
||||
"_ssld_pretrained")
|
||||
print(pretrained_url)
|
||||
local_weight_path = get_path_from_url(
|
||||
pretrained_url, os.path.expanduser("~/.paddleclas/weights"))
|
||||
param_state_dict = paddle.load(local_weight_path)
|
||||
self.set_dict(param_state_dict)
|
||||
return
|
|
@ -42,14 +42,15 @@ def build_head(config):
|
|||
#kie head
|
||||
from .kie_sdmgr_head import SDMGRHead
|
||||
|
||||
from .table_att_head import TableAttentionHead
|
||||
from .table_att_head import TableAttentionHead, SLAHead
|
||||
from .table_master_head import TableMasterHead
|
||||
|
||||
support_dict = [
|
||||
'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead',
|
||||
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
|
||||
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
|
||||
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead'
|
||||
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
|
||||
'SLAHead'
|
||||
]
|
||||
|
||||
#table head
|
||||
|
|
|
@ -18,12 +18,26 @@ from __future__ import print_function
|
|||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from .rec_att_head import AttentionGRUCell
|
||||
|
||||
|
||||
def get_para_bias_attr(l2_decay, k):
|
||||
if l2_decay > 0:
|
||||
regularizer = paddle.regularizer.L2Decay(l2_decay)
|
||||
stdv = 1.0 / math.sqrt(k * 1.0)
|
||||
initializer = nn.initializer.Uniform(-stdv, stdv)
|
||||
else:
|
||||
regularizer = None
|
||||
initializer = None
|
||||
weight_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
|
||||
bias_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
|
||||
return [weight_attr, bias_attr]
|
||||
|
||||
|
||||
class TableAttentionHead(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
|
@ -32,7 +46,7 @@ class TableAttentionHead(nn.Layer):
|
|||
in_max_len=488,
|
||||
max_text_length=800,
|
||||
out_channels=30,
|
||||
point_num=2,
|
||||
loc_reg_num=4,
|
||||
**kwargs):
|
||||
super(TableAttentionHead, self).__init__()
|
||||
self.input_size = in_channels[-1]
|
||||
|
@ -56,7 +70,7 @@ class TableAttentionHead(nn.Layer):
|
|||
else:
|
||||
self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1)
|
||||
self.loc_generator = nn.Linear(self.input_size + hidden_size,
|
||||
point_num * 2)
|
||||
loc_reg_num)
|
||||
|
||||
def _char_to_onehot(self, input_char, onehot_dim):
|
||||
input_ont_hot = F.one_hot(input_char, onehot_dim)
|
||||
|
@ -129,3 +143,121 @@ class TableAttentionHead(nn.Layer):
|
|||
loc_preds = self.loc_generator(loc_concat)
|
||||
loc_preds = F.sigmoid(loc_preds)
|
||||
return {'structure_probs': structure_probs, 'loc_preds': loc_preds}
|
||||
|
||||
|
||||
class SLAHead(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
hidden_size,
|
||||
out_channels=30,
|
||||
max_text_length=500,
|
||||
loc_reg_num=4,
|
||||
fc_decay=0.0,
|
||||
**kwargs):
|
||||
"""
|
||||
@param in_channels: input shape
|
||||
@param hidden_size: hidden_size for RNN and Embedding
|
||||
@param out_channels: num_classes to rec
|
||||
@param max_text_length: max text pred
|
||||
"""
|
||||
super().__init__()
|
||||
in_channels = in_channels[-1]
|
||||
self.hidden_size = hidden_size
|
||||
self.max_text_length = max_text_length
|
||||
self.emb = self._char_to_onehot
|
||||
self.num_embeddings = out_channels
|
||||
|
||||
# structure
|
||||
self.structure_attention_cell = AttentionGRUCell(
|
||||
in_channels, hidden_size, self.num_embeddings)
|
||||
weight_attr, bias_attr = get_para_bias_attr(
|
||||
l2_decay=fc_decay, k=hidden_size)
|
||||
weight_attr1_1, bias_attr1_1 = get_para_bias_attr(
|
||||
l2_decay=fc_decay, k=hidden_size)
|
||||
weight_attr1_2, bias_attr1_2 = get_para_bias_attr(
|
||||
l2_decay=fc_decay, k=hidden_size)
|
||||
self.structure_generator = nn.Sequential(
|
||||
nn.Linear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
weight_attr=weight_attr1_2,
|
||||
bias_attr=bias_attr1_2),
|
||||
nn.Linear(
|
||||
hidden_size,
|
||||
out_channels,
|
||||
weight_attr=weight_attr,
|
||||
bias_attr=bias_attr))
|
||||
# loc
|
||||
weight_attr1, bias_attr1 = get_para_bias_attr(
|
||||
l2_decay=fc_decay, k=self.hidden_size)
|
||||
weight_attr2, bias_attr2 = get_para_bias_attr(
|
||||
l2_decay=fc_decay, k=self.hidden_size)
|
||||
self.loc_generator = nn.Sequential(
|
||||
nn.Linear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
weight_attr=weight_attr1,
|
||||
bias_attr=bias_attr1),
|
||||
nn.Linear(
|
||||
self.hidden_size,
|
||||
loc_reg_num,
|
||||
weight_attr=weight_attr2,
|
||||
bias_attr=bias_attr2),
|
||||
nn.Sigmoid())
|
||||
|
||||
def forward(self, inputs, targets=None):
|
||||
fea = inputs[-1]
|
||||
batch_size = fea.shape[0]
|
||||
# reshape
|
||||
fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], -1])
|
||||
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
|
||||
|
||||
hidden = paddle.zeros((batch_size, self.hidden_size))
|
||||
structure_preds = []
|
||||
loc_preds = []
|
||||
if self.training and targets is not None:
|
||||
structure = targets[0]
|
||||
for i in range(self.max_text_length + 1):
|
||||
hidden, structure_step, loc_step = self._decode(structure[:, i],
|
||||
fea, hidden)
|
||||
structure_preds.append(structure_step)
|
||||
loc_preds.append(loc_step)
|
||||
else:
|
||||
pre_chars = paddle.zeros(shape=[batch_size], dtype="int32")
|
||||
max_text_length = paddle.to_tensor(self.max_text_length)
|
||||
# for export
|
||||
loc_step, structure_step = None, None
|
||||
for i in range(max_text_length + 1):
|
||||
hidden, structure_step, loc_step = self._decode(pre_chars, fea,
|
||||
hidden)
|
||||
pre_chars = structure_step.argmax(axis=1, dtype="int32")
|
||||
structure_preds.append(structure_step)
|
||||
loc_preds.append(loc_step)
|
||||
structure_preds = paddle.stack(structure_preds, axis=1)
|
||||
loc_preds = paddle.stack(loc_preds, axis=1)
|
||||
if not self.training:
|
||||
structure_preds = F.softmax(structure_preds)
|
||||
return {'structure_probs': structure_preds, 'loc_preds': loc_preds}
|
||||
|
||||
def _decode(self, pre_chars, features, hidden):
|
||||
"""
|
||||
Predict table label and coordinates for each step
|
||||
@param pre_chars: Table label in previous step
|
||||
@param features:
|
||||
@param hidden: hidden status in previous step
|
||||
@return:
|
||||
"""
|
||||
emb_feature = self.emb(pre_chars)
|
||||
# output shape is b * self.hidden_size
|
||||
(output, hidden), alpha = self.structure_attention_cell(
|
||||
hidden, features, emb_feature)
|
||||
|
||||
# structure
|
||||
structure_step = self.structure_generator(output)
|
||||
# loc
|
||||
loc_step = self.loc_generator(output)
|
||||
return hidden, structure_step, loc_step
|
||||
|
||||
def _char_to_onehot(self, input_char):
|
||||
input_ont_hot = F.one_hot(input_char, self.num_embeddings)
|
||||
return input_ont_hot
|
||||
|
|
|
@ -37,7 +37,7 @@ class TableMasterHead(nn.Layer):
|
|||
d_ff=2048,
|
||||
dropout=0,
|
||||
max_text_length=500,
|
||||
point_num=2,
|
||||
loc_reg_num=4,
|
||||
**kwargs):
|
||||
super(TableMasterHead, self).__init__()
|
||||
hidden_size = in_channels[-1]
|
||||
|
@ -50,7 +50,7 @@ class TableMasterHead(nn.Layer):
|
|||
self.cls_fc = nn.Linear(hidden_size, out_channels)
|
||||
self.bbox_fc = nn.Sequential(
|
||||
# nn.Linear(hidden_size, hidden_size),
|
||||
nn.Linear(hidden_size, point_num * 2),
|
||||
nn.Linear(hidden_size, loc_reg_num),
|
||||
nn.Sigmoid())
|
||||
self.norm = nn.LayerNorm(hidden_size)
|
||||
self.embedding = Embeddings(d_model=hidden_size, vocab=out_channels)
|
||||
|
@ -59,7 +59,7 @@ class TableMasterHead(nn.Layer):
|
|||
self.SOS = out_channels - 3
|
||||
self.PAD = out_channels - 1
|
||||
self.out_channels = out_channels
|
||||
self.point_num = point_num
|
||||
self.loc_reg_num = loc_reg_num
|
||||
self.max_text_length = max_text_length
|
||||
|
||||
def make_mask(self, tgt):
|
||||
|
@ -105,7 +105,7 @@ class TableMasterHead(nn.Layer):
|
|||
output = paddle.zeros(
|
||||
[input.shape[0], self.max_text_length + 1, self.out_channels])
|
||||
bbox_output = paddle.zeros(
|
||||
[input.shape[0], self.max_text_length + 1, self.point_num * 2])
|
||||
[input.shape[0], self.max_text_length + 1, self.loc_reg_num])
|
||||
max_text_length = paddle.to_tensor(self.max_text_length)
|
||||
for i in range(max_text_length + 1):
|
||||
target_mask = self.make_mask(input)
|
||||
|
|
|
@ -25,9 +25,10 @@ def build_neck(config):
|
|||
from .fpn import FPN
|
||||
from .fce_fpn import FCEFPN
|
||||
from .pren_fpn import PRENFPN
|
||||
from .csp_pan import CSPPAN
|
||||
support_dict = [
|
||||
'FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN',
|
||||
'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN'
|
||||
'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN'
|
||||
]
|
||||
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -0,0 +1,325 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# The code is based on:
|
||||
# https://github.com/PaddlePaddle/PaddleDetection/blob/release%2F2.3/ppdet/modeling/necks/csp_pan.py
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle import ParamAttr
|
||||
|
||||
__all__ = ['CSPPAN']
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channel=96,
|
||||
out_channel=96,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
groups=1,
|
||||
act='leaky_relu'):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
initializer = nn.initializer.KaimingUniform()
|
||||
self.act = act
|
||||
assert self.act in ['leaky_relu', "hard_swish"]
|
||||
self.conv = nn.Conv2D(
|
||||
in_channels=in_channel,
|
||||
out_channels=out_channel,
|
||||
kernel_size=kernel_size,
|
||||
groups=groups,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
stride=stride,
|
||||
weight_attr=ParamAttr(initializer=initializer),
|
||||
bias_attr=False)
|
||||
self.bn = nn.BatchNorm2D(out_channel)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.bn(self.conv(x))
|
||||
if self.act == "leaky_relu":
|
||||
x = F.leaky_relu(x)
|
||||
elif self.act == "hard_swish":
|
||||
x = F.hardswish(x)
|
||||
return x
|
||||
|
||||
|
||||
class DPModule(nn.Layer):
|
||||
"""
|
||||
Depth-wise and point-wise module.
|
||||
Args:
|
||||
in_channel (int): The input channels of this Module.
|
||||
out_channel (int): The output channels of this Module.
|
||||
kernel_size (int): The conv2d kernel size of this Module.
|
||||
stride (int): The conv2d's stride of this Module.
|
||||
act (str): The activation function of this Module,
|
||||
Now support `leaky_relu` and `hard_swish`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channel=96,
|
||||
out_channel=96,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act='leaky_relu'):
|
||||
super(DPModule, self).__init__()
|
||||
initializer = nn.initializer.KaimingUniform()
|
||||
self.act = act
|
||||
self.dwconv = nn.Conv2D(
|
||||
in_channels=in_channel,
|
||||
out_channels=out_channel,
|
||||
kernel_size=kernel_size,
|
||||
groups=out_channel,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
stride=stride,
|
||||
weight_attr=ParamAttr(initializer=initializer),
|
||||
bias_attr=False)
|
||||
self.bn1 = nn.BatchNorm2D(out_channel)
|
||||
self.pwconv = nn.Conv2D(
|
||||
in_channels=out_channel,
|
||||
out_channels=out_channel,
|
||||
kernel_size=1,
|
||||
groups=1,
|
||||
padding=0,
|
||||
weight_attr=ParamAttr(initializer=initializer),
|
||||
bias_attr=False)
|
||||
self.bn2 = nn.BatchNorm2D(out_channel)
|
||||
|
||||
def act_func(self, x):
|
||||
if self.act == "leaky_relu":
|
||||
x = F.leaky_relu(x)
|
||||
elif self.act == "hard_swish":
|
||||
x = F.hardswish(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.act_func(self.bn1(self.dwconv(x)))
|
||||
x = self.act_func(self.bn2(self.pwconv(x)))
|
||||
return x
|
||||
|
||||
|
||||
class DarknetBottleneck(nn.Layer):
|
||||
"""The basic bottleneck block used in Darknet.
|
||||
Each Block consists of two ConvModules and the input is added to the
|
||||
final output. Each ConvModule is composed of Conv, BN, and act.
|
||||
The first convLayer has filter size of 1x1 and the second one has the
|
||||
filter size of 3x3.
|
||||
Args:
|
||||
in_channels (int): The input channels of this Module.
|
||||
out_channels (int): The output channels of this Module.
|
||||
expansion (int): The kernel size of the convolution. Default: 0.5
|
||||
add_identity (bool): Whether to add identity to the out.
|
||||
Default: True
|
||||
use_depthwise (bool): Whether to use depthwise separable convolution.
|
||||
Default: False
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
expansion=0.5,
|
||||
add_identity=True,
|
||||
use_depthwise=False,
|
||||
act="leaky_relu"):
|
||||
super(DarknetBottleneck, self).__init__()
|
||||
hidden_channels = int(out_channels * expansion)
|
||||
conv_func = DPModule if use_depthwise else ConvBNLayer
|
||||
self.conv1 = ConvBNLayer(
|
||||
in_channel=in_channels,
|
||||
out_channel=hidden_channels,
|
||||
kernel_size=1,
|
||||
act=act)
|
||||
self.conv2 = conv_func(
|
||||
in_channel=hidden_channels,
|
||||
out_channel=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
act=act)
|
||||
self.add_identity = \
|
||||
add_identity and in_channels == out_channels
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
out = self.conv1(x)
|
||||
out = self.conv2(out)
|
||||
|
||||
if self.add_identity:
|
||||
return out + identity
|
||||
else:
|
||||
return out
|
||||
|
||||
|
||||
class CSPLayer(nn.Layer):
|
||||
"""Cross Stage Partial Layer.
|
||||
Args:
|
||||
in_channels (int): The input channels of the CSP layer.
|
||||
out_channels (int): The output channels of the CSP layer.
|
||||
expand_ratio (float): Ratio to adjust the number of channels of the
|
||||
hidden layer. Default: 0.5
|
||||
num_blocks (int): Number of blocks. Default: 1
|
||||
add_identity (bool): Whether to add identity in blocks.
|
||||
Default: True
|
||||
use_depthwise (bool): Whether to depthwise separable convolution in
|
||||
blocks. Default: False
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
expand_ratio=0.5,
|
||||
num_blocks=1,
|
||||
add_identity=True,
|
||||
use_depthwise=False,
|
||||
act="leaky_relu"):
|
||||
super().__init__()
|
||||
mid_channels = int(out_channels * expand_ratio)
|
||||
self.main_conv = ConvBNLayer(in_channels, mid_channels, 1, act=act)
|
||||
self.short_conv = ConvBNLayer(in_channels, mid_channels, 1, act=act)
|
||||
self.final_conv = ConvBNLayer(
|
||||
2 * mid_channels, out_channels, 1, act=act)
|
||||
|
||||
self.blocks = nn.Sequential(* [
|
||||
DarknetBottleneck(
|
||||
mid_channels,
|
||||
mid_channels,
|
||||
kernel_size,
|
||||
1.0,
|
||||
add_identity,
|
||||
use_depthwise,
|
||||
act=act) for _ in range(num_blocks)
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
x_short = self.short_conv(x)
|
||||
|
||||
x_main = self.main_conv(x)
|
||||
x_main = self.blocks(x_main)
|
||||
|
||||
x_final = paddle.concat((x_main, x_short), axis=1)
|
||||
return self.final_conv(x_final)
|
||||
|
||||
|
||||
class Channel_T(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels=[116, 232, 464],
|
||||
out_channels=96,
|
||||
act="leaky_relu"):
|
||||
super(Channel_T, self).__init__()
|
||||
self.convs = nn.LayerList()
|
||||
for i in range(len(in_channels)):
|
||||
self.convs.append(
|
||||
ConvBNLayer(
|
||||
in_channels[i], out_channels, 1, act=act))
|
||||
|
||||
def forward(self, x):
|
||||
outs = [self.convs[i](x[i]) for i in range(len(x))]
|
||||
return outs
|
||||
|
||||
|
||||
class CSPPAN(nn.Layer):
|
||||
"""Path Aggregation Network with CSP module.
|
||||
Args:
|
||||
in_channels (List[int]): Number of input channels per scale.
|
||||
out_channels (int): Number of output channels (used at each scale)
|
||||
kernel_size (int): The conv2d kernel size of this Module.
|
||||
num_csp_blocks (int): Number of bottlenecks in CSPLayer. Default: 1
|
||||
use_depthwise (bool): Whether to depthwise separable convolution in
|
||||
blocks. Default: True
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=5,
|
||||
num_csp_blocks=1,
|
||||
use_depthwise=True,
|
||||
act='hard_swish'):
|
||||
super(CSPPAN, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = [out_channels] * len(in_channels)
|
||||
conv_func = DPModule if use_depthwise else ConvBNLayer
|
||||
|
||||
self.conv_t = Channel_T(in_channels, out_channels, act=act)
|
||||
|
||||
# build top-down blocks
|
||||
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
|
||||
self.top_down_blocks = nn.LayerList()
|
||||
for idx in range(len(in_channels) - 1, 0, -1):
|
||||
self.top_down_blocks.append(
|
||||
CSPLayer(
|
||||
out_channels * 2,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
num_blocks=num_csp_blocks,
|
||||
add_identity=False,
|
||||
use_depthwise=use_depthwise,
|
||||
act=act))
|
||||
|
||||
# build bottom-up blocks
|
||||
self.downsamples = nn.LayerList()
|
||||
self.bottom_up_blocks = nn.LayerList()
|
||||
for idx in range(len(in_channels) - 1):
|
||||
self.downsamples.append(
|
||||
conv_func(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=2,
|
||||
act=act))
|
||||
self.bottom_up_blocks.append(
|
||||
CSPLayer(
|
||||
out_channels * 2,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
num_blocks=num_csp_blocks,
|
||||
add_identity=False,
|
||||
use_depthwise=use_depthwise,
|
||||
act=act))
|
||||
|
||||
def forward(self, inputs):
|
||||
"""
|
||||
Args:
|
||||
inputs (tuple[Tensor]): input features.
|
||||
Returns:
|
||||
tuple[Tensor]: CSPPAN features.
|
||||
"""
|
||||
assert len(inputs) == len(self.in_channels)
|
||||
inputs = self.conv_t(inputs)
|
||||
|
||||
# top-down path
|
||||
inner_outs = [inputs[-1]]
|
||||
for idx in range(len(self.in_channels) - 1, 0, -1):
|
||||
feat_heigh = inner_outs[0]
|
||||
feat_low = inputs[idx - 1]
|
||||
|
||||
upsample_feat = F.upsample(
|
||||
feat_heigh, size=feat_low.shape[2:4], mode="nearest")
|
||||
|
||||
inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx](
|
||||
paddle.concat([upsample_feat, feat_low], 1))
|
||||
inner_outs.insert(0, inner_out)
|
||||
|
||||
# bottom-up path
|
||||
outs = [inner_outs[0]]
|
||||
for idx in range(len(self.in_channels) - 1):
|
||||
feat_low = outs[-1]
|
||||
feat_height = inner_outs[idx + 1]
|
||||
downsample_feat = self.downsamples[idx](feat_low)
|
||||
out = self.bottom_up_blocks[idx](paddle.concat(
|
||||
[downsample_feat, feat_height], 1))
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
|
@ -23,7 +23,7 @@ class TableLabelDecode(AttnLabelDecode):
|
|||
|
||||
def __init__(self, character_dict_path, **kwargs):
|
||||
super(TableLabelDecode, self).__init__(character_dict_path)
|
||||
self.td_token = ['<td>', '<td', '<eb></eb>', '<td></td>']
|
||||
self.td_token = ['<td>', '<td', '<td></td>']
|
||||
|
||||
def __call__(self, preds, batch=None):
|
||||
structure_probs = preds['structure_probs']
|
||||
|
@ -114,10 +114,8 @@ class TableLabelDecode(AttnLabelDecode):
|
|||
|
||||
def _bbox_decode(self, bbox, shape):
|
||||
h, w, ratio_h, ratio_w, pad_h, pad_w = shape
|
||||
src_h = h / ratio_h
|
||||
src_w = w / ratio_w
|
||||
bbox[0::2] *= src_w
|
||||
bbox[1::2] *= src_h
|
||||
bbox[0::2] *= w
|
||||
bbox[1::2] *= h
|
||||
return bbox
|
||||
|
||||
|
||||
|
@ -157,4 +155,7 @@ class TableMasterLabelDecode(TableLabelDecode):
|
|||
bbox[1::2] *= h
|
||||
bbox[0::2] /= ratio_w
|
||||
bbox[1::2] /= ratio_h
|
||||
x, y, w, h = bbox
|
||||
x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
|
||||
bbox = np.array([x1, y1, x2, y2])
|
||||
return bbox
|
||||
|
|
|
@ -113,14 +113,10 @@ def draw_re_results(image,
|
|||
return np.array(img_new)
|
||||
|
||||
|
||||
def draw_rectangle(img_path, boxes, use_xywh=False):
|
||||
def draw_rectangle(img_path, boxes):
|
||||
img = cv2.imread(img_path)
|
||||
img_show = img.copy()
|
||||
for box in boxes.astype(int):
|
||||
if use_xywh:
|
||||
x, y, w, h = box
|
||||
x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
|
||||
else:
|
||||
x1, y1, x2, y2 = box
|
||||
x1, y1, x2, y2 = box
|
||||
cv2.rectangle(img_show, (x1, y1), (x2, y2), (255, 0, 0), 2)
|
||||
return img_show
|
|
@ -0,0 +1,227 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
from scipy.special import softmax
|
||||
|
||||
|
||||
def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200):
|
||||
"""
|
||||
Args:
|
||||
box_scores (N, 5): boxes in corner-form and probabilities.
|
||||
iou_threshold: intersection over union threshold.
|
||||
top_k: keep top_k results. If k <= 0, keep all the results.
|
||||
candidate_size: only consider the candidates with the highest scores.
|
||||
Returns:
|
||||
picked: a list of indexes of the kept boxes
|
||||
"""
|
||||
scores = box_scores[:, -1]
|
||||
boxes = box_scores[:, :-1]
|
||||
picked = []
|
||||
indexes = np.argsort(scores)
|
||||
indexes = indexes[-candidate_size:]
|
||||
while len(indexes) > 0:
|
||||
current = indexes[-1]
|
||||
picked.append(current)
|
||||
if 0 < top_k == len(picked) or len(indexes) == 1:
|
||||
break
|
||||
current_box = boxes[current, :]
|
||||
indexes = indexes[:-1]
|
||||
rest_boxes = boxes[indexes, :]
|
||||
iou = iou_of(
|
||||
rest_boxes,
|
||||
np.expand_dims(
|
||||
current_box, axis=0), )
|
||||
indexes = indexes[iou <= iou_threshold]
|
||||
|
||||
return box_scores[picked, :]
|
||||
|
||||
|
||||
def iou_of(boxes0, boxes1, eps=1e-5):
|
||||
"""Return intersection-over-union (Jaccard index) of boxes.
|
||||
Args:
|
||||
boxes0 (N, 4): ground truth boxes.
|
||||
boxes1 (N or 1, 4): predicted boxes.
|
||||
eps: a small number to avoid 0 as denominator.
|
||||
Returns:
|
||||
iou (N): IoU values.
|
||||
"""
|
||||
overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2])
|
||||
overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:])
|
||||
|
||||
overlap_area = area_of(overlap_left_top, overlap_right_bottom)
|
||||
area0 = area_of(boxes0[..., :2], boxes0[..., 2:])
|
||||
area1 = area_of(boxes1[..., :2], boxes1[..., 2:])
|
||||
return overlap_area / (area0 + area1 - overlap_area + eps)
|
||||
|
||||
|
||||
def area_of(left_top, right_bottom):
|
||||
"""Compute the areas of rectangles given two corners.
|
||||
Args:
|
||||
left_top (N, 2): left top corner.
|
||||
right_bottom (N, 2): right bottom corner.
|
||||
Returns:
|
||||
area (N): return the area.
|
||||
"""
|
||||
hw = np.clip(right_bottom - left_top, 0.0, None)
|
||||
return hw[..., 0] * hw[..., 1]
|
||||
|
||||
|
||||
class PicoDetPostProcess(object):
|
||||
"""
|
||||
Args:
|
||||
input_shape (int): network input image size
|
||||
ori_shape (int): ori image shape of before padding
|
||||
scale_factor (float): scale factor of ori image
|
||||
enable_mkldnn (bool): whether to open MKLDNN
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_shape,
|
||||
ori_shape,
|
||||
scale_factor,
|
||||
strides=[8, 16, 32, 64],
|
||||
score_threshold=0.4,
|
||||
nms_threshold=0.5,
|
||||
nms_top_k=1000,
|
||||
keep_top_k=100):
|
||||
self.ori_shape = ori_shape
|
||||
self.input_shape = input_shape
|
||||
self.scale_factor = scale_factor
|
||||
self.strides = strides
|
||||
self.score_threshold = score_threshold
|
||||
self.nms_threshold = nms_threshold
|
||||
self.nms_top_k = nms_top_k
|
||||
self.keep_top_k = keep_top_k
|
||||
|
||||
def warp_boxes(self, boxes, ori_shape):
|
||||
"""Apply transform to boxes
|
||||
"""
|
||||
width, height = ori_shape[1], ori_shape[0]
|
||||
n = len(boxes)
|
||||
if n:
|
||||
# warp points
|
||||
xy = np.ones((n * 4, 3))
|
||||
xy[:, :2] = boxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(
|
||||
n * 4, 2) # x1y1, x2y2, x1y2, x2y1
|
||||
# xy = xy @ M.T # transform
|
||||
xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale
|
||||
# create new boxes
|
||||
x = xy[:, [0, 2, 4, 6]]
|
||||
y = xy[:, [1, 3, 5, 7]]
|
||||
xy = np.concatenate(
|
||||
(x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
|
||||
# clip boxes
|
||||
xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
|
||||
xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
|
||||
return xy.astype(np.float32)
|
||||
else:
|
||||
return boxes
|
||||
|
||||
def __call__(self, scores, raw_boxes):
|
||||
batch_size = raw_boxes[0].shape[0]
|
||||
reg_max = int(raw_boxes[0].shape[-1] / 4 - 1)
|
||||
out_boxes_num = []
|
||||
out_boxes_list = []
|
||||
for batch_id in range(batch_size):
|
||||
# generate centers
|
||||
decode_boxes = []
|
||||
select_scores = []
|
||||
for stride, box_distribute, score in zip(self.strides, raw_boxes,
|
||||
scores):
|
||||
box_distribute = box_distribute[batch_id]
|
||||
score = score[batch_id]
|
||||
# centers
|
||||
fm_h = self.input_shape[0] / stride
|
||||
fm_w = self.input_shape[1] / stride
|
||||
h_range = np.arange(fm_h)
|
||||
w_range = np.arange(fm_w)
|
||||
ww, hh = np.meshgrid(w_range, h_range)
|
||||
ct_row = (hh.flatten() + 0.5) * stride
|
||||
ct_col = (ww.flatten() + 0.5) * stride
|
||||
center = np.stack((ct_col, ct_row, ct_col, ct_row), axis=1)
|
||||
|
||||
# box distribution to distance
|
||||
reg_range = np.arange(reg_max + 1)
|
||||
box_distance = box_distribute.reshape((-1, reg_max + 1))
|
||||
box_distance = softmax(box_distance, axis=1)
|
||||
box_distance = box_distance * np.expand_dims(reg_range, axis=0)
|
||||
box_distance = np.sum(box_distance, axis=1).reshape((-1, 4))
|
||||
box_distance = box_distance * stride
|
||||
|
||||
# top K candidate
|
||||
topk_idx = np.argsort(score.max(axis=1))[::-1]
|
||||
topk_idx = topk_idx[:self.nms_top_k]
|
||||
center = center[topk_idx]
|
||||
score = score[topk_idx]
|
||||
box_distance = box_distance[topk_idx]
|
||||
|
||||
# decode box
|
||||
decode_box = center + [-1, -1, 1, 1] * box_distance
|
||||
|
||||
select_scores.append(score)
|
||||
decode_boxes.append(decode_box)
|
||||
|
||||
# nms
|
||||
bboxes = np.concatenate(decode_boxes, axis=0)
|
||||
confidences = np.concatenate(select_scores, axis=0)
|
||||
picked_box_probs = []
|
||||
picked_labels = []
|
||||
for class_index in range(0, confidences.shape[1]):
|
||||
probs = confidences[:, class_index]
|
||||
mask = probs > self.score_threshold
|
||||
probs = probs[mask]
|
||||
if probs.shape[0] == 0:
|
||||
continue
|
||||
subset_boxes = bboxes[mask, :]
|
||||
box_probs = np.concatenate(
|
||||
[subset_boxes, probs.reshape(-1, 1)], axis=1)
|
||||
box_probs = hard_nms(
|
||||
box_probs,
|
||||
iou_threshold=self.nms_threshold,
|
||||
top_k=self.keep_top_k, )
|
||||
picked_box_probs.append(box_probs)
|
||||
picked_labels.extend([class_index] * box_probs.shape[0])
|
||||
|
||||
if len(picked_box_probs) == 0:
|
||||
out_boxes_list.append(np.empty((0, 4)))
|
||||
out_boxes_num.append(0)
|
||||
|
||||
else:
|
||||
picked_box_probs = np.concatenate(picked_box_probs)
|
||||
|
||||
# resize output boxes
|
||||
picked_box_probs[:, :4] = self.warp_boxes(
|
||||
picked_box_probs[:, :4], self.ori_shape[batch_id])
|
||||
im_scale = np.concatenate([
|
||||
self.scale_factor[batch_id][::-1],
|
||||
self.scale_factor[batch_id][::-1]
|
||||
])
|
||||
picked_box_probs[:, :4] /= im_scale
|
||||
# clas score box
|
||||
out_boxes_list.append(
|
||||
np.concatenate(
|
||||
[
|
||||
np.expand_dims(
|
||||
np.array(picked_labels),
|
||||
axis=-1), np.expand_dims(
|
||||
picked_box_probs[:, 4], axis=-1),
|
||||
picked_box_probs[:, :4]
|
||||
],
|
||||
axis=1))
|
||||
out_boxes_num.append(len(picked_labels))
|
||||
|
||||
out_boxes_list = np.concatenate(out_boxes_list, axis=0)
|
||||
out_boxes_num = np.asarray(out_boxes_num).astype(np.int32)
|
||||
return out_boxes_list, out_boxes_num
|
|
@ -0,0 +1,155 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
import tools.infer.utility as utility
|
||||
from ppocr.data import create_operators, transform
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.utils.logging import get_logger
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
from ppstructure.utility import parse_args
|
||||
from picodet_postprocess import PicoDetPostProcess
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class LayoutPredictor(object):
|
||||
def __init__(self, args):
|
||||
pre_process_list = [{
|
||||
'Resize': {
|
||||
'size': [800, 608]
|
||||
}
|
||||
}, {
|
||||
'NormalizeImage': {
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'scale': '1./255.',
|
||||
'order': 'hwc'
|
||||
}
|
||||
}, {
|
||||
'ToCHWImage': None
|
||||
}, {
|
||||
'KeepKeys': {
|
||||
'keep_keys': ['image']
|
||||
}
|
||||
}]
|
||||
# postprocess_params = {
|
||||
# 'name': 'LayoutPostProcess',
|
||||
# "character_dict_path": args.layout_dict_path,
|
||||
# }
|
||||
|
||||
self.preprocess_op = create_operators(pre_process_list)
|
||||
# self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
||||
utility.create_predictor(args, 'layout', logger)
|
||||
|
||||
def __call__(self, img):
|
||||
ori_im = img.copy()
|
||||
data = {'image': img}
|
||||
data = transform(data, self.preprocess_op)
|
||||
img = data[0]
|
||||
|
||||
if img is None:
|
||||
return None, 0
|
||||
|
||||
img = np.expand_dims(img, axis=0)
|
||||
img = img.copy()
|
||||
|
||||
preds, elapse = 0, 1
|
||||
starttime = time.time()
|
||||
|
||||
self.input_tensor.copy_from_cpu(img)
|
||||
self.predictor.run()
|
||||
|
||||
# outputs = []
|
||||
# for output_tensor in self.output_tensors:
|
||||
# output = output_tensor.copy_to_cpu()
|
||||
# outputs.append(output)
|
||||
np_score_list, np_boxes_list = [], []
|
||||
output_names = self.predictor.get_output_names()
|
||||
num_outs = int(len(output_names) / 2)
|
||||
for out_idx in range(num_outs):
|
||||
np_score_list.append(
|
||||
self.predictor.get_output_handle(output_names[out_idx])
|
||||
.copy_to_cpu())
|
||||
np_boxes_list.append(
|
||||
self.predictor.get_output_handle(output_names[
|
||||
out_idx + num_outs]).copy_to_cpu())
|
||||
# result = dict(boxes=np_score_list, boxes_num=np_boxes_list)
|
||||
postprocessor = PicoDetPostProcess(
|
||||
(800, 608), [[800., 608.]],
|
||||
np.array([[1.010101, 0.99346405]]),
|
||||
strides=[8, 16, 32, 64],
|
||||
nms_threshold=0.5)
|
||||
np_boxes, np_boxes_num = postprocessor(np_score_list, np_boxes_list)
|
||||
result = dict(boxes=np_boxes, boxes_num=np_boxes_num)
|
||||
# print(result)
|
||||
im_bboxes_num = result['boxes_num'][0]
|
||||
# print('im_bboxes_num:',im_bboxes_num)
|
||||
|
||||
bboxs = result['boxes'][0:0 + im_bboxes_num, :]
|
||||
threshold = 0.5
|
||||
expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
|
||||
np_boxes = np_boxes[expect_boxes, :]
|
||||
preds = []
|
||||
|
||||
id2label = {1: 'text', 2: 'title', 3: 'list', 4: 'table', 5: 'figure'}
|
||||
for dt in np_boxes:
|
||||
clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
|
||||
label = id2label[clsid + 1]
|
||||
result_di = {'bbox': bbox, 'label': label}
|
||||
preds.append(result_di)
|
||||
# print('result_di',result_di)
|
||||
# print('clsid, bbox, score:',clsid, bbox, score)
|
||||
|
||||
elapse = time.time() - starttime
|
||||
return preds, elapse
|
||||
|
||||
|
||||
def main(args):
|
||||
image_file_list = get_image_file_list(args.image_dir)
|
||||
layout_predictor = LayoutPredictor(args)
|
||||
count = 0
|
||||
total_time = 0
|
||||
|
||||
for image_file in image_file_list:
|
||||
img, flag = check_and_read_gif(image_file)
|
||||
if not flag:
|
||||
img = cv2.imread(image_file)
|
||||
if img is None:
|
||||
logger.info("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
layout_res, elapse = layout_predictor(img)
|
||||
|
||||
logger.info("result: {}".format(layout_res))
|
||||
|
||||
if count > 0:
|
||||
total_time += elapse
|
||||
count += 1
|
||||
logger.info("Predict time of {}: {}".format(image_file, elapse))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(parse_args())
|
|
@ -18,7 +18,7 @@ import subprocess
|
|||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../')))
|
||||
|
||||
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||
import cv2
|
||||
|
@ -32,6 +32,7 @@ from attrdict import AttrDict
|
|||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
from ppocr.utils.logging import get_logger
|
||||
from tools.infer.predict_system import TextSystem
|
||||
from ppstructure.layout.predict_layout import LayoutPredictor
|
||||
from ppstructure.table.predict_table import TableSystem, to_excel
|
||||
from ppstructure.utility import parse_args, draw_structure_result
|
||||
from ppstructure.recovery.recovery_to_doc import convert_info_docx
|
||||
|
@ -51,28 +52,14 @@ class StructureSystem(object):
|
|||
"When args.layout is false, args.ocr is automatically set to false"
|
||||
)
|
||||
args.drop_score = 0
|
||||
# init layout and ocr model
|
||||
# init model
|
||||
self.layout_predictor = None
|
||||
self.text_system = None
|
||||
self.table_system = None
|
||||
if args.layout:
|
||||
import layoutparser as lp
|
||||
config_path = None
|
||||
model_path = None
|
||||
if os.path.isdir(args.layout_path_model):
|
||||
model_path = args.layout_path_model
|
||||
else:
|
||||
config_path = args.layout_path_model
|
||||
self.table_layout = lp.PaddleDetectionLayoutModel(
|
||||
config_path=config_path,
|
||||
model_path=model_path,
|
||||
label_map=args.layout_label_map,
|
||||
threshold=0.5,
|
||||
enable_mkldnn=args.enable_mkldnn,
|
||||
enforce_cpu=not args.use_gpu,
|
||||
thread_num=args.cpu_threads)
|
||||
self.layout_predictor = LayoutPredictor(args)
|
||||
if args.ocr:
|
||||
self.text_system = TextSystem(args)
|
||||
else:
|
||||
self.table_layout = None
|
||||
if args.table:
|
||||
if self.text_system is not None:
|
||||
self.table_system = TableSystem(
|
||||
|
@ -80,38 +67,59 @@ class StructureSystem(object):
|
|||
self.text_system.text_recognizer)
|
||||
else:
|
||||
self.table_system = TableSystem(args)
|
||||
else:
|
||||
self.table_system = None
|
||||
|
||||
elif self.mode == 'vqa':
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, img, return_ocr_result_in_table=False):
|
||||
time_dict = {
|
||||
'layout': 0,
|
||||
'table': 0,
|
||||
'table_match': 0,
|
||||
'det': 0,
|
||||
'rec': 0,
|
||||
'vqa': 0,
|
||||
'all': 0
|
||||
}
|
||||
start = time.time()
|
||||
if self.mode == 'structure':
|
||||
ori_im = img.copy()
|
||||
if self.table_layout is not None:
|
||||
layout_res = self.table_layout.detect(img[..., ::-1])
|
||||
if self.layout_predictor is not None:
|
||||
layout_res, elapse = self.layout_predictor(img)
|
||||
time_dict['layout'] += elapse
|
||||
else:
|
||||
h, w = ori_im.shape[:2]
|
||||
layout_res = [AttrDict(coordinates=[0, 0, w, h], type='Table')]
|
||||
layout_res = [dict(bbox=None, label='table')]
|
||||
res_list = []
|
||||
for region in layout_res:
|
||||
res = ''
|
||||
x1, y1, x2, y2 = region.coordinates
|
||||
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
||||
roi_img = ori_im[y1:y2, x1:x2, :]
|
||||
if region.type == 'Table':
|
||||
if region['bbox'] is not None:
|
||||
x1, y1, x2, y2 = region['bbox']
|
||||
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
||||
roi_img = ori_im[y1:y2, x1:x2, :]
|
||||
else:
|
||||
x1, y1, x2, y2 = 0, 0, w, h
|
||||
roi_img = ori_im
|
||||
if region['label'] == 'table':
|
||||
if self.table_system is not None:
|
||||
res = self.table_system(roi_img,
|
||||
return_ocr_result_in_table)
|
||||
res, table_time_dict = self.table_system(
|
||||
roi_img, return_ocr_result_in_table)
|
||||
time_dict['table'] += table_time_dict['table']
|
||||
time_dict['table_match'] += table_time_dict['match']
|
||||
time_dict['det'] += table_time_dict['det']
|
||||
time_dict['rec'] += table_time_dict['rec']
|
||||
else:
|
||||
if self.text_system is not None:
|
||||
if args.recovery:
|
||||
wht_im = np.ones(ori_im.shape, dtype=ori_im.dtype)
|
||||
wht_im[y1:y2, x1:x2, :] = roi_img
|
||||
filter_boxes, filter_rec_res = self.text_system(wht_im)
|
||||
filter_boxes, filter_rec_res, ocr_time_dict = self.text_system(
|
||||
wht_im)
|
||||
else:
|
||||
filter_boxes, filter_rec_res = self.text_system(roi_img)
|
||||
filter_boxes, filter_rec_res, ocr_time_dict = self.text_system(
|
||||
roi_img)
|
||||
time_dict['det'] += ocr_time_dict['det']
|
||||
time_dict['rec'] += ocr_time_dict['rec']
|
||||
# remove style char
|
||||
style_token = [
|
||||
'<strike>', '<strike>', '<sup>', '</sub>', '<b>',
|
||||
|
@ -133,15 +141,17 @@ class StructureSystem(object):
|
|||
'text_region': box.tolist()
|
||||
})
|
||||
res_list.append({
|
||||
'type': region.type,
|
||||
'type': region['label'].lower(),
|
||||
'bbox': [x1, y1, x2, y2],
|
||||
'img': roi_img,
|
||||
'res': res
|
||||
})
|
||||
return res_list
|
||||
end = time.time()
|
||||
time_dict['all'] = end - start
|
||||
return res_list, time_dict
|
||||
elif self.mode == 'vqa':
|
||||
raise NotImplementedError
|
||||
return None
|
||||
return None, None
|
||||
|
||||
|
||||
def save_structure_res(res, save_folder, img_name):
|
||||
|
@ -156,12 +166,12 @@ def save_structure_res(res, save_folder, img_name):
|
|||
roi_img = region.pop('img')
|
||||
f.write('{}\n'.format(json.dumps(region)))
|
||||
|
||||
if region['type'] == 'Table' and len(region[
|
||||
if region['type'] == 'table' and len(region[
|
||||
'res']) > 0 and 'html' in region['res']:
|
||||
excel_path = os.path.join(excel_save_folder,
|
||||
'{}.xlsx'.format(region['bbox']))
|
||||
to_excel(region['res']['html'], excel_path)
|
||||
elif region['type'] == 'Figure':
|
||||
elif region['type'] == 'figure':
|
||||
img_path = os.path.join(excel_save_folder,
|
||||
'{}.jpg'.format(region['bbox']))
|
||||
cv2.imwrite(img_path, roi_img)
|
||||
|
@ -188,7 +198,7 @@ def main(args):
|
|||
logger.error("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
starttime = time.time()
|
||||
res = structure_sys(img)
|
||||
res, time_dict = structure_sys(img)
|
||||
|
||||
if structure_sys.mode == 'structure':
|
||||
save_structure_res(res, save_folder, img_name)
|
||||
|
@ -201,7 +211,7 @@ def main(args):
|
|||
cv2.imwrite(img_save_path, draw_img)
|
||||
logger.info('result save to {}'.format(img_save_path))
|
||||
if args.recovery:
|
||||
convert_info_docx(img, res, save_folder, img_name)
|
||||
convert_info_docx(img, res, save_folder, img_name)
|
||||
elapse = time.time() - starttime
|
||||
logger.info("Predict time : {:.3f}s".format(elapse))
|
||||
|
||||
|
|
|
@ -13,12 +13,14 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
import cv2
|
||||
import json
|
||||
import pickle
|
||||
import paddle
|
||||
from tqdm import tqdm
|
||||
from ppstructure.table.table_metric import TEDS
|
||||
from ppstructure.table.predict_table import TableSystem
|
||||
|
@ -33,40 +35,74 @@ def parse_args():
|
|||
parser.add_argument("--gt_path", type=str)
|
||||
return parser.parse_args()
|
||||
|
||||
def main(gt_path, img_root, args):
|
||||
teds = TEDS(n_jobs=16)
|
||||
|
||||
def load_txt(txt_path):
|
||||
pred_html_dict = {}
|
||||
if not os.path.exists(txt_path):
|
||||
return pred_html_dict
|
||||
with open(txt_path, encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
line = line.strip().split('\t')
|
||||
img_name, pred_html = line
|
||||
pred_html_dict[img_name] = pred_html
|
||||
return pred_html_dict
|
||||
|
||||
|
||||
def load_result(path):
|
||||
data = {}
|
||||
if os.path.exists(path):
|
||||
data = pickle.load(open(path, 'rb'))
|
||||
return data
|
||||
|
||||
|
||||
def save_result(path, data):
|
||||
old_data = load_result(path)
|
||||
old_data.update(data)
|
||||
with open(path, 'wb') as f:
|
||||
pickle.dump(old_data, f)
|
||||
|
||||
|
||||
def main(gt_path, img_root, args):
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
# init TableSystem
|
||||
text_sys = TableSystem(args)
|
||||
jsons_gt = json.load(open(gt_path)) # gt
|
||||
# load gt and preds html result
|
||||
gt_html_dict = load_txt(gt_path)
|
||||
|
||||
ocr_result = load_result(os.path.join(args.output, 'ocr.pickle'))
|
||||
structure_result = load_result(
|
||||
os.path.join(args.output, 'structure.pickle'))
|
||||
|
||||
pred_htmls = []
|
||||
gt_htmls = []
|
||||
for img_name in tqdm(jsons_gt):
|
||||
# read image
|
||||
img = cv2.imread(os.path.join(img_root,img_name))
|
||||
pred_html = text_sys(img)
|
||||
for img_name, gt_html in tqdm(gt_html_dict.items()):
|
||||
img = cv2.imread(os.path.join(img_root, img_name))
|
||||
# run ocr and save result
|
||||
if img_name not in ocr_result:
|
||||
dt_boxes, rec_res, _, _ = text_sys._ocr(img)
|
||||
ocr_result[img_name] = [dt_boxes, rec_res]
|
||||
save_result(os.path.join(args.output, 'ocr.pickle'), ocr_result)
|
||||
# run structure and save result
|
||||
if img_name not in structure_result:
|
||||
structure_res, _ = text_sys._structure(img)
|
||||
structure_result[img_name] = structure_res
|
||||
save_result(
|
||||
os.path.join(args.output, 'structure.pickle'), structure_result)
|
||||
dt_boxes, rec_res = ocr_result[img_name]
|
||||
structure_res = structure_result[img_name]
|
||||
# match ocr and structure
|
||||
pred_html = text_sys.match(structure_res, dt_boxes, rec_res)
|
||||
|
||||
pred_htmls.append(pred_html)
|
||||
|
||||
gt_structures, gt_bboxes, gt_contents = jsons_gt[img_name]
|
||||
gt_html, gt = get_gt_html(gt_structures, gt_contents)
|
||||
gt_htmls.append(gt_html)
|
||||
|
||||
# compute teds
|
||||
teds = TEDS(n_jobs=16)
|
||||
scores = teds.batch_evaluate_html(gt_htmls, pred_htmls)
|
||||
logger.info('teds:', sum(scores) / len(scores))
|
||||
|
||||
|
||||
def get_gt_html(gt_structures, gt_contents):
|
||||
end_html = []
|
||||
td_index = 0
|
||||
for tag in gt_structures:
|
||||
if '</td>' in tag:
|
||||
if gt_contents[td_index] != []:
|
||||
end_html.extend(gt_contents[td_index])
|
||||
end_html.append(tag)
|
||||
td_index += 1
|
||||
else:
|
||||
end_html.append(tag)
|
||||
return ''.join(end_html), end_html
|
||||
logger.info('teds: {}'.format(sum(scores) / len(scores)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
main(args.gt_path,args.image_dir, args)
|
||||
main(args.gt_path, args.image_dir, args)
|
||||
|
|
|
@ -1,11 +1,15 @@
|
|||
import json
|
||||
from ppstructure.table.table_master_match import deal_eb_token, deal_bb
|
||||
|
||||
|
||||
def distance(box_1, box_2):
|
||||
x1, y1, x2, y2 = box_1
|
||||
x3, y3, x4, y4 = box_2
|
||||
dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2)
|
||||
dis_2 = abs(x3 - x1) + abs(y3 - y1)
|
||||
dis_3 = abs(x4- x2) + abs(y4 - y2)
|
||||
return dis + min(dis_2, dis_3)
|
||||
x1, y1, x2, y2 = box_1
|
||||
x3, y3, x4, y4 = box_2
|
||||
dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
|
||||
dis_2 = abs(x3 - x1) + abs(y3 - y1)
|
||||
dis_3 = abs(x4 - x2) + abs(y4 - y2)
|
||||
return dis + min(dis_2, dis_3)
|
||||
|
||||
|
||||
def compute_iou(rec1, rec2):
|
||||
"""
|
||||
|
@ -18,23 +22,22 @@ def compute_iou(rec1, rec2):
|
|||
# computing area of each rectangles
|
||||
S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
|
||||
S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
|
||||
|
||||
|
||||
# computing the sum_area
|
||||
sum_area = S_rec1 + S_rec2
|
||||
|
||||
|
||||
# find the each edge of intersect rectangle
|
||||
left_line = max(rec1[1], rec2[1])
|
||||
right_line = min(rec1[3], rec2[3])
|
||||
top_line = max(rec1[0], rec2[0])
|
||||
bottom_line = min(rec1[2], rec2[2])
|
||||
|
||||
|
||||
# judge if there is an intersect
|
||||
if left_line >= right_line or top_line >= bottom_line:
|
||||
return 0.0
|
||||
else:
|
||||
intersect = (right_line - left_line) * (bottom_line - top_line)
|
||||
return (intersect / (sum_area - intersect))*1.0
|
||||
|
||||
return (intersect / (sum_area - intersect)) * 1.0
|
||||
|
||||
|
||||
def matcher_merge(ocr_bboxes, pred_bboxes):
|
||||
|
@ -45,15 +48,18 @@ def matcher_merge(ocr_bboxes, pred_bboxes):
|
|||
distances = []
|
||||
for j, pred_box in enumerate(pred_bboxes):
|
||||
# compute l1 distence and IOU between two boxes
|
||||
distances.append((distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box)))
|
||||
distances.append((distance(gt_box, pred_box),
|
||||
1. - compute_iou(gt_box, pred_box)))
|
||||
sorted_distances = distances.copy()
|
||||
# select nearest cell
|
||||
sorted_distances = sorted(sorted_distances, key = lambda item: (item[1], item[0]))
|
||||
if distances.index(sorted_distances[0]) not in matched.keys():
|
||||
sorted_distances = sorted(
|
||||
sorted_distances, key=lambda item: (item[1], item[0]))
|
||||
if distances.index(sorted_distances[0]) not in matched.keys():
|
||||
matched[distances.index(sorted_distances[0])] = [i]
|
||||
else:
|
||||
matched[distances.index(sorted_distances[0])].append(i)
|
||||
return matched#, sum(ious) / len(ious)
|
||||
return matched #, sum(ious) / len(ious)
|
||||
|
||||
|
||||
def complex_num(pred_bboxes):
|
||||
complex_nums = []
|
||||
|
@ -67,6 +73,7 @@ def complex_num(pred_bboxes):
|
|||
complex_nums.append(temp_ious[distances.index(min(distances))])
|
||||
return sum(complex_nums) / len(complex_nums)
|
||||
|
||||
|
||||
def get_rows(pred_bboxes):
|
||||
pre_bbox = pred_bboxes[0]
|
||||
res = []
|
||||
|
@ -81,7 +88,9 @@ def get_rows(pred_bboxes):
|
|||
for i in range(step):
|
||||
pred_bboxes.pop(0)
|
||||
return res, pred_bboxes
|
||||
def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上
|
||||
|
||||
|
||||
def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上
|
||||
ys_1 = []
|
||||
ys_2 = []
|
||||
for box in pred_bboxes:
|
||||
|
@ -95,12 +104,14 @@ def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上
|
|||
box[3] = min_y_2
|
||||
re_boxes.append(box)
|
||||
return re_boxes
|
||||
|
||||
|
||||
|
||||
def matcher_refine_row(gt_bboxes, pred_bboxes):
|
||||
before_refine_pred_bboxes = pred_bboxes.copy()
|
||||
pred_bboxes = []
|
||||
while(len(before_refine_pred_bboxes) != 0):
|
||||
row_bboxes, before_refine_pred_bboxes = get_rows(before_refine_pred_bboxes)
|
||||
while (len(before_refine_pred_bboxes) != 0):
|
||||
row_bboxes, before_refine_pred_bboxes = get_rows(
|
||||
before_refine_pred_bboxes)
|
||||
print(row_bboxes)
|
||||
pred_bboxes.extend(refine_rows(row_bboxes))
|
||||
all_dis = []
|
||||
|
@ -114,12 +125,11 @@ def matcher_refine_row(gt_bboxes, pred_bboxes):
|
|||
#temp_ious.append(compute_iou(gt_box, pred_box))
|
||||
#all_dis.append(min(distances))
|
||||
#ious.append(temp_ious[distances.index(min(distances))])
|
||||
if distances.index(min(distances)) not in matched.keys():
|
||||
if distances.index(min(distances)) not in matched.keys():
|
||||
matched[distances.index(min(distances))] = [i]
|
||||
else:
|
||||
matched[distances.index(min(distances))].append(i)
|
||||
return matched#, sum(ious) / len(ious)
|
||||
|
||||
return matched #, sum(ious) / len(ious)
|
||||
|
||||
|
||||
#先挑选出一行,再进行匹配
|
||||
|
@ -128,29 +138,30 @@ def matcher_structure_1(gt_bboxes, pred_bboxes_rows, pred_bboxes):
|
|||
delete_gt_bboxes = gt_bboxes.copy()
|
||||
match_bboxes_ready = []
|
||||
matched = {}
|
||||
while(len(delete_gt_bboxes) != 0):
|
||||
while (len(delete_gt_bboxes) != 0):
|
||||
row_bboxes, delete_gt_bboxes = get_rows(delete_gt_bboxes)
|
||||
row_bboxes = sorted(row_bboxes, key = lambda key: key[0])
|
||||
row_bboxes = sorted(row_bboxes, key=lambda key: key[0])
|
||||
if len(pred_bboxes_rows) > 0:
|
||||
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
|
||||
print(row_bboxes)
|
||||
for i, gt_box in enumerate(row_bboxes):
|
||||
#print(gt_box)
|
||||
pred_distances = []
|
||||
distances = []
|
||||
distances = []
|
||||
for pred_bbox in pred_bboxes:
|
||||
pred_distances.append(distance(gt_box, pred_bbox))
|
||||
for j, pred_box in enumerate(match_bboxes_ready):
|
||||
distances.append(distance(gt_box, pred_box))
|
||||
index = pred_distances.index(min(distances))
|
||||
#print('index', index)
|
||||
if index not in matched.keys():
|
||||
if index not in matched.keys():
|
||||
matched[index] = [gt_box_index]
|
||||
else:
|
||||
matched[index].append(gt_box_index)
|
||||
gt_box_index += 1
|
||||
return matched
|
||||
|
||||
|
||||
def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes):
|
||||
'''
|
||||
gt_bboxes: 排序后
|
||||
|
@ -161,7 +172,7 @@ def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes):
|
|||
match_bboxes_ready = []
|
||||
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
|
||||
for i, gt_box in enumerate(gt_bboxes):
|
||||
|
||||
|
||||
pred_distances = []
|
||||
for pred_bbox in pred_bboxes:
|
||||
pred_distances.append(distance(gt_box, pred_bbox))
|
||||
|
@ -184,9 +195,143 @@ def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes):
|
|||
#print(gt_box, index)
|
||||
#match_bboxes_ready.pop(distances.index(min(distances)))
|
||||
print(gt_box, match_bboxes_ready[distances.index(min(distances))])
|
||||
if index not in matched.keys():
|
||||
if index not in matched.keys():
|
||||
matched[index] = [i]
|
||||
else:
|
||||
matched[index].append(i)
|
||||
pre_bbox = gt_box
|
||||
return matched
|
||||
|
||||
|
||||
class TableMatch:
|
||||
def __init__(self, filter_ocr_result=False, use_master=False):
|
||||
self.filter_ocr_result = filter_ocr_result
|
||||
self.use_master = use_master
|
||||
|
||||
def __call__(self, structure_res, dt_boxes, rec_res):
|
||||
pred_structures, pred_bboxes = structure_res
|
||||
if self.filter_ocr_result:
|
||||
dt_boxes, rec_res = self.filter_ocr_result(pred_bboxes, dt_boxes,
|
||||
rec_res)
|
||||
matched_index = self.match_result(dt_boxes, pred_bboxes)
|
||||
if self.use_master:
|
||||
pred_html, pred = self.get_pred_html_master(pred_structures,
|
||||
matched_index, rec_res)
|
||||
else:
|
||||
pred_html, pred = self.get_pred_html(pred_structures, matched_index,
|
||||
rec_res)
|
||||
return pred_html
|
||||
|
||||
def match_result(self, dt_boxes, pred_bboxes):
|
||||
matched = {}
|
||||
for i, gt_box in enumerate(dt_boxes):
|
||||
# gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])]
|
||||
distances = []
|
||||
for j, pred_box in enumerate(pred_bboxes):
|
||||
distances.append((distance(gt_box, pred_box),
|
||||
1. - compute_iou(gt_box, pred_box)
|
||||
)) # 获取两两cell之间的L1距离和 1- IOU
|
||||
sorted_distances = distances.copy()
|
||||
# 根据距离和IOU挑选最"近"的cell
|
||||
sorted_distances = sorted(
|
||||
sorted_distances, key=lambda item: (item[1], item[0]))
|
||||
if distances.index(sorted_distances[0]) not in matched.keys():
|
||||
matched[distances.index(sorted_distances[0])] = [i]
|
||||
else:
|
||||
matched[distances.index(sorted_distances[0])].append(i)
|
||||
return matched
|
||||
|
||||
def get_pred_html(self, pred_structures, matched_index, ocr_contents):
|
||||
end_html = []
|
||||
td_index = 0
|
||||
for tag in pred_structures:
|
||||
if '</td>' in tag:
|
||||
if '<td></td>' == tag:
|
||||
end_html.extend('<td>')
|
||||
if td_index in matched_index.keys():
|
||||
b_with = False
|
||||
if '<b>' in ocr_contents[matched_index[td_index][
|
||||
0]] and len(matched_index[td_index]) > 1:
|
||||
b_with = True
|
||||
end_html.extend('<b>')
|
||||
for i, td_index_index in enumerate(matched_index[td_index]):
|
||||
content = ocr_contents[td_index_index][0]
|
||||
if len(matched_index[td_index]) > 1:
|
||||
if len(content) == 0:
|
||||
continue
|
||||
if content[0] == ' ':
|
||||
content = content[1:]
|
||||
if '<b>' in content:
|
||||
content = content[3:]
|
||||
if '</b>' in content:
|
||||
content = content[:-4]
|
||||
if len(content) == 0:
|
||||
continue
|
||||
if i != len(matched_index[
|
||||
td_index]) - 1 and ' ' != content[-1]:
|
||||
content += ' '
|
||||
end_html.extend(content)
|
||||
if b_with:
|
||||
end_html.extend('</b>')
|
||||
if '<td></td>' == tag:
|
||||
end_html.append('</td>')
|
||||
else:
|
||||
end_html.append(tag)
|
||||
td_index += 1
|
||||
else:
|
||||
end_html.append(tag)
|
||||
return ''.join(end_html), end_html
|
||||
|
||||
def get_pred_html_master(self, pred_structures, matched_index,
|
||||
ocr_contents):
|
||||
end_html = []
|
||||
td_index = 0
|
||||
for token in pred_structures:
|
||||
if '</td>' in token:
|
||||
txt = ''
|
||||
b_with = False
|
||||
if td_index in matched_index.keys():
|
||||
if '<b>' in ocr_contents[matched_index[td_index][
|
||||
0]] and len(matched_index[td_index]) > 1:
|
||||
b_with = True
|
||||
for i, td_index_index in enumerate(matched_index[td_index]):
|
||||
content = ocr_contents[td_index_index][0]
|
||||
if len(matched_index[td_index]) > 1:
|
||||
if len(content) == 0:
|
||||
continue
|
||||
if content[0] == ' ':
|
||||
content = content[1:]
|
||||
if '<b>' in content:
|
||||
content = content[3:]
|
||||
if '</b>' in content:
|
||||
content = content[:-4]
|
||||
if len(content) == 0:
|
||||
continue
|
||||
if i != len(matched_index[
|
||||
td_index]) - 1 and ' ' != content[-1]:
|
||||
content += ' '
|
||||
txt += content
|
||||
if b_with:
|
||||
txt = '<b>{}</b>'.format(txt)
|
||||
if '<td></td>' == token:
|
||||
token = '<td>{}</td>'.format(txt)
|
||||
else:
|
||||
token = '{}</td>'.format(txt)
|
||||
td_index += 1
|
||||
token = deal_eb_token(token)
|
||||
end_html.append(token)
|
||||
html = ''.join(end_html)
|
||||
html = deal_bb(html)
|
||||
return html, end_html
|
||||
|
||||
def filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res):
|
||||
y1 = pred_bboxes[:, 1::2].min()
|
||||
new_dt_boxes = []
|
||||
new_rec_res = []
|
||||
|
||||
for box, rec in zip(dt_boxes, rec_res):
|
||||
if np.max(box[1::2]) < y1:
|
||||
continue
|
||||
new_dt_boxes.append(box)
|
||||
new_rec_res.append(rec)
|
||||
return new_dt_boxes, new_rec_res
|
||||
|
|
|
@ -16,7 +16,7 @@ import sys
|
|||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||
|
||||
|
@ -87,6 +87,7 @@ class TableStructurer(object):
|
|||
utility.create_predictor(args, 'table', logger)
|
||||
|
||||
def __call__(self, img):
|
||||
starttime = time.time()
|
||||
ori_im = img.copy()
|
||||
data = {'image': img}
|
||||
data = transform(data, self.preprocess_op)
|
||||
|
@ -95,7 +96,6 @@ class TableStructurer(object):
|
|||
return None, 0
|
||||
img = np.expand_dims(img, axis=0)
|
||||
img = img.copy()
|
||||
starttime = time.time()
|
||||
|
||||
self.input_tensor.copy_from_cpu(img)
|
||||
self.predictor.run()
|
||||
|
@ -126,7 +126,6 @@ def main(args):
|
|||
table_structurer = TableStructurer(args)
|
||||
count = 0
|
||||
total_time = 0
|
||||
use_xywh = args.table_algorithm in ['TableMaster']
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
with open(
|
||||
os.path.join(args.output, 'infer.txt'), mode='w',
|
||||
|
@ -146,7 +145,7 @@ def main(args):
|
|||
f_w.write("result: {}, {}\n".format(structure_str_list,
|
||||
bbox_list_str))
|
||||
|
||||
img = draw_rectangle(image_file, bbox_list, use_xywh)
|
||||
img = draw_rectangle(image_file, bbox_list)
|
||||
img_save_path = os.path.join(args.output,
|
||||
os.path.basename(image_file))
|
||||
cv2.imwrite(img_save_path, img)
|
||||
|
|
|
@ -18,20 +18,23 @@ import subprocess
|
|||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||
import cv2
|
||||
import copy
|
||||
import logging
|
||||
import numpy as np
|
||||
import time
|
||||
import tools.infer.predict_rec as predict_rec
|
||||
import tools.infer.predict_det as predict_det
|
||||
import tools.infer.utility as utility
|
||||
from tools.infer.predict_system import sorted_boxes
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
from ppocr.utils.logging import get_logger
|
||||
from ppstructure.table.matcher import distance, compute_iou
|
||||
from ppstructure.table.matcher import TableMatch
|
||||
from ppstructure.table.table_master_match import TableMasterMatcher
|
||||
from ppstructure.utility import parse_args
|
||||
import ppstructure.table.predict_structure as predict_strture
|
||||
|
||||
|
@ -55,11 +58,20 @@ def expand(pix, det_box, shape):
|
|||
|
||||
class TableSystem(object):
|
||||
def __init__(self, args, text_detector=None, text_recognizer=None):
|
||||
if not args.show_log:
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
self.text_detector = predict_det.TextDetector(
|
||||
args) if text_detector is None else text_detector
|
||||
self.text_recognizer = predict_rec.TextRecognizer(
|
||||
args) if text_recognizer is None else text_recognizer
|
||||
|
||||
self.table_structurer = predict_strture.TableStructurer(args)
|
||||
if args.table_algorithm in ['TableMaster']:
|
||||
self.match = TableMasterMatcher()
|
||||
else:
|
||||
self.match = TableMatch()
|
||||
|
||||
self.benchmark = args.benchmark
|
||||
self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
|
||||
args, 'table', logger)
|
||||
|
@ -85,16 +97,47 @@ class TableSystem(object):
|
|||
|
||||
def __call__(self, img, return_ocr_result_in_table=False):
|
||||
result = dict()
|
||||
ori_im = img.copy()
|
||||
time_dict = {'det': 0, 'rec': 0, 'table': 0, 'all': 0, 'match': 0}
|
||||
start = time.time()
|
||||
|
||||
structure_res, elapse = self._structure(copy.deepcopy(img))
|
||||
time_dict['table'] = elapse
|
||||
|
||||
dt_boxes, rec_res, det_elapse, rec_elapse = self._ocr(
|
||||
copy.deepcopy(img))
|
||||
time_dict['det'] = det_elapse
|
||||
time_dict['rec'] = rec_elapse
|
||||
|
||||
if return_ocr_result_in_table:
|
||||
result['boxes'] = dt_boxes #[x.tolist() for x in dt_boxes]
|
||||
result['rec_res'] = rec_res
|
||||
|
||||
tic = time.time()
|
||||
pred_html = self.match(structure_res, dt_boxes, rec_res)
|
||||
toc = time.time()
|
||||
time_dict['match'] = toc - tic
|
||||
# pred_html = self.match(1, 1, 1,img_name)
|
||||
result['html'] = pred_html
|
||||
if self.benchmark:
|
||||
self.autolog.times.end(stamp=True)
|
||||
end = time.time()
|
||||
time_dict['all'] = end - start
|
||||
if self.benchmark:
|
||||
self.autolog.times.stamp()
|
||||
return result, time_dict
|
||||
|
||||
def _structure(self, img):
|
||||
if self.benchmark:
|
||||
self.autolog.times.start()
|
||||
structure_res, elapse = self.table_structurer(copy.deepcopy(img))
|
||||
return structure_res, elapse
|
||||
|
||||
def _ocr(self, img):
|
||||
if self.benchmark:
|
||||
self.autolog.times.stamp()
|
||||
dt_boxes, elapse = self.text_detector(copy.deepcopy(img))
|
||||
dt_boxes, det_elapse = self.text_detector(copy.deepcopy(img))
|
||||
dt_boxes = sorted_boxes(dt_boxes)
|
||||
if return_ocr_result_in_table:
|
||||
result['boxes'] = [x.tolist() for x in dt_boxes]
|
||||
|
||||
r_boxes = []
|
||||
for box in dt_boxes:
|
||||
x_min = box[:, 0].min() - 1
|
||||
|
@ -105,125 +148,20 @@ class TableSystem(object):
|
|||
r_boxes.append(box)
|
||||
dt_boxes = np.array(r_boxes)
|
||||
logger.debug("dt_boxes num : {}, elapse : {}".format(
|
||||
len(dt_boxes), elapse))
|
||||
len(dt_boxes), det_elapse))
|
||||
if dt_boxes is None:
|
||||
return None, None
|
||||
|
||||
img_crop_list = []
|
||||
for i in range(len(dt_boxes)):
|
||||
det_box = dt_boxes[i]
|
||||
x0, y0, x1, y1 = expand(2, det_box, ori_im.shape)
|
||||
text_rect = ori_im[int(y0):int(y1), int(x0):int(x1), :]
|
||||
x0, y0, x1, y1 = expand(2, det_box, img.shape)
|
||||
text_rect = img[int(y0):int(y1), int(x0):int(x1), :]
|
||||
img_crop_list.append(text_rect)
|
||||
rec_res, elapse = self.text_recognizer(img_crop_list)
|
||||
rec_res, rec_elapse = self.text_recognizer(img_crop_list)
|
||||
logger.debug("rec_res num : {}, elapse : {}".format(
|
||||
len(rec_res), elapse))
|
||||
if self.benchmark:
|
||||
self.autolog.times.stamp()
|
||||
if return_ocr_result_in_table:
|
||||
result['rec_res'] = rec_res
|
||||
pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res)
|
||||
result['html'] = pred_html
|
||||
if self.benchmark:
|
||||
self.autolog.times.end(stamp=True)
|
||||
return result
|
||||
|
||||
def rebuild_table(self, structure_res, dt_boxes, rec_res):
|
||||
pred_structures, pred_bboxes = structure_res
|
||||
dt_boxes, rec_res = self.filter_ocr_result(pred_bboxes,dt_boxes, rec_res)
|
||||
matched_index = self.match_result(dt_boxes, pred_bboxes)
|
||||
pred_html, pred = self.get_pred_html(pred_structures, matched_index,
|
||||
rec_res)
|
||||
return pred_html, pred
|
||||
|
||||
def filter_ocr_result(self, pred_bboxes,dt_boxes, rec_res):
|
||||
y1 = pred_bboxes[:,1::2].min()
|
||||
new_dt_boxes = []
|
||||
new_rec_res = []
|
||||
|
||||
for box,rec in zip(dt_boxes, rec_res):
|
||||
if np.max(box[1::2]) < y1:
|
||||
continue
|
||||
new_dt_boxes.append(box)
|
||||
new_rec_res.append(rec)
|
||||
return new_dt_boxes, new_rec_res
|
||||
|
||||
|
||||
def match_result(self, dt_boxes, pred_bboxes):
|
||||
matched = {}
|
||||
for i, gt_box in enumerate(dt_boxes):
|
||||
# gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])]
|
||||
distances = []
|
||||
for j, pred_box in enumerate(pred_bboxes):
|
||||
distances.append((distance(gt_box, pred_box),
|
||||
1. - compute_iou(gt_box, pred_box)
|
||||
)) # 获取两两cell之间的L1距离和 1- IOU
|
||||
sorted_distances = distances.copy()
|
||||
# 根据距离和IOU挑选最"近"的cell
|
||||
sorted_distances = sorted(
|
||||
sorted_distances, key=lambda item: (item[1], item[0]))
|
||||
if distances.index(sorted_distances[0]) not in matched.keys():
|
||||
matched[distances.index(sorted_distances[0])] = [i]
|
||||
else:
|
||||
matched[distances.index(sorted_distances[0])].append(i)
|
||||
return matched
|
||||
|
||||
def get_pred_html(self, pred_structures, matched_index, ocr_contents):
|
||||
end_html = []
|
||||
td_index = 0
|
||||
for tag in pred_structures:
|
||||
if '</td>' in tag:
|
||||
if td_index in matched_index.keys():
|
||||
b_with = False
|
||||
if '<b>' in ocr_contents[matched_index[td_index][
|
||||
0]] and len(matched_index[td_index]) > 1:
|
||||
b_with = True
|
||||
end_html.extend('<b>')
|
||||
for i, td_index_index in enumerate(matched_index[td_index]):
|
||||
content = ocr_contents[td_index_index][0]
|
||||
if len(matched_index[td_index]) > 1:
|
||||
if len(content) == 0:
|
||||
continue
|
||||
if content[0] == ' ':
|
||||
content = content[1:]
|
||||
if '<b>' in content:
|
||||
content = content[3:]
|
||||
if '</b>' in content:
|
||||
content = content[:-4]
|
||||
if len(content) == 0:
|
||||
continue
|
||||
if i != len(matched_index[
|
||||
td_index]) - 1 and ' ' != content[-1]:
|
||||
content += ' '
|
||||
end_html.extend(content)
|
||||
if b_with:
|
||||
end_html.extend('</b>')
|
||||
|
||||
end_html.append(tag)
|
||||
td_index += 1
|
||||
else:
|
||||
end_html.append(tag)
|
||||
return ''.join(end_html), end_html
|
||||
|
||||
|
||||
def sorted_boxes(dt_boxes):
|
||||
"""
|
||||
Sort text boxes in order from top to bottom, left to right
|
||||
args:
|
||||
dt_boxes(array):detected text boxes with shape [4, 2]
|
||||
return:
|
||||
sorted boxes(array) with shape [4, 2]
|
||||
"""
|
||||
num_boxes = dt_boxes.shape[0]
|
||||
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
|
||||
_boxes = list(sorted_boxes)
|
||||
|
||||
for i in range(num_boxes - 1):
|
||||
if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
|
||||
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
|
||||
tmp = _boxes[i]
|
||||
_boxes[i] = _boxes[i + 1]
|
||||
_boxes[i + 1] = tmp
|
||||
return _boxes
|
||||
len(rec_res), rec_elapse))
|
||||
return dt_boxes, rec_res, det_elapse, rec_elapse
|
||||
|
||||
|
||||
def to_excel(html_table, excel_path):
|
||||
|
@ -249,7 +187,7 @@ def main(args):
|
|||
logger.error("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
starttime = time.time()
|
||||
pred_res = text_sys(img)
|
||||
pred_res, _ = text_sys(img)
|
||||
pred_html = pred_res['html']
|
||||
logger.info(pred_html)
|
||||
to_excel(pred_html, excel_path)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -32,6 +32,7 @@ def init_args():
|
|||
type=str,
|
||||
default="../ppocr/utils/dict/table_structure_dict.txt")
|
||||
# params for layout
|
||||
parser.add_argument("--layout_model_dir", type=str)
|
||||
parser.add_argument(
|
||||
"--layout_path_model",
|
||||
type=str,
|
||||
|
@ -87,7 +88,7 @@ def draw_structure_result(image, result, font_path):
|
|||
image = Image.fromarray(image)
|
||||
boxes, txts, scores = [], [], []
|
||||
for region in result:
|
||||
if region['type'] == 'Table':
|
||||
if region['type'] == 'table':
|
||||
pass
|
||||
else:
|
||||
for text_result in region['res']:
|
||||
|
|
|
@ -19,8 +19,6 @@ Global:
|
|||
character_type: en
|
||||
max_text_length: 800
|
||||
infer_mode: False
|
||||
process_total_num: 0
|
||||
process_cut_num: 0
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
|
|
|
@ -16,8 +16,6 @@ Global:
|
|||
character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt
|
||||
infer_mode: false
|
||||
max_text_length: 500
|
||||
process_total_num: 0
|
||||
process_cut_num: 0
|
||||
|
||||
|
||||
Optimizer:
|
||||
|
@ -86,7 +84,7 @@ Train:
|
|||
- PaddingTableImage:
|
||||
size: [480, 480]
|
||||
- TableBoxEncode:
|
||||
use_xywh: True
|
||||
box_format: 'xywh'
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.5, 0.5, 0.5]
|
||||
|
@ -120,7 +118,7 @@ Eval:
|
|||
- PaddingTableImage:
|
||||
size: [480, 480]
|
||||
- TableBoxEncode:
|
||||
use_xywh: True
|
||||
box_format: 'xywh'
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.5, 0.5, 0.5]
|
||||
|
|
|
@ -65,9 +65,11 @@ class TextSystem(object):
|
|||
self.crop_image_res_index += bbox_num
|
||||
|
||||
def __call__(self, img, cls=True):
|
||||
time_dict = {'det': 0, 'rec': 0, 'csl': 0, 'all': 0}
|
||||
start = time.time()
|
||||
ori_im = img.copy()
|
||||
dt_boxes, elapse = self.text_detector(img)
|
||||
|
||||
time_dict['det'] = elapse
|
||||
logger.debug("dt_boxes num : {}, elapse : {}".format(
|
||||
len(dt_boxes), elapse))
|
||||
if dt_boxes is None:
|
||||
|
@ -83,10 +85,12 @@ class TextSystem(object):
|
|||
if self.use_angle_cls and cls:
|
||||
img_crop_list, angle_list, elapse = self.text_classifier(
|
||||
img_crop_list)
|
||||
time_dict['cls'] = elapse
|
||||
logger.debug("cls num : {}, elapse : {}".format(
|
||||
len(img_crop_list), elapse))
|
||||
|
||||
rec_res, elapse = self.text_recognizer(img_crop_list)
|
||||
time_dict['rec'] = elapse
|
||||
logger.debug("rec_res num : {}, elapse : {}".format(
|
||||
len(rec_res), elapse))
|
||||
if self.args.save_crop_res:
|
||||
|
@ -98,7 +102,9 @@ class TextSystem(object):
|
|||
if score >= self.drop_score:
|
||||
filter_boxes.append(box)
|
||||
filter_rec_res.append(rec_result)
|
||||
return filter_boxes, filter_rec_res
|
||||
end = time.time()
|
||||
time_dict['all'] = end - start
|
||||
return filter_boxes, filter_rec_res, time_dict
|
||||
|
||||
|
||||
def sorted_boxes(dt_boxes):
|
||||
|
@ -133,9 +139,11 @@ def main(args):
|
|||
os.makedirs(draw_img_save_dir, exist_ok=True)
|
||||
save_results = []
|
||||
|
||||
logger.info("In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', "
|
||||
"if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320")
|
||||
|
||||
logger.info(
|
||||
"In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', "
|
||||
"if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320"
|
||||
)
|
||||
|
||||
# warm up 10 times
|
||||
if args.warmup:
|
||||
img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
|
||||
|
@ -155,7 +163,7 @@ def main(args):
|
|||
logger.debug("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
starttime = time.time()
|
||||
dt_boxes, rec_res = text_sys(img)
|
||||
dt_boxes, rec_res, time_dict = text_sys(img)
|
||||
elapse = time.time() - starttime
|
||||
total_time += elapse
|
||||
|
||||
|
@ -198,7 +206,10 @@ def main(args):
|
|||
text_sys.text_detector.autolog.report()
|
||||
text_sys.text_recognizer.autolog.report()
|
||||
|
||||
with open(os.path.join(draw_img_save_dir, "system_results.txt"), 'w', encoding='utf-8') as f:
|
||||
with open(
|
||||
os.path.join(draw_img_save_dir, "system_results.txt"),
|
||||
'w',
|
||||
encoding='utf-8') as f:
|
||||
f.writelines(save_results)
|
||||
|
||||
|
||||
|
|
|
@ -155,6 +155,8 @@ def create_predictor(args, mode, logger):
|
|||
model_dir = args.table_model_dir
|
||||
elif mode == 'ser':
|
||||
model_dir = args.ser_model_dir
|
||||
elif mode == 'layout':
|
||||
model_dir = args.layout_model_dir
|
||||
else:
|
||||
model_dir = args.e2e_model_dir
|
||||
|
||||
|
|
|
@ -56,7 +56,6 @@ def main(config, device, logger, vdl_writer):
|
|||
|
||||
model = build_model(config['Architecture'])
|
||||
algorithm = config['Architecture']['algorithm']
|
||||
use_xywh = algorithm in ['TableMaster']
|
||||
|
||||
load_model(config, model)
|
||||
|
||||
|
@ -106,7 +105,7 @@ def main(config, device, logger, vdl_writer):
|
|||
f_w.write("result: {}, {}\n".format(structure_str_list,
|
||||
bbox_list_str))
|
||||
|
||||
img = draw_rectangle(file, bbox_list, use_xywh)
|
||||
img = draw_rectangle(file, bbox_list)
|
||||
cv2.imwrite(
|
||||
os.path.join(save_res_path, os.path.basename(file)), img)
|
||||
logger.info("success!")
|
||||
|
|
|
@ -154,6 +154,7 @@ def check_xpu(use_xpu):
|
|||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
def to_float32(preds):
|
||||
if isinstance(preds, dict):
|
||||
for k in preds:
|
||||
|
@ -173,6 +174,7 @@ def to_float32(preds):
|
|||
preds = preds.astype(paddle.float32)
|
||||
return preds
|
||||
|
||||
|
||||
def train(config,
|
||||
train_dataloader,
|
||||
valid_dataloader,
|
||||
|
@ -596,7 +598,7 @@ def preprocess(is_train=False):
|
|||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
|
||||
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
|
||||
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN'
|
||||
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'SLANet'
|
||||
]
|
||||
|
||||
if use_xpu:
|
||||
|
|
|
@ -119,6 +119,10 @@ def main(config, device, logger, vdl_writer):
|
|||
config['Loss']['ignore_index'] = char_num - 1
|
||||
|
||||
model = build_model(config['Architecture'])
|
||||
use_sync_bn = config["Global"].get("use_sync_bn", False)
|
||||
if use_sync_bn:
|
||||
model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
logger.info('convert_sync_batchnorm')
|
||||
if config['Global']['distributed']:
|
||||
model = paddle.DataParallel(model)
|
||||
|
||||
|
@ -157,7 +161,8 @@ def main(config, device, logger, vdl_writer):
|
|||
scaler = paddle.amp.GradScaler(
|
||||
init_loss_scaling=scale_loss,
|
||||
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
|
||||
model, optimizer = paddle.amp.decorate(models=model, optimizers=optimizer, level='O2', master_weight=True)
|
||||
model, optimizer = paddle.amp.decorate(
|
||||
models=model, optimizers=optimizer, level='O2', master_weight=True)
|
||||
else:
|
||||
scaler = None
|
||||
|
||||
|
|
Loading…
Reference in New Issue