diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 4cbd79005..27b4aca21 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -821,54 +821,44 @@ class VQATokenLabelEncode(object): self.ocr_engine = ocr_engine def __call__(self, data): - if self.infer_mode == False: - return self._train(data) - else: - return self._infer(data) + # load bbox and label info + ocr_info = self._load_ocr_info(data) - def _train(self, data): - info = data['label'] - - # read text info - info_dict = json.loads(info) - height = info_dict["height"] - width = info_dict["width"] + height, width, _ = data['image'].shape words_list = [] bbox_list = [] input_ids_list = [] token_type_ids_list = [] + segment_offset_id = [] gt_label_list = [] if self.contains_re: # for re entities = [] - relations = [] - id2label = {} - entity_id_to_index_map = {} - empty_entity = set() - for info in info_dict["ocr_info"]: - if self.contains_re: + if not self.infer_mode: + relations = [] + id2label = {} + entity_id_to_index_map = {} + empty_entity = set() + + data['ocr_info'] = copy.deepcopy(ocr_info) + + for info in ocr_info: + if self.contains_re and not self.infer_mode: # for re if len(info["text"]) == 0: empty_entity.add(info["id"]) continue id2label[info["id"]] = info["label"] relations.extend([tuple(sorted(l)) for l in info["linking"]]) - - # x1, y1, x2, y2 - bbox = info["bbox"] - label = info["label"] - bbox[0] = int(bbox[0] * 1000.0 / width) - bbox[2] = int(bbox[2] * 1000.0 / width) - bbox[1] = int(bbox[1] * 1000.0 / height) - bbox[3] = int(bbox[3] * 1000.0 / height) + # smooth_box + bbox = self._smooth_box(info["bbox"], height, width) text = info["text"] encode_res = self.tokenizer.encode( text, pad_to_max_seq_len=False, return_attention_mask=True) - gt_label = [] if not self.add_special_ids: # TODO: use tok.all_special_ids to remove encode_res["input_ids"] = encode_res["input_ids"][1:-1] @@ -876,35 +866,44 @@ class VQATokenLabelEncode(object): -1] encode_res["attention_mask"] = encode_res["attention_mask"][1: -1] - if label.lower() == "other": - gt_label.extend([0] * len(encode_res["input_ids"])) - else: - gt_label.append(self.label2id_map[("b-" + label).upper()]) - gt_label.extend([self.label2id_map[("i-" + label).upper()]] * - (len(encode_res["input_ids"]) - 1)) + # parse label + if not self.infer_mode: + label = info['label'] + gt_label = self._parse_label(label, encode_res) + + # construct entities for re if self.contains_re: - if gt_label[0] != self.label2id_map["O"]: - entity_id_to_index_map[info["id"]] = len(entities) + if not self.infer_mode: + if gt_label[0] != self.label2id_map["O"]: + entity_id_to_index_map[info["id"]] = len(entities) + label = label.upper() + entities.append({ + "start": len(input_ids_list), + "end": + len(input_ids_list) + len(encode_res["input_ids"]), + "label": label.upper(), + }) + else: entities.append({ "start": len(input_ids_list), "end": len(input_ids_list) + len(encode_res["input_ids"]), - "label": label.upper(), + "label": 'O', }) input_ids_list.extend(encode_res["input_ids"]) token_type_ids_list.extend(encode_res["token_type_ids"]) bbox_list.extend([bbox] * len(encode_res["input_ids"])) - gt_label_list.extend(gt_label) words_list.append(text) + segment_offset_id.append(len(input_ids_list)) + if not self.infer_mode: + gt_label_list.extend(gt_label) - encoded_inputs = { - "input_ids": input_ids_list, - "labels": gt_label_list, - "token_type_ids": token_type_ids_list, - "bbox": bbox_list, - "attention_mask": [1] * len(input_ids_list), - } - data.update(encoded_inputs) + data['input_ids'] = input_ids_list + data['token_type_ids'] = token_type_ids_list + data['bbox'] = bbox_list + data['attention_mask'] = [1] * len(input_ids_list) + data['labels'] = gt_label_list + data['segment_offset_id'] = segment_offset_id data['tokenizer_params'] = dict( padding_side=self.tokenizer.padding_side, pad_token_type_id=self.tokenizer.pad_token_type_id, @@ -912,79 +911,45 @@ class VQATokenLabelEncode(object): if self.contains_re: data['entities'] = entities - data['relations'] = relations - data['id2label'] = id2label - data['empty_entity'] = empty_entity - data['entity_id_to_index_map'] = entity_id_to_index_map + if self.infer_mode: + data['ocr_info'] = ocr_info + else: + data['relations'] = relations + data['id2label'] = id2label + data['empty_entity'] = empty_entity + data['entity_id_to_index_map'] = entity_id_to_index_map return data - def _infer(self, data): - def trans_poly_to_bbox(poly): - x1 = np.min([p[0] for p in poly]) - x2 = np.max([p[0] for p in poly]) - y1 = np.min([p[1] for p in poly]) - y2 = np.max([p[1] for p in poly]) - return [x1, y1, x2, y2] + def _load_ocr_info(self, data): + if self.infer_mode: + ocr_result = self.ocr_engine.ocr(data['image'], cls=False) + ocr_info = [] + for res in ocr_result: + ocr_info.append({ + "text": res[1][0], + "bbox": trans_poly_to_bbox(res[0]), + "poly": res[0], + }) + return ocr_info + else: + info = data['label'] + # read text info + info_dict = json.loads(info) + return info_dict["ocr_info"] - height, width, _ = data['image'].shape - ocr_result = self.ocr_engine.ocr(data['image'], cls=False) - ocr_info = [] - for res in ocr_result: - ocr_info.append({ - "text": res[1][0], - "bbox": trans_poly_to_bbox(res[0]), - "poly": res[0], - }) + def _smooth_box(self, bbox, height, width): + bbox[0] = int(bbox[0] * 1000.0 / width) + bbox[2] = int(bbox[2] * 1000.0 / width) + bbox[1] = int(bbox[1] * 1000.0 / height) + bbox[3] = int(bbox[3] * 1000.0 / height) + return bbox - segment_offset_id = [] - words_list = [] - bbox_list = [] - input_ids_list = [] - token_type_ids_list = [] - entities = [] - - for info in ocr_info: - # x1, y1, x2, y2 - bbox = copy.deepcopy(info["bbox"]) - bbox[0] = int(bbox[0] * 1000.0 / width) - bbox[2] = int(bbox[2] * 1000.0 / width) - bbox[1] = int(bbox[1] * 1000.0 / height) - bbox[3] = int(bbox[3] * 1000.0 / height) - - text = info["text"] - encode_res = self.tokenizer.encode( - text, pad_to_max_seq_len=False, return_attention_mask=True) - - if not self.add_special_ids: - # TODO: use tok.all_special_ids to remove - encode_res["input_ids"] = encode_res["input_ids"][1:-1] - encode_res["token_type_ids"] = encode_res["token_type_ids"][1: - -1] - encode_res["attention_mask"] = encode_res["attention_mask"][1: - -1] - - # for re - entities.append({ - "start": len(input_ids_list), - "end": len(input_ids_list) + len(encode_res["input_ids"]), - "label": "O", - }) - - input_ids_list.extend(encode_res["input_ids"]) - token_type_ids_list.extend(encode_res["token_type_ids"]) - bbox_list.extend([bbox] * len(encode_res["input_ids"])) - words_list.append(text) - segment_offset_id.append(len(input_ids_list)) - - encoded_inputs = { - "input_ids": input_ids_list, - "token_type_ids": token_type_ids_list, - "bbox": bbox_list, - "attention_mask": [1] * len(input_ids_list), - "entities": entities, - 'labels': None, - 'segment_offset_id': segment_offset_id, - 'ocr_info': ocr_info - } - data.update(encoded_inputs) - return data + def _parse_label(self, label, encode_res): + gt_label = [] + if label.lower() == "other": + gt_label.extend([0] * len(encode_res["input_ids"])) + else: + gt_label.append(self.label2id_map[("b-" + label).upper()]) + gt_label.extend([self.label2id_map[("i-" + label).upper()]] * + (len(encode_res["input_ids"]) - 1)) + return gt_label