commit
ff185e6004
|
@ -0,0 +1,111 @@
|
|||
Global:
|
||||
use_gpu: True
|
||||
epoch_num: 300
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 50
|
||||
save_model_dir: ./output/kie_5/
|
||||
save_epoch_step: 50
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [ 0, 80 ]
|
||||
# 1. If pretrained_model is saved in static mode, such as classification pretrained model
|
||||
# from static branch, load_static_weights must be set as True.
|
||||
# 2. If you want to finetune the pretrained models we provide in the docs,
|
||||
# you should set load_static_weights as False.
|
||||
load_static_weights: False
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./output/kie_4/best_accuracy
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
class_path: ./train_data/wildreceipt/class_list.txt
|
||||
infer_img: ./train_data/wildreceipt/1.txt
|
||||
save_res_path: ./output/sdmgr_kie/predicts_kie.txt
|
||||
img_scale: [ 1024, 512 ]
|
||||
|
||||
Architecture:
|
||||
model_type: kie
|
||||
algorithm: SDMGR
|
||||
Transform:
|
||||
Backbone:
|
||||
name: Kie_backbone
|
||||
Head:
|
||||
name: SDMGRHead
|
||||
|
||||
Loss:
|
||||
name: SDMGRLoss
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Piecewise
|
||||
learning_rate: 0.001
|
||||
decay_epochs: [ 60, 80, 100]
|
||||
values: [ 0.001, 0.0001, 0.00001]
|
||||
warmup_epoch: 2
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0.00005
|
||||
|
||||
PostProcess:
|
||||
name: None
|
||||
|
||||
Metric:
|
||||
name: KIEMetric
|
||||
main_indicator: hmean
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/wildreceipt/
|
||||
label_file_list: [ './train_data/wildreceipt/wildreceipt_train.txt' ]
|
||||
ratio_list: [ 1.0 ]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
channel_first: False
|
||||
- NormalizeImage:
|
||||
scale: 1
|
||||
mean: [ 123.675, 116.28, 103.53 ]
|
||||
std: [ 58.395, 57.12, 57.375 ]
|
||||
order: 'hwc'
|
||||
- KieLabelEncode: # Class handling label
|
||||
character_dict_path: ./train_data/wildreceipt/dict.txt
|
||||
- KieResize:
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'image', 'relations', 'texts', 'points', 'labels', 'tag', 'shape'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 4
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/wildreceipt
|
||||
label_file_list:
|
||||
- ./train_data/wildreceipt/wildreceipt_test.txt
|
||||
# - /paddle/data/PaddleOCR/train_data/wildreceipt/1.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
channel_first: False
|
||||
- KieLabelEncode: # Class handling label
|
||||
character_dict_path: ./train_data/wildreceipt/dict.txt
|
||||
- KieResize:
|
||||
- NormalizeImage:
|
||||
scale: 1
|
||||
mean: [ 123.675, 116.28, 103.53 ]
|
||||
std: [ 58.395, 57.12, 57.375 ]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'image', 'relations', 'texts', 'points', 'labels', 'tag', 'ori_image', 'ori_boxes', 'shape']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 4
|
|
@ -21,6 +21,7 @@ PaddleOCR开源的文本检测算法列表:
|
|||
- [x] EAST([paper](https://arxiv.org/abs/1704.03155))[1]
|
||||
- [x] SAST([paper](https://arxiv.org/abs/1908.05498))[4]
|
||||
- [x] PSENet([paper](https://arxiv.org/abs/1903.12473v2))
|
||||
- [x] SDMGR([paper](https://arxiv.org/pdf/2103.14470.pdf))
|
||||
|
||||
在ICDAR2015文本检测公开数据集上,算法效果如下:
|
||||
|模型|骨干网络|precision|recall|Hmean|下载链接|
|
||||
|
@ -32,6 +33,7 @@ PaddleOCR开源的文本检测算法列表:
|
|||
|SAST|ResNet50_vd|91.39%|83.77%|87.42%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)|
|
||||
|PSE|ResNet50_vd|85.81%|79.53%|82.55%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_vd_pse_v2.0_train.tar)|
|
||||
|PSE|MobileNetV3|82.20%|70.48%|75.89%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar)|
|
||||
|SDMGR|VGG16|-|-|87.11%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/kie_vgg16.tar)|
|
||||
|
||||
在Total-text文本检测公开数据集上,算法效果如下:
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ from __future__ import unicode_literals
|
|||
|
||||
import numpy as np
|
||||
import string
|
||||
from shapely.geometry import LineString, Point, Polygon
|
||||
import json
|
||||
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
@ -286,6 +287,168 @@ class E2ELabelEncodeTrain(object):
|
|||
return data
|
||||
|
||||
|
||||
class KieLabelEncode(object):
|
||||
def __init__(self, character_dict_path, norm=10, directed=False, **kwargs):
|
||||
super(KieLabelEncode, self).__init__()
|
||||
self.dict = dict({'': 0})
|
||||
with open(character_dict_path, 'r') as fr:
|
||||
idx = 1
|
||||
for line in fr:
|
||||
char = line.strip()
|
||||
self.dict[char] = idx
|
||||
idx += 1
|
||||
self.norm = norm
|
||||
self.directed = directed
|
||||
|
||||
def compute_relation(self, boxes):
|
||||
"""Compute relation between every two boxes."""
|
||||
x1s, y1s = boxes[:, 0:1], boxes[:, 1:2]
|
||||
x2s, y2s = boxes[:, 4:5], boxes[:, 5:6]
|
||||
ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1)
|
||||
dxs = (x1s[:, 0][None] - x1s) / self.norm
|
||||
dys = (y1s[:, 0][None] - y1s) / self.norm
|
||||
xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs
|
||||
whs = ws / hs + np.zeros_like(xhhs)
|
||||
relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1)
|
||||
bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32)
|
||||
return relations, bboxes
|
||||
|
||||
def pad_text_indices(self, text_inds):
|
||||
"""Pad text index to same length."""
|
||||
max_len = 300
|
||||
recoder_len = max([len(text_ind) for text_ind in text_inds])
|
||||
padded_text_inds = -np.ones((len(text_inds), max_len), np.int32)
|
||||
for idx, text_ind in enumerate(text_inds):
|
||||
padded_text_inds[idx, :len(text_ind)] = np.array(text_ind)
|
||||
return padded_text_inds, recoder_len
|
||||
|
||||
def list_to_numpy(self, ann_infos):
|
||||
"""Convert bboxes, relations, texts and labels to ndarray."""
|
||||
boxes, text_inds = ann_infos['points'], ann_infos['text_inds']
|
||||
boxes = np.array(boxes, np.int32)
|
||||
relations, bboxes = self.compute_relation(boxes)
|
||||
|
||||
labels = ann_infos.get('labels', None)
|
||||
if labels is not None:
|
||||
labels = np.array(labels, np.int32)
|
||||
edges = ann_infos.get('edges', None)
|
||||
if edges is not None:
|
||||
labels = labels[:, None]
|
||||
edges = np.array(edges)
|
||||
edges = (edges[:, None] == edges[None, :]).astype(np.int32)
|
||||
if self.directed:
|
||||
edges = (edges & labels == 1).astype(np.int32)
|
||||
np.fill_diagonal(edges, -1)
|
||||
labels = np.concatenate([labels, edges], -1)
|
||||
padded_text_inds, recoder_len = self.pad_text_indices(text_inds)
|
||||
max_num = 300
|
||||
temp_bboxes = np.zeros([max_num, 4])
|
||||
h, _ = bboxes.shape
|
||||
temp_bboxes[:h, :h] = bboxes
|
||||
|
||||
temp_relations = np.zeros([max_num, max_num, 5])
|
||||
temp_relations[:h, :h, :] = relations
|
||||
|
||||
temp_padded_text_inds = np.zeros([max_num, max_num])
|
||||
temp_padded_text_inds[:h, :] = padded_text_inds
|
||||
|
||||
temp_labels = np.zeros([max_num, max_num])
|
||||
temp_labels[:h, :h + 1] = labels
|
||||
|
||||
tag = np.array([h, recoder_len])
|
||||
return dict(
|
||||
image=ann_infos['image'],
|
||||
points=temp_bboxes,
|
||||
relations=temp_relations,
|
||||
texts=temp_padded_text_inds,
|
||||
labels=temp_labels,
|
||||
tag=tag)
|
||||
|
||||
def convert_canonical(self, points_x, points_y):
|
||||
|
||||
assert len(points_x) == 4
|
||||
assert len(points_y) == 4
|
||||
|
||||
points = [Point(points_x[i], points_y[i]) for i in range(4)]
|
||||
|
||||
polygon = Polygon([(p.x, p.y) for p in points])
|
||||
min_x, min_y, _, _ = polygon.bounds
|
||||
points_to_lefttop = [
|
||||
LineString([points[i], Point(min_x, min_y)]) for i in range(4)
|
||||
]
|
||||
distances = np.array([line.length for line in points_to_lefttop])
|
||||
sort_dist_idx = np.argsort(distances)
|
||||
lefttop_idx = sort_dist_idx[0]
|
||||
|
||||
if lefttop_idx == 0:
|
||||
point_orders = [0, 1, 2, 3]
|
||||
elif lefttop_idx == 1:
|
||||
point_orders = [1, 2, 3, 0]
|
||||
elif lefttop_idx == 2:
|
||||
point_orders = [2, 3, 0, 1]
|
||||
else:
|
||||
point_orders = [3, 0, 1, 2]
|
||||
|
||||
sorted_points_x = [points_x[i] for i in point_orders]
|
||||
sorted_points_y = [points_y[j] for j in point_orders]
|
||||
|
||||
return sorted_points_x, sorted_points_y
|
||||
|
||||
def sort_vertex(self, points_x, points_y):
|
||||
|
||||
assert len(points_x) == 4
|
||||
assert len(points_y) == 4
|
||||
|
||||
x = np.array(points_x)
|
||||
y = np.array(points_y)
|
||||
center_x = np.sum(x) * 0.25
|
||||
center_y = np.sum(y) * 0.25
|
||||
|
||||
x_arr = np.array(x - center_x)
|
||||
y_arr = np.array(y - center_y)
|
||||
|
||||
angle = np.arctan2(y_arr, x_arr) * 180.0 / np.pi
|
||||
sort_idx = np.argsort(angle)
|
||||
|
||||
sorted_points_x, sorted_points_y = [], []
|
||||
for i in range(4):
|
||||
sorted_points_x.append(points_x[sort_idx[i]])
|
||||
sorted_points_y.append(points_y[sort_idx[i]])
|
||||
|
||||
return self.convert_canonical(sorted_points_x, sorted_points_y)
|
||||
|
||||
def __call__(self, data):
|
||||
import json
|
||||
label = data['label']
|
||||
annotations = json.loads(label)
|
||||
boxes, texts, text_inds, labels, edges = [], [], [], [], []
|
||||
for ann in annotations:
|
||||
box = ann['points']
|
||||
x_list = [box[i][0] for i in range(4)]
|
||||
y_list = [box[i][1] for i in range(4)]
|
||||
sorted_x_list, sorted_y_list = self.sort_vertex(x_list, y_list)
|
||||
sorted_box = []
|
||||
for x, y in zip(sorted_x_list, sorted_y_list):
|
||||
sorted_box.append(x)
|
||||
sorted_box.append(y)
|
||||
boxes.append(sorted_box)
|
||||
text = ann['transcription']
|
||||
texts.append(ann['transcription'])
|
||||
text_ind = [self.dict[c] for c in text if c in self.dict]
|
||||
text_inds.append(text_ind)
|
||||
labels.append(ann['label'])
|
||||
edges.append(ann.get('edge', 0))
|
||||
ann_infos = dict(
|
||||
image=data['image'],
|
||||
points=boxes,
|
||||
texts=texts,
|
||||
text_inds=text_inds,
|
||||
edges=edges,
|
||||
labels=labels)
|
||||
|
||||
return self.list_to_numpy(ann_infos)
|
||||
|
||||
|
||||
class AttnLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
|
|
|
@ -111,7 +111,6 @@ class NormalizeImage(object):
|
|||
from PIL import Image
|
||||
if isinstance(img, Image.Image):
|
||||
img = np.array(img)
|
||||
|
||||
assert isinstance(img,
|
||||
np.ndarray), "invalid input 'img' in NormalizeImage"
|
||||
data['image'] = (
|
||||
|
@ -367,3 +366,53 @@ class E2EResizeForTest(object):
|
|||
ratio_w = resize_w / float(w)
|
||||
|
||||
return im, (ratio_h, ratio_w)
|
||||
|
||||
|
||||
class KieResize(object):
|
||||
def __init__(self, **kwargs):
|
||||
super(KieResize, self).__init__()
|
||||
self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[
|
||||
'img_scale'][1]
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
points = data['points']
|
||||
src_h, src_w, _ = img.shape
|
||||
im_resized, scale_factor, [ratio_h, ratio_w
|
||||
], [new_h, new_w] = self.resize_image(img)
|
||||
resize_points = self.resize_boxes(img, points, scale_factor)
|
||||
data['ori_image'] = img
|
||||
data['ori_boxes'] = points
|
||||
data['points'] = resize_points
|
||||
data['image'] = im_resized
|
||||
data['shape'] = np.array([new_h, new_w])
|
||||
return data
|
||||
|
||||
def resize_image(self, img):
|
||||
norm_img = np.zeros([1024, 1024, 3], dtype='float32')
|
||||
scale = [512, 1024]
|
||||
h, w = img.shape[:2]
|
||||
max_long_edge = max(scale)
|
||||
max_short_edge = min(scale)
|
||||
scale_factor = min(max_long_edge / max(h, w),
|
||||
max_short_edge / min(h, w))
|
||||
resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float(
|
||||
scale_factor) + 0.5)
|
||||
max_stride = 32
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
im = cv2.resize(img, (resize_w, resize_h))
|
||||
new_h, new_w = im.shape[:2]
|
||||
w_scale = new_w / w
|
||||
h_scale = new_h / h
|
||||
scale_factor = np.array(
|
||||
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
|
||||
norm_img[:new_h, :new_w, :] = im
|
||||
return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w]
|
||||
|
||||
def resize_boxes(self, im, points, scale_factor):
|
||||
points = points * scale_factor
|
||||
img_shape = im.shape[:2]
|
||||
points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
|
||||
points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
|
||||
return points
|
||||
|
|
|
@ -35,6 +35,7 @@ from .cls_loss import ClsLoss
|
|||
|
||||
# e2e loss
|
||||
from .e2e_pg_loss import PGLoss
|
||||
from .kie_sdmgr_loss import SDMGRLoss
|
||||
|
||||
# basic loss function
|
||||
from .basic_loss import DistanceLoss
|
||||
|
@ -50,7 +51,7 @@ def build_loss(config):
|
|||
support_dict = [
|
||||
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss',
|
||||
'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss',
|
||||
'TableAttentionLoss', 'SARLoss', 'AsterLoss'
|
||||
'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss'
|
||||
]
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -0,0 +1,113 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from paddle import nn
|
||||
import paddle
|
||||
|
||||
|
||||
class SDMGRLoss(nn.Layer):
|
||||
def __init__(self, node_weight=1.0, edge_weight=1.0, ignore=0):
|
||||
super().__init__()
|
||||
self.loss_node = nn.CrossEntropyLoss(ignore_index=ignore)
|
||||
self.loss_edge = nn.CrossEntropyLoss(ignore_index=-1)
|
||||
self.node_weight = node_weight
|
||||
self.edge_weight = edge_weight
|
||||
self.ignore = ignore
|
||||
|
||||
def pre_process(self, gts, tag):
|
||||
gts, tag = gts.numpy(), tag.numpy().tolist()
|
||||
temp_gts = []
|
||||
batch = len(tag)
|
||||
for i in range(batch):
|
||||
num, recoder_len = tag[i][0], tag[i][1]
|
||||
temp_gts.append(
|
||||
paddle.to_tensor(
|
||||
gts[i, :num, :num + 1], dtype='int64'))
|
||||
return temp_gts
|
||||
|
||||
def accuracy(self, pred, target, topk=1, thresh=None):
|
||||
"""Calculate accuracy according to the prediction and target.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The model prediction, shape (N, num_class)
|
||||
target (torch.Tensor): The target of each prediction, shape (N, )
|
||||
topk (int | tuple[int], optional): If the predictions in ``topk``
|
||||
matches the target, the predictions will be regarded as
|
||||
correct ones. Defaults to 1.
|
||||
thresh (float, optional): If not None, predictions with scores under
|
||||
this threshold are considered incorrect. Default to None.
|
||||
|
||||
Returns:
|
||||
float | tuple[float]: If the input ``topk`` is a single integer,
|
||||
the function will return a single float as accuracy. If
|
||||
``topk`` is a tuple containing multiple integers, the
|
||||
function will return a tuple containing accuracies of
|
||||
each ``topk`` number.
|
||||
"""
|
||||
assert isinstance(topk, (int, tuple))
|
||||
if isinstance(topk, int):
|
||||
topk = (topk, )
|
||||
return_single = True
|
||||
else:
|
||||
return_single = False
|
||||
|
||||
maxk = max(topk)
|
||||
if pred.shape[0] == 0:
|
||||
accu = [pred.new_tensor(0.) for i in range(len(topk))]
|
||||
return accu[0] if return_single else accu
|
||||
pred_value, pred_label = paddle.topk(pred, maxk, axis=1)
|
||||
pred_label = pred_label.transpose(
|
||||
[1, 0]) # transpose to shape (maxk, N)
|
||||
correct = paddle.equal(pred_label,
|
||||
(target.reshape([1, -1]).expand_as(pred_label)))
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = paddle.sum(correct[:k].reshape([-1]).astype('float32'),
|
||||
axis=0,
|
||||
keepdim=True)
|
||||
res.append(
|
||||
paddle.multiply(correct_k,
|
||||
paddle.to_tensor(100.0 / pred.shape[0])))
|
||||
return res[0] if return_single else res
|
||||
|
||||
def forward(self, pred, batch):
|
||||
node_preds, edge_preds = pred
|
||||
gts, tag = batch[4], batch[5]
|
||||
gts = self.pre_process(gts, tag)
|
||||
node_gts, edge_gts = [], []
|
||||
for gt in gts:
|
||||
node_gts.append(gt[:, 0])
|
||||
edge_gts.append(gt[:, 1:].reshape([-1]))
|
||||
node_gts = paddle.concat(node_gts)
|
||||
edge_gts = paddle.concat(edge_gts)
|
||||
|
||||
node_valids = paddle.nonzero(node_gts != self.ignore).reshape([-1])
|
||||
edge_valids = paddle.nonzero(edge_gts != -1).reshape([-1])
|
||||
loss_node = self.loss_node(node_preds, node_gts)
|
||||
loss_edge = self.loss_edge(edge_preds, edge_gts)
|
||||
loss = self.node_weight * loss_node + self.edge_weight * loss_edge
|
||||
return dict(
|
||||
loss=loss,
|
||||
loss_node=loss_node,
|
||||
loss_edge=loss_edge,
|
||||
acc_node=self.accuracy(
|
||||
paddle.gather(node_preds, node_valids),
|
||||
paddle.gather(node_gts, node_valids)),
|
||||
acc_edge=self.accuracy(
|
||||
paddle.gather(edge_preds, edge_valids),
|
||||
paddle.gather(edge_gts, edge_valids)))
|
|
@ -27,10 +27,13 @@ from .cls_metric import ClsMetric
|
|||
from .e2e_metric import E2EMetric
|
||||
from .distillation_metric import DistillationMetric
|
||||
from .table_metric import TableMetric
|
||||
from .kie_metric import KIEMetric
|
||||
|
||||
|
||||
def build_metric(config):
|
||||
support_dict = [
|
||||
"DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric", "TableMetric"
|
||||
"DetMetric", "RecMetric", "ClsMetric", "E2EMetric",
|
||||
"DistillationMetric", "TableMetric", 'KIEMetric'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
__all__ = ['KIEMetric']
|
||||
|
||||
|
||||
class KIEMetric(object):
|
||||
def __init__(self, main_indicator='hmean', **kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.reset()
|
||||
self.node = []
|
||||
self.gt = []
|
||||
|
||||
def __call__(self, preds, batch, **kwargs):
|
||||
nodes, _ = preds
|
||||
gts, tag = batch[4].squeeze(0), batch[5].tolist()[0]
|
||||
gts = gts[:tag[0], :1].reshape([-1])
|
||||
self.node.append(nodes.numpy())
|
||||
self.gt.append(gts)
|
||||
# result = self.compute_f1_score(nodes, gts)
|
||||
# self.results.append(result)
|
||||
|
||||
def compute_f1_score(self, preds, gts):
|
||||
ignores = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 25]
|
||||
C = preds.shape[1]
|
||||
classes = np.array(sorted(set(range(C)) - set(ignores)))
|
||||
hist = np.bincount(
|
||||
(gts * C).astype('int64') + preds.argmax(1), minlength=C
|
||||
**2).reshape([C, C]).astype('float32')
|
||||
diag = np.diag(hist)
|
||||
recalls = diag / hist.sum(1).clip(min=1)
|
||||
precisions = diag / hist.sum(0).clip(min=1)
|
||||
f1 = 2 * recalls * precisions / (recalls + precisions).clip(min=1e-8)
|
||||
return f1[classes]
|
||||
|
||||
def combine_results(self, results):
|
||||
node = np.concatenate(self.node, 0)
|
||||
gts = np.concatenate(self.gt, 0)
|
||||
results = self.compute_f1_score(node, gts)
|
||||
data = {'hmean': results.mean()}
|
||||
return data
|
||||
|
||||
def get_metric(self):
|
||||
|
||||
metircs = self.combine_results(self.results)
|
||||
self.reset()
|
||||
return metircs
|
||||
|
||||
def reset(self):
|
||||
self.results = [] # clear results
|
||||
self.node = []
|
||||
self.gt = []
|
|
@ -35,7 +35,14 @@ def build_backbone(config, model_type):
|
|||
]
|
||||
elif model_type == "e2e":
|
||||
from .e2e_resnet_vd_pg import ResNet
|
||||
support_dict = ["ResNet"]
|
||||
support_dict = ['ResNet']
|
||||
elif model_type == 'kie':
|
||||
from .kie_unet_sdmgr import Kie_backbone
|
||||
support_dict = ['Kie_backbone']
|
||||
elif model_type == "table":
|
||||
from .table_resnet_vd import ResNet
|
||||
from .table_mobilenet_v3 import MobileNetV3
|
||||
support_dict = ["ResNet", "MobileNetV3"]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -0,0 +1,186 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
__all__ = ["Kie_backbone"]
|
||||
|
||||
|
||||
class Encoder(nn.Layer):
|
||||
def __init__(self, num_channels, num_filters):
|
||||
super(Encoder, self).__init__()
|
||||
self.conv1 = nn.Conv2D(
|
||||
num_channels,
|
||||
num_filters,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias_attr=False)
|
||||
self.bn1 = nn.BatchNorm(num_filters, act='relu')
|
||||
|
||||
self.conv2 = nn.Conv2D(
|
||||
num_filters,
|
||||
num_filters,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias_attr=False)
|
||||
self.bn2 = nn.BatchNorm(num_filters, act='relu')
|
||||
|
||||
self.pool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self.conv1(inputs)
|
||||
x = self.bn1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.bn2(x)
|
||||
x_pooled = self.pool(x)
|
||||
return x, x_pooled
|
||||
|
||||
|
||||
class Decoder(nn.Layer):
|
||||
def __init__(self, num_channels, num_filters):
|
||||
super(Decoder, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2D(
|
||||
num_channels,
|
||||
num_filters,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias_attr=False)
|
||||
self.bn1 = nn.BatchNorm(num_filters, act='relu')
|
||||
|
||||
self.conv2 = nn.Conv2D(
|
||||
num_filters,
|
||||
num_filters,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias_attr=False)
|
||||
self.bn2 = nn.BatchNorm(num_filters, act='relu')
|
||||
|
||||
self.conv0 = nn.Conv2D(
|
||||
num_channels,
|
||||
num_filters,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias_attr=False)
|
||||
self.bn0 = nn.BatchNorm(num_filters, act='relu')
|
||||
|
||||
def forward(self, inputs_prev, inputs):
|
||||
x = self.conv0(inputs)
|
||||
x = self.bn0(x)
|
||||
x = paddle.nn.functional.interpolate(
|
||||
x, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
x = paddle.concat([inputs_prev, x], axis=1)
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.bn2(x)
|
||||
return x
|
||||
|
||||
|
||||
class UNet(nn.Layer):
|
||||
def __init__(self):
|
||||
super(UNet, self).__init__()
|
||||
self.down1 = Encoder(num_channels=3, num_filters=16)
|
||||
self.down2 = Encoder(num_channels=16, num_filters=32)
|
||||
self.down3 = Encoder(num_channels=32, num_filters=64)
|
||||
self.down4 = Encoder(num_channels=64, num_filters=128)
|
||||
self.down5 = Encoder(num_channels=128, num_filters=256)
|
||||
|
||||
self.up1 = Decoder(32, 16)
|
||||
self.up2 = Decoder(64, 32)
|
||||
self.up3 = Decoder(128, 64)
|
||||
self.up4 = Decoder(256, 128)
|
||||
self.out_channels = 16
|
||||
|
||||
def forward(self, inputs):
|
||||
x1, _ = self.down1(inputs)
|
||||
_, x2 = self.down2(x1)
|
||||
_, x3 = self.down3(x2)
|
||||
_, x4 = self.down4(x3)
|
||||
_, x5 = self.down5(x4)
|
||||
|
||||
x = self.up4(x4, x5)
|
||||
x = self.up3(x3, x)
|
||||
x = self.up2(x2, x)
|
||||
x = self.up1(x1, x)
|
||||
return x
|
||||
|
||||
|
||||
class Kie_backbone(nn.Layer):
|
||||
def __init__(self, in_channels, **kwargs):
|
||||
super(Kie_backbone, self).__init__()
|
||||
self.out_channels = 16
|
||||
self.img_feat = UNet()
|
||||
self.maxpool = nn.MaxPool2D(kernel_size=7)
|
||||
|
||||
def bbox2roi(self, bbox_list):
|
||||
rois_list = []
|
||||
rois_num = []
|
||||
for img_id, bboxes in enumerate(bbox_list):
|
||||
rois_num.append(bboxes.shape[0])
|
||||
rois_list.append(bboxes)
|
||||
rois = paddle.concat(rois_list, 0)
|
||||
rois_num = paddle.to_tensor(rois_num, dtype='int32')
|
||||
return rois, rois_num
|
||||
|
||||
def pre_process(self, img, relations, texts, gt_bboxes, tag, img_size):
|
||||
img, relations, texts, gt_bboxes, tag, img_size = img.numpy(
|
||||
), relations.numpy(), texts.numpy(), gt_bboxes.numpy(), tag.numpy(
|
||||
).tolist(), img_size.numpy()
|
||||
temp_relations, temp_texts, temp_gt_bboxes = [], [], []
|
||||
h, w = int(np.max(img_size[:, 0])), int(np.max(img_size[:, 1]))
|
||||
img = paddle.to_tensor(img[:, :, :h, :w])
|
||||
batch = len(tag)
|
||||
for i in range(batch):
|
||||
num, recoder_len = tag[i][0], tag[i][1]
|
||||
temp_relations.append(
|
||||
paddle.to_tensor(
|
||||
relations[i, :num, :num, :], dtype='float32'))
|
||||
temp_texts.append(
|
||||
paddle.to_tensor(
|
||||
texts[i, :num, :recoder_len], dtype='float32'))
|
||||
temp_gt_bboxes.append(
|
||||
paddle.to_tensor(
|
||||
gt_bboxes[i, :num, ...], dtype='float32'))
|
||||
return img, temp_relations, temp_texts, temp_gt_bboxes
|
||||
|
||||
def forward(self, inputs):
|
||||
img = inputs[0]
|
||||
relations, texts, gt_bboxes, tag, img_size = inputs[1], inputs[
|
||||
2], inputs[3], inputs[5], inputs[-1]
|
||||
img, relations, texts, gt_bboxes = self.pre_process(
|
||||
img, relations, texts, gt_bboxes, tag, img_size)
|
||||
x = self.img_feat(img)
|
||||
boxes, rois_num = self.bbox2roi(gt_bboxes)
|
||||
feats = paddle.fluid.layers.roi_align(
|
||||
x,
|
||||
boxes,
|
||||
spatial_scale=1.0,
|
||||
pooled_height=7,
|
||||
pooled_width=7,
|
||||
rois_num=rois_num)
|
||||
feats = self.maxpool(feats).squeeze(-1).squeeze(-1)
|
||||
return [relations, texts, feats]
|
|
@ -33,14 +33,19 @@ def build_head(config):
|
|||
|
||||
# cls head
|
||||
from .cls_head import ClsHead
|
||||
|
||||
#kie head
|
||||
from .kie_sdmgr_head import SDMGRHead
|
||||
|
||||
from .table_att_head import TableAttentionHead
|
||||
|
||||
support_dict = [
|
||||
'DBHead', 'PSEHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead',
|
||||
'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
|
||||
'TableAttentionHead', 'SARHead', 'AsterHead'
|
||||
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead'
|
||||
]
|
||||
|
||||
#table head
|
||||
from .table_att_head import TableAttentionHead
|
||||
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception('head only support {}'.format(
|
||||
|
|
|
@ -0,0 +1,206 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import paddle
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle import ParamAttr
|
||||
|
||||
|
||||
class SDMGRHead(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
num_chars=92,
|
||||
visual_dim=16,
|
||||
fusion_dim=1024,
|
||||
node_input=32,
|
||||
node_embed=256,
|
||||
edge_input=5,
|
||||
edge_embed=256,
|
||||
num_gnn=2,
|
||||
num_classes=26,
|
||||
bidirectional=False):
|
||||
super().__init__()
|
||||
|
||||
self.fusion = Block([visual_dim, node_embed], node_embed, fusion_dim)
|
||||
self.node_embed = nn.Embedding(num_chars, node_input, 0)
|
||||
hidden = node_embed // 2 if bidirectional else node_embed
|
||||
self.rnn = nn.LSTM(
|
||||
input_size=node_input, hidden_size=hidden, num_layers=1)
|
||||
self.edge_embed = nn.Linear(edge_input, edge_embed)
|
||||
self.gnn_layers = nn.LayerList(
|
||||
[GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)])
|
||||
self.node_cls = nn.Linear(node_embed, num_classes)
|
||||
self.edge_cls = nn.Linear(edge_embed, 2)
|
||||
|
||||
def forward(self, input, targets):
|
||||
relations, texts, x = input
|
||||
node_nums, char_nums = [], []
|
||||
for text in texts:
|
||||
node_nums.append(text.shape[0])
|
||||
char_nums.append(paddle.sum((text > -1).astype(int), axis=-1))
|
||||
|
||||
max_num = max([char_num.max() for char_num in char_nums])
|
||||
all_nodes = paddle.concat([
|
||||
paddle.concat(
|
||||
[text, paddle.zeros(
|
||||
(text.shape[0], max_num - text.shape[1]))], -1)
|
||||
for text in texts
|
||||
])
|
||||
temp = paddle.clip(all_nodes, min=0).astype(int)
|
||||
embed_nodes = self.node_embed(temp)
|
||||
rnn_nodes, _ = self.rnn(embed_nodes)
|
||||
|
||||
b, h, w = rnn_nodes.shape
|
||||
nodes = paddle.zeros([b, w])
|
||||
all_nums = paddle.concat(char_nums)
|
||||
valid = paddle.nonzero((all_nums > 0).astype(int))
|
||||
temp_all_nums = (
|
||||
paddle.gather(all_nums, valid) - 1).unsqueeze(-1).unsqueeze(-1)
|
||||
temp_all_nums = paddle.expand(temp_all_nums, [
|
||||
temp_all_nums.shape[0], temp_all_nums.shape[1], rnn_nodes.shape[-1]
|
||||
])
|
||||
temp_all_nodes = paddle.gather(rnn_nodes, valid)
|
||||
N, C, A = temp_all_nodes.shape
|
||||
one_hot = F.one_hot(
|
||||
temp_all_nums[:, 0, :], num_classes=C).transpose([0, 2, 1])
|
||||
one_hot = paddle.multiply(
|
||||
temp_all_nodes, one_hot.astype("float32")).sum(axis=1, keepdim=True)
|
||||
t = one_hot.expand([N, 1, A]).squeeze(1)
|
||||
nodes = paddle.scatter(nodes, valid.squeeze(1), t)
|
||||
|
||||
if x is not None:
|
||||
nodes = self.fusion([x, nodes])
|
||||
|
||||
all_edges = paddle.concat(
|
||||
[rel.reshape([-1, rel.shape[-1]]) for rel in relations])
|
||||
embed_edges = self.edge_embed(all_edges.astype('float32'))
|
||||
embed_edges = F.normalize(embed_edges)
|
||||
|
||||
for gnn_layer in self.gnn_layers:
|
||||
nodes, cat_nodes = gnn_layer(nodes, embed_edges, node_nums)
|
||||
|
||||
node_cls, edge_cls = self.node_cls(nodes), self.edge_cls(cat_nodes)
|
||||
return node_cls, edge_cls
|
||||
|
||||
|
||||
class GNNLayer(nn.Layer):
|
||||
def __init__(self, node_dim=256, edge_dim=256):
|
||||
super().__init__()
|
||||
self.in_fc = nn.Linear(node_dim * 2 + edge_dim, node_dim)
|
||||
self.coef_fc = nn.Linear(node_dim, 1)
|
||||
self.out_fc = nn.Linear(node_dim, node_dim)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, nodes, edges, nums):
|
||||
start, cat_nodes = 0, []
|
||||
for num in nums:
|
||||
sample_nodes = nodes[start:start + num]
|
||||
cat_nodes.append(
|
||||
paddle.concat([
|
||||
paddle.expand(sample_nodes.unsqueeze(1), [-1, num, -1]),
|
||||
paddle.expand(sample_nodes.unsqueeze(0), [num, -1, -1])
|
||||
], -1).reshape([num**2, -1]))
|
||||
start += num
|
||||
cat_nodes = paddle.concat([paddle.concat(cat_nodes), edges], -1)
|
||||
cat_nodes = self.relu(self.in_fc(cat_nodes))
|
||||
coefs = self.coef_fc(cat_nodes)
|
||||
|
||||
start, residuals = 0, []
|
||||
for num in nums:
|
||||
residual = F.softmax(
|
||||
-paddle.eye(num).unsqueeze(-1) * 1e9 +
|
||||
coefs[start:start + num**2].reshape([num, num, -1]), 1)
|
||||
residuals.append((residual * cat_nodes[start:start + num**2]
|
||||
.reshape([num, num, -1])).sum(1))
|
||||
start += num**2
|
||||
|
||||
nodes += self.relu(self.out_fc(paddle.concat(residuals)))
|
||||
return [nodes, cat_nodes]
|
||||
|
||||
|
||||
class Block(nn.Layer):
|
||||
def __init__(self,
|
||||
input_dims,
|
||||
output_dim,
|
||||
mm_dim=1600,
|
||||
chunks=20,
|
||||
rank=15,
|
||||
shared=False,
|
||||
dropout_input=0.,
|
||||
dropout_pre_lin=0.,
|
||||
dropout_output=0.,
|
||||
pos_norm='before_cat'):
|
||||
super().__init__()
|
||||
self.rank = rank
|
||||
self.dropout_input = dropout_input
|
||||
self.dropout_pre_lin = dropout_pre_lin
|
||||
self.dropout_output = dropout_output
|
||||
assert (pos_norm in ['before_cat', 'after_cat'])
|
||||
self.pos_norm = pos_norm
|
||||
# Modules
|
||||
self.linear0 = nn.Linear(input_dims[0], mm_dim)
|
||||
self.linear1 = (self.linear0
|
||||
if shared else nn.Linear(input_dims[1], mm_dim))
|
||||
self.merge_linears0 = nn.LayerList()
|
||||
self.merge_linears1 = nn.LayerList()
|
||||
self.chunks = self.chunk_sizes(mm_dim, chunks)
|
||||
for size in self.chunks:
|
||||
ml0 = nn.Linear(size, size * rank)
|
||||
self.merge_linears0.append(ml0)
|
||||
ml1 = ml0 if shared else nn.Linear(size, size * rank)
|
||||
self.merge_linears1.append(ml1)
|
||||
self.linear_out = nn.Linear(mm_dim, output_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.linear0(x[0])
|
||||
x1 = self.linear1(x[1])
|
||||
bs = x1.shape[0]
|
||||
if self.dropout_input > 0:
|
||||
x0 = F.dropout(x0, p=self.dropout_input, training=self.training)
|
||||
x1 = F.dropout(x1, p=self.dropout_input, training=self.training)
|
||||
x0_chunks = paddle.split(x0, self.chunks, -1)
|
||||
x1_chunks = paddle.split(x1, self.chunks, -1)
|
||||
zs = []
|
||||
for x0_c, x1_c, m0, m1 in zip(x0_chunks, x1_chunks, self.merge_linears0,
|
||||
self.merge_linears1):
|
||||
m = m0(x0_c) * m1(x1_c) # bs x split_size*rank
|
||||
m = m.reshape([bs, self.rank, -1])
|
||||
z = paddle.sum(m, 1)
|
||||
if self.pos_norm == 'before_cat':
|
||||
z = paddle.sqrt(F.relu(z)) - paddle.sqrt(F.relu(-z))
|
||||
z = F.normalize(z)
|
||||
zs.append(z)
|
||||
z = paddle.concat(zs, 1)
|
||||
if self.pos_norm == 'after_cat':
|
||||
z = paddle.sqrt(F.relu(z)) - paddle.sqrt(F.relu(-z))
|
||||
z = F.normalize(z)
|
||||
|
||||
if self.dropout_pre_lin > 0:
|
||||
z = F.dropout(z, p=self.dropout_pre_lin, training=self.training)
|
||||
z = self.linear_out(z)
|
||||
if self.dropout_output > 0:
|
||||
z = F.dropout(z, p=self.dropout_output, training=self.training)
|
||||
return z
|
||||
|
||||
def chunk_sizes(self, dim, chunks):
|
||||
split_size = (dim + chunks - 1) // chunks
|
||||
sizes_list = [split_size] * chunks
|
||||
sizes_list[-1] = sizes_list[-1] - (sum(sizes_list) - dim)
|
||||
return sizes_list
|
|
@ -45,6 +45,8 @@ def build_post_process(config, global_config=None):
|
|||
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
if module_name == "None":
|
||||
return
|
||||
if global_config is not None:
|
||||
config.update(global_config)
|
||||
assert module_name in support_dict, Exception(
|
||||
|
|
|
@ -12,4 +12,5 @@ cython
|
|||
lxml
|
||||
premailer
|
||||
openpyxl
|
||||
fasttext==0.9.1
|
||||
fasttext==0.9.1
|
||||
|
||||
|
|
|
@ -54,7 +54,8 @@ def main():
|
|||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
|
||||
model = build_model(config['Architecture'])
|
||||
extra_input = config['Architecture']['algorithm'] in ["SRN", "SAR"]
|
||||
extra_input = config['Architecture'][
|
||||
'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"]
|
||||
if "model_type" in config['Architecture'].keys():
|
||||
model_type = config['Architecture']['model_type']
|
||||
else:
|
||||
|
@ -68,7 +69,6 @@ def main():
|
|||
|
||||
# build metric
|
||||
eval_class = build_metric(config['Metric'])
|
||||
|
||||
# start eval
|
||||
metric = program.eval(model, valid_dataloader, post_process_class,
|
||||
eval_class, model_type, extra_input)
|
||||
|
|
|
@ -0,0 +1,139 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import paddle.nn.functional as F
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
|
||||
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||
|
||||
import cv2
|
||||
import paddle
|
||||
|
||||
from ppocr.data import create_operators, transform
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.utils.save_load import init_model
|
||||
import tools.program as program
|
||||
|
||||
|
||||
def read_class_list(filepath):
|
||||
dict = {}
|
||||
with open(filepath, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
key, value = line.split(" ")
|
||||
dict[key] = value.rstrip()
|
||||
return dict
|
||||
|
||||
|
||||
def draw_kie_result(batch, node, idx_to_cls, count):
|
||||
img = batch[6].copy()
|
||||
boxes = batch[7]
|
||||
h, w = img.shape[:2]
|
||||
pred_img = np.ones((h, w * 2, 3), dtype=np.uint8) * 255
|
||||
max_value, max_idx = paddle.max(node, -1), paddle.argmax(node, -1)
|
||||
node_pred_label = max_idx.numpy().tolist()
|
||||
node_pred_score = max_value.numpy().tolist()
|
||||
|
||||
for i, box in enumerate(boxes):
|
||||
if i >= len(node_pred_label):
|
||||
break
|
||||
new_box = [[box[0], box[1]], [box[2], box[1]], [box[2], box[3]],
|
||||
[box[0], box[3]]]
|
||||
Pts = np.array([new_box], np.int32)
|
||||
cv2.polylines(
|
||||
img, [Pts.reshape((-1, 1, 2))],
|
||||
True,
|
||||
color=(255, 255, 0),
|
||||
thickness=1)
|
||||
x_min = int(min([point[0] for point in new_box]))
|
||||
y_min = int(min([point[1] for point in new_box]))
|
||||
|
||||
pred_label = str(node_pred_label[i])
|
||||
if pred_label in idx_to_cls:
|
||||
pred_label = idx_to_cls[pred_label]
|
||||
pred_score = '{:.2f}'.format(node_pred_score[i])
|
||||
text = pred_label + '(' + pred_score + ')'
|
||||
cv2.putText(pred_img, text, (x_min * 2, y_min),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
|
||||
vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255
|
||||
vis_img[:, :w] = img
|
||||
vis_img[:, w:] = pred_img
|
||||
save_kie_path = os.path.dirname(config['Global']['save_res_path']) + "/kie_results/"
|
||||
if not os.path.exists(save_kie_path):
|
||||
os.makedirs(save_kie_path)
|
||||
save_path = os.path.join(save_kie_path, str(count) + ".png")
|
||||
cv2.imwrite(save_path, vis_img)
|
||||
logger.info("The Kie Image saved in {}".format(save_path))
|
||||
|
||||
|
||||
def main():
|
||||
global_config = config['Global']
|
||||
|
||||
# build model
|
||||
model = build_model(config['Architecture'])
|
||||
init_model(config, model, logger)
|
||||
|
||||
# create data ops
|
||||
transforms = []
|
||||
for op in config['Eval']['dataset']['transforms']:
|
||||
transforms.append(op)
|
||||
|
||||
data_dir = config['Eval']['dataset']['data_dir']
|
||||
|
||||
ops = create_operators(transforms, global_config)
|
||||
|
||||
save_res_path = config['Global']['save_res_path']
|
||||
class_path = config['Global']['class_path']
|
||||
idx_to_cls = read_class_list(class_path)
|
||||
if not os.path.exists(os.path.dirname(save_res_path)):
|
||||
os.makedirs(os.path.dirname(save_res_path))
|
||||
|
||||
model.eval()
|
||||
with open(save_res_path, "wb") as fout:
|
||||
with open(config['Global']['infer_img'], "rb") as f:
|
||||
lines = f.readlines()
|
||||
for index, data_line in enumerate(lines):
|
||||
data_line = data_line.decode('utf-8')
|
||||
substr = data_line.strip("\n").split("\t")
|
||||
img_path, label = data_dir + "/" + substr[0], substr[1]
|
||||
data = {'img_path': img_path, 'label': label}
|
||||
with open(data['img_path'], 'rb') as f:
|
||||
img = f.read()
|
||||
data['image'] = img
|
||||
batch = transform(data, ops)
|
||||
batch_pred = [0] * len(batch)
|
||||
for i in range(len(batch)):
|
||||
batch_pred[i] = paddle.to_tensor(
|
||||
np.expand_dims(
|
||||
batch[i], axis=0))
|
||||
node, edge = model(batch_pred)
|
||||
node = F.softmax(node, -1)
|
||||
draw_kie_result(batch, node, idx_to_cls, index)
|
||||
logger.info("success!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config, device, logger, vdl_writer = program.preprocess()
|
||||
main()
|
|
@ -227,6 +227,10 @@ def train(config,
|
|||
images = batch[0]
|
||||
if use_srn:
|
||||
model_average = True
|
||||
if model_type == 'table' or extra_input:
|
||||
preds = model(images, data=batch[1:])
|
||||
if model_type == "kie":
|
||||
preds = model(batch)
|
||||
|
||||
train_start = time.time()
|
||||
# use amp
|
||||
|
@ -266,7 +270,7 @@ def train(config,
|
|||
|
||||
if cal_metric_during_train: # only rec and cls need
|
||||
batch = [item.numpy() for item in batch]
|
||||
if model_type == 'table':
|
||||
if model_type in ['table', 'kie']:
|
||||
eval_class(preds, batch)
|
||||
else:
|
||||
post_result = post_process_class(preds, batch[1])
|
||||
|
@ -399,17 +403,20 @@ def eval(model,
|
|||
start = time.time()
|
||||
if model_type == 'table' or extra_input:
|
||||
preds = model(images, data=batch[1:])
|
||||
if model_type == "kie":
|
||||
preds = model(batch)
|
||||
else:
|
||||
preds = model(images)
|
||||
batch = [item.numpy() for item in batch]
|
||||
# Obtain usable results from post-processing methods
|
||||
total_time += time.time() - start
|
||||
# Evaluate the results of the current batch
|
||||
if model_type == 'table':
|
||||
if model_type in ['table', 'kie']:
|
||||
eval_class(preds, batch)
|
||||
else:
|
||||
post_result = post_process_class(preds, batch[1])
|
||||
eval_class(post_result, batch)
|
||||
|
||||
pbar.update(1)
|
||||
total_frame += len(images)
|
||||
# Get final metric,eg. acc or hmean
|
||||
|
@ -498,8 +505,13 @@ def preprocess(is_train=False):
|
|||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
|
||||
'SEED'
|
||||
'SEED', 'SDMGR'
|
||||
]
|
||||
windows_not_support_list = ['PSE']
|
||||
if platform.system() == "Windows" and alg in windows_not_support_list:
|
||||
logger.warning('{} is not support in Windows now'.format(
|
||||
windows_not_support_list))
|
||||
sys.exit()
|
||||
|
||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
||||
device = paddle.set_device(device)
|
||||
|
|
Loading…
Reference in New Issue