152 lines
5.9 KiB
Python
152 lines
5.9 KiB
Python
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
import cv2
|
|
import numpy as np
|
|
from copy import deepcopy
|
|
|
|
|
|
def trans_poly_to_bbox(poly):
|
|
x1 = np.min([p[0] for p in poly])
|
|
x2 = np.max([p[0] for p in poly])
|
|
y1 = np.min([p[1] for p in poly])
|
|
y2 = np.max([p[1] for p in poly])
|
|
return [x1, y1, x2, y2]
|
|
|
|
|
|
def get_outer_poly(bbox_list):
|
|
x1 = min([bbox[0] for bbox in bbox_list])
|
|
y1 = min([bbox[1] for bbox in bbox_list])
|
|
x2 = max([bbox[2] for bbox in bbox_list])
|
|
y2 = max([bbox[3] for bbox in bbox_list])
|
|
return [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
|
|
|
|
|
|
def load_funsd_label(image_dir, anno_dir):
|
|
imgs = os.listdir(image_dir)
|
|
annos = os.listdir(anno_dir)
|
|
|
|
imgs = [img.replace(".png", "") for img in imgs]
|
|
annos = [anno.replace(".json", "") for anno in annos]
|
|
|
|
fn_info_map = dict()
|
|
for anno_fn in annos:
|
|
res = []
|
|
with open(os.path.join(anno_dir, anno_fn + ".json"), "r") as fin:
|
|
infos = json.load(fin)
|
|
infos = infos["form"]
|
|
old_id2new_id_map = dict()
|
|
global_new_id = 0
|
|
for info in infos:
|
|
if info["text"] is None:
|
|
continue
|
|
words = info["words"]
|
|
if len(words) <= 0:
|
|
continue
|
|
word_idx = 1
|
|
curr_bboxes = [words[0]["box"]]
|
|
curr_texts = [words[0]["text"]]
|
|
while word_idx < len(words):
|
|
# switch to a new link
|
|
if words[word_idx]["box"][0] + 10 <= words[word_idx - 1][
|
|
"box"][2]:
|
|
if len("".join(curr_texts[0])) > 0:
|
|
res.append({
|
|
"transcription": " ".join(curr_texts),
|
|
"label": info["label"],
|
|
"points": get_outer_poly(curr_bboxes),
|
|
"linking": info["linking"],
|
|
"id": global_new_id,
|
|
})
|
|
if info["id"] not in old_id2new_id_map:
|
|
old_id2new_id_map[info["id"]] = []
|
|
old_id2new_id_map[info["id"]].append(global_new_id)
|
|
global_new_id += 1
|
|
curr_bboxes = [words[word_idx]["box"]]
|
|
curr_texts = [words[word_idx]["text"]]
|
|
else:
|
|
curr_bboxes.append(words[word_idx]["box"])
|
|
curr_texts.append(words[word_idx]["text"])
|
|
word_idx += 1
|
|
if len("".join(curr_texts[0])) > 0:
|
|
res.append({
|
|
"transcription": " ".join(curr_texts),
|
|
"label": info["label"],
|
|
"points": get_outer_poly(curr_bboxes),
|
|
"linking": info["linking"],
|
|
"id": global_new_id,
|
|
})
|
|
if info["id"] not in old_id2new_id_map:
|
|
old_id2new_id_map[info["id"]] = []
|
|
old_id2new_id_map[info["id"]].append(global_new_id)
|
|
global_new_id += 1
|
|
res = sorted(
|
|
res, key=lambda r: (r["points"][0][1], r["points"][0][0]))
|
|
for i in range(len(res) - 1):
|
|
for j in range(i, 0, -1):
|
|
if abs(res[j + 1]["points"][0][1] - res[j]["points"][0][1]) < 20 and \
|
|
(res[j + 1]["points"][0][0] < res[j]["points"][0][0]):
|
|
tmp = deepcopy(res[j])
|
|
res[j] = deepcopy(res[j + 1])
|
|
res[j + 1] = deepcopy(tmp)
|
|
else:
|
|
break
|
|
# re-generate unique ids
|
|
for idx, r in enumerate(res):
|
|
new_links = []
|
|
for link in r["linking"]:
|
|
# illegal links will be removed
|
|
if link[0] not in old_id2new_id_map or link[
|
|
1] not in old_id2new_id_map:
|
|
continue
|
|
for src in old_id2new_id_map[link[0]]:
|
|
for dst in old_id2new_id_map[link[1]]:
|
|
new_links.append([src, dst])
|
|
res[idx]["linking"] = deepcopy(new_links)
|
|
|
|
fn_info_map[anno_fn] = res
|
|
|
|
return fn_info_map
|
|
|
|
|
|
def main():
|
|
test_image_dir = "train_data/FUNSD/testing_data/images/"
|
|
test_anno_dir = "train_data/FUNSD/testing_data/annotations/"
|
|
test_output_dir = "train_data/FUNSD/test.json"
|
|
|
|
fn_info_map = load_funsd_label(test_image_dir, test_anno_dir)
|
|
with open(test_output_dir, "w") as fout:
|
|
for fn in fn_info_map:
|
|
fout.write(fn + ".png" + "\t" + json.dumps(
|
|
fn_info_map[fn], ensure_ascii=False) + "\n")
|
|
|
|
train_image_dir = "train_data/FUNSD/training_data/images/"
|
|
train_anno_dir = "train_data/FUNSD/training_data/annotations/"
|
|
train_output_dir = "train_data/FUNSD/train.json"
|
|
|
|
fn_info_map = load_funsd_label(train_image_dir, train_anno_dir)
|
|
with open(train_output_dir, "w") as fout:
|
|
for fn in fn_info_map:
|
|
fout.write(fn + ".png" + "\t" + json.dumps(
|
|
fn_info_map[fn], ensure_ascii=False) + "\n")
|
|
print("====ok====")
|
|
return
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|