merge infer and train
parent
5b307b4b50
commit
434ab12276
|
@ -821,54 +821,44 @@ class VQATokenLabelEncode(object):
|
|||
self.ocr_engine = ocr_engine
|
||||
|
||||
def __call__(self, data):
|
||||
if self.infer_mode == False:
|
||||
return self._train(data)
|
||||
else:
|
||||
return self._infer(data)
|
||||
# load bbox and label info
|
||||
ocr_info = self._load_ocr_info(data)
|
||||
|
||||
def _train(self, data):
|
||||
info = data['label']
|
||||
|
||||
# read text info
|
||||
info_dict = json.loads(info)
|
||||
height = info_dict["height"]
|
||||
width = info_dict["width"]
|
||||
height, width, _ = data['image'].shape
|
||||
|
||||
words_list = []
|
||||
bbox_list = []
|
||||
input_ids_list = []
|
||||
token_type_ids_list = []
|
||||
segment_offset_id = []
|
||||
gt_label_list = []
|
||||
|
||||
if self.contains_re:
|
||||
# for re
|
||||
entities = []
|
||||
relations = []
|
||||
id2label = {}
|
||||
entity_id_to_index_map = {}
|
||||
empty_entity = set()
|
||||
for info in info_dict["ocr_info"]:
|
||||
if self.contains_re:
|
||||
if not self.infer_mode:
|
||||
relations = []
|
||||
id2label = {}
|
||||
entity_id_to_index_map = {}
|
||||
empty_entity = set()
|
||||
|
||||
data['ocr_info'] = copy.deepcopy(ocr_info)
|
||||
|
||||
for info in ocr_info:
|
||||
if self.contains_re and not self.infer_mode:
|
||||
# for re
|
||||
if len(info["text"]) == 0:
|
||||
empty_entity.add(info["id"])
|
||||
continue
|
||||
id2label[info["id"]] = info["label"]
|
||||
relations.extend([tuple(sorted(l)) for l in info["linking"]])
|
||||
|
||||
# x1, y1, x2, y2
|
||||
bbox = info["bbox"]
|
||||
label = info["label"]
|
||||
bbox[0] = int(bbox[0] * 1000.0 / width)
|
||||
bbox[2] = int(bbox[2] * 1000.0 / width)
|
||||
bbox[1] = int(bbox[1] * 1000.0 / height)
|
||||
bbox[3] = int(bbox[3] * 1000.0 / height)
|
||||
# smooth_box
|
||||
bbox = self._smooth_box(info["bbox"], height, width)
|
||||
|
||||
text = info["text"]
|
||||
encode_res = self.tokenizer.encode(
|
||||
text, pad_to_max_seq_len=False, return_attention_mask=True)
|
||||
|
||||
gt_label = []
|
||||
if not self.add_special_ids:
|
||||
# TODO: use tok.all_special_ids to remove
|
||||
encode_res["input_ids"] = encode_res["input_ids"][1:-1]
|
||||
|
@ -876,35 +866,44 @@ class VQATokenLabelEncode(object):
|
|||
-1]
|
||||
encode_res["attention_mask"] = encode_res["attention_mask"][1:
|
||||
-1]
|
||||
if label.lower() == "other":
|
||||
gt_label.extend([0] * len(encode_res["input_ids"]))
|
||||
else:
|
||||
gt_label.append(self.label2id_map[("b-" + label).upper()])
|
||||
gt_label.extend([self.label2id_map[("i-" + label).upper()]] *
|
||||
(len(encode_res["input_ids"]) - 1))
|
||||
# parse label
|
||||
if not self.infer_mode:
|
||||
label = info['label']
|
||||
gt_label = self._parse_label(label, encode_res)
|
||||
|
||||
# construct entities for re
|
||||
if self.contains_re:
|
||||
if gt_label[0] != self.label2id_map["O"]:
|
||||
entity_id_to_index_map[info["id"]] = len(entities)
|
||||
if not self.infer_mode:
|
||||
if gt_label[0] != self.label2id_map["O"]:
|
||||
entity_id_to_index_map[info["id"]] = len(entities)
|
||||
label = label.upper()
|
||||
entities.append({
|
||||
"start": len(input_ids_list),
|
||||
"end":
|
||||
len(input_ids_list) + len(encode_res["input_ids"]),
|
||||
"label": label.upper(),
|
||||
})
|
||||
else:
|
||||
entities.append({
|
||||
"start": len(input_ids_list),
|
||||
"end":
|
||||
len(input_ids_list) + len(encode_res["input_ids"]),
|
||||
"label": label.upper(),
|
||||
"label": 'O',
|
||||
})
|
||||
input_ids_list.extend(encode_res["input_ids"])
|
||||
token_type_ids_list.extend(encode_res["token_type_ids"])
|
||||
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
|
||||
gt_label_list.extend(gt_label)
|
||||
words_list.append(text)
|
||||
segment_offset_id.append(len(input_ids_list))
|
||||
if not self.infer_mode:
|
||||
gt_label_list.extend(gt_label)
|
||||
|
||||
encoded_inputs = {
|
||||
"input_ids": input_ids_list,
|
||||
"labels": gt_label_list,
|
||||
"token_type_ids": token_type_ids_list,
|
||||
"bbox": bbox_list,
|
||||
"attention_mask": [1] * len(input_ids_list),
|
||||
}
|
||||
data.update(encoded_inputs)
|
||||
data['input_ids'] = input_ids_list
|
||||
data['token_type_ids'] = token_type_ids_list
|
||||
data['bbox'] = bbox_list
|
||||
data['attention_mask'] = [1] * len(input_ids_list)
|
||||
data['labels'] = gt_label_list
|
||||
data['segment_offset_id'] = segment_offset_id
|
||||
data['tokenizer_params'] = dict(
|
||||
padding_side=self.tokenizer.padding_side,
|
||||
pad_token_type_id=self.tokenizer.pad_token_type_id,
|
||||
|
@ -912,79 +911,45 @@ class VQATokenLabelEncode(object):
|
|||
|
||||
if self.contains_re:
|
||||
data['entities'] = entities
|
||||
data['relations'] = relations
|
||||
data['id2label'] = id2label
|
||||
data['empty_entity'] = empty_entity
|
||||
data['entity_id_to_index_map'] = entity_id_to_index_map
|
||||
if self.infer_mode:
|
||||
data['ocr_info'] = ocr_info
|
||||
else:
|
||||
data['relations'] = relations
|
||||
data['id2label'] = id2label
|
||||
data['empty_entity'] = empty_entity
|
||||
data['entity_id_to_index_map'] = entity_id_to_index_map
|
||||
return data
|
||||
|
||||
def _infer(self, data):
|
||||
def trans_poly_to_bbox(poly):
|
||||
x1 = np.min([p[0] for p in poly])
|
||||
x2 = np.max([p[0] for p in poly])
|
||||
y1 = np.min([p[1] for p in poly])
|
||||
y2 = np.max([p[1] for p in poly])
|
||||
return [x1, y1, x2, y2]
|
||||
def _load_ocr_info(self, data):
|
||||
if self.infer_mode:
|
||||
ocr_result = self.ocr_engine.ocr(data['image'], cls=False)
|
||||
ocr_info = []
|
||||
for res in ocr_result:
|
||||
ocr_info.append({
|
||||
"text": res[1][0],
|
||||
"bbox": trans_poly_to_bbox(res[0]),
|
||||
"poly": res[0],
|
||||
})
|
||||
return ocr_info
|
||||
else:
|
||||
info = data['label']
|
||||
# read text info
|
||||
info_dict = json.loads(info)
|
||||
return info_dict["ocr_info"]
|
||||
|
||||
height, width, _ = data['image'].shape
|
||||
ocr_result = self.ocr_engine.ocr(data['image'], cls=False)
|
||||
ocr_info = []
|
||||
for res in ocr_result:
|
||||
ocr_info.append({
|
||||
"text": res[1][0],
|
||||
"bbox": trans_poly_to_bbox(res[0]),
|
||||
"poly": res[0],
|
||||
})
|
||||
def _smooth_box(self, bbox, height, width):
|
||||
bbox[0] = int(bbox[0] * 1000.0 / width)
|
||||
bbox[2] = int(bbox[2] * 1000.0 / width)
|
||||
bbox[1] = int(bbox[1] * 1000.0 / height)
|
||||
bbox[3] = int(bbox[3] * 1000.0 / height)
|
||||
return bbox
|
||||
|
||||
segment_offset_id = []
|
||||
words_list = []
|
||||
bbox_list = []
|
||||
input_ids_list = []
|
||||
token_type_ids_list = []
|
||||
entities = []
|
||||
|
||||
for info in ocr_info:
|
||||
# x1, y1, x2, y2
|
||||
bbox = copy.deepcopy(info["bbox"])
|
||||
bbox[0] = int(bbox[0] * 1000.0 / width)
|
||||
bbox[2] = int(bbox[2] * 1000.0 / width)
|
||||
bbox[1] = int(bbox[1] * 1000.0 / height)
|
||||
bbox[3] = int(bbox[3] * 1000.0 / height)
|
||||
|
||||
text = info["text"]
|
||||
encode_res = self.tokenizer.encode(
|
||||
text, pad_to_max_seq_len=False, return_attention_mask=True)
|
||||
|
||||
if not self.add_special_ids:
|
||||
# TODO: use tok.all_special_ids to remove
|
||||
encode_res["input_ids"] = encode_res["input_ids"][1:-1]
|
||||
encode_res["token_type_ids"] = encode_res["token_type_ids"][1:
|
||||
-1]
|
||||
encode_res["attention_mask"] = encode_res["attention_mask"][1:
|
||||
-1]
|
||||
|
||||
# for re
|
||||
entities.append({
|
||||
"start": len(input_ids_list),
|
||||
"end": len(input_ids_list) + len(encode_res["input_ids"]),
|
||||
"label": "O",
|
||||
})
|
||||
|
||||
input_ids_list.extend(encode_res["input_ids"])
|
||||
token_type_ids_list.extend(encode_res["token_type_ids"])
|
||||
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
|
||||
words_list.append(text)
|
||||
segment_offset_id.append(len(input_ids_list))
|
||||
|
||||
encoded_inputs = {
|
||||
"input_ids": input_ids_list,
|
||||
"token_type_ids": token_type_ids_list,
|
||||
"bbox": bbox_list,
|
||||
"attention_mask": [1] * len(input_ids_list),
|
||||
"entities": entities,
|
||||
'labels': None,
|
||||
'segment_offset_id': segment_offset_id,
|
||||
'ocr_info': ocr_info
|
||||
}
|
||||
data.update(encoded_inputs)
|
||||
return data
|
||||
def _parse_label(self, label, encode_res):
|
||||
gt_label = []
|
||||
if label.lower() == "other":
|
||||
gt_label.extend([0] * len(encode_res["input_ids"]))
|
||||
else:
|
||||
gt_label.append(self.label2id_map[("b-" + label).upper()])
|
||||
gt_label.extend([self.label2id_map[("i-" + label).upper()]] *
|
||||
(len(encode_res["input_ids"]) - 1))
|
||||
return gt_label
|
||||
|
|
Loading…
Reference in New Issue