diff --git a/configs/vqa/re/layoutxlm.yml b/configs/vqa/re/layoutxlm.yml new file mode 100644 index 0000000000..ca6b0d29db --- /dev/null +++ b/configs/vqa/re/layoutxlm.yml @@ -0,0 +1,122 @@ +Global: + use_gpu: True + epoch_num: &epoch_num 200 + log_smooth_window: 10 + print_batch_step: 10 + save_model_dir: ./output/re_layoutxlm/ + save_epoch_step: 2000 + # evaluation is run every 10 iterations after the 0th iteration + eval_batch_step: [ 0, 19 ] + cal_metric_during_train: False + save_inference_dir: + use_visualdl: False + seed: 2022 + infer_img: doc/vqa/input/zh_val_21.jpg + save_res_path: ./output/re/ + +Architecture: + model_type: vqa + algorithm: &algorithm "LayoutXLM" + Transform: + Backbone: + name: LayoutXLMForRe + pretrained: True + checkpoints: + +Loss: + name: LossFromOutput + key: loss + reduction: mean + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + clip_norm: 10 + lr: + learning_rate: 0.00005 + regularizer: + name: L2 + factor: 0.00000 + +PostProcess: + name: VQAReTokenLayoutLMPostProcess + +Metric: + name: VQAReTokenMetric + main_indicator: hmean + +Train: + dataset: + name: SimpleDataSet + data_dir: train_data/XFUND/zh_train/image + label_file_list: + - train_data/XFUND/zh_train/xfun_normalize_train.json + ratio_list: [ 1.0 ] + transforms: + - DecodeImage: # load image + img_mode: RGB + channel_first: False + - VQATokenLabelEncode: # Class handling label + contains_re: True + algorithm: *algorithm + class_path: &class_path ppstructure/vqa/labels/labels_ser.txt + - VQATokenPad: + max_seq_len: &max_seq_len 512 + return_attention_mask: True + - VQAReTokenRelation: + - VQAReTokenChunk: + max_seq_len: *max_seq_len + - Resize: + size: [224,224] + - NormalizeImage: + scale: 1 + mean: [ 123.675, 116.28, 103.53 ] + std: [ 58.395, 57.12, 57.375 ] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order + loader: + shuffle: True + drop_last: False + batch_size_per_card: 8 + num_workers: 4 + collate_fn: ListCollator + +Eval: + dataset: + name: SimpleDataSet + data_dir: train_data/XFUND/zh_val/image + label_file_list: + - train_data/XFUND/zh_val/xfun_normalize_val.json + transforms: + - DecodeImage: # load image + img_mode: RGB + channel_first: False + - VQATokenLabelEncode: # Class handling label + contains_re: True + algorithm: *algorithm + class_path: *class_path + - VQATokenPad: + max_seq_len: *max_seq_len + return_attention_mask: True + - VQAReTokenRelation: + - VQAReTokenChunk: + max_seq_len: *max_seq_len + - Resize: + size: [224,224] + - NormalizeImage: + scale: 1 + mean: [ 123.675, 116.28, 103.53 ] + std: [ 58.395, 57.12, 57.375 ] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 8 + num_workers: 4 + collate_fn: ListCollator diff --git a/configs/vqa/ser/layoutlm.yml b/configs/vqa/ser/layoutlm.yml new file mode 100644 index 0000000000..87131170c9 --- /dev/null +++ b/configs/vqa/ser/layoutlm.yml @@ -0,0 +1,120 @@ +Global: + use_gpu: True + epoch_num: &epoch_num 200 + log_smooth_window: 10 + print_batch_step: 10 + save_model_dir: ./output/ser_layoutlm/ + save_epoch_step: 2000 + # evaluation is run every 10 iterations after the 0th iteration + eval_batch_step: [ 0, 19 ] + cal_metric_during_train: False + save_inference_dir: + use_visualdl: False + seed: 2022 + infer_img: doc/vqa/input/zh_val_0.jpg + save_res_path: ./output/ser/ + +Architecture: + model_type: vqa + algorithm: &algorithm "LayoutLM" + Transform: + Backbone: + name: LayoutLMForSer + pretrained: True + checkpoints: + num_classes: &num_classes 7 + +Loss: + name: VQASerTokenLayoutLMLoss + num_classes: *num_classes + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + lr: + name: Linear + learning_rate: 0.00005 + epochs: *epoch_num + warmup_epoch: 2 + regularizer: + name: L2 + factor: 0.00000 + +PostProcess: + name: VQASerTokenLayoutLMPostProcess + class_path: &class_path ppstructure/vqa/labels/labels_ser.txt + +Metric: + name: VQASerTokenMetric + main_indicator: hmean + +Train: + dataset: + name: SimpleDataSet + data_dir: train_data/XFUND/zh_train/image + label_file_list: + - train_data/XFUND/zh_train/xfun_normalize_train.json + transforms: + - DecodeImage: # load image + img_mode: RGB + channel_first: False + - VQATokenLabelEncode: # Class handling label + contains_re: False + algorithm: *algorithm + class_path: *class_path + - VQATokenPad: + max_seq_len: &max_seq_len 512 + return_attention_mask: True + - VQASerTokenChunk: + max_seq_len: *max_seq_len + - Resize: + size: [224,224] + - NormalizeImage: + scale: 1 + mean: [ 123.675, 116.28, 103.53 ] + std: [ 58.395, 57.12, 57.375 ] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order + loader: + shuffle: True + drop_last: False + batch_size_per_card: 8 + num_workers: 4 + +Eval: + dataset: + name: SimpleDataSet + data_dir: train_data/XFUND/zh_val/image + label_file_list: + - train_data/XFUND/zh_val/xfun_normalize_val.json + transforms: + - DecodeImage: # load image + img_mode: RGB + channel_first: False + - VQATokenLabelEncode: # Class handling label + contains_re: False + algorithm: *algorithm + class_path: *class_path + - VQATokenPad: + max_seq_len: *max_seq_len + return_attention_mask: True + - VQASerTokenChunk: + max_seq_len: *max_seq_len + - Resize: + size: [224,224] + - NormalizeImage: + scale: 1 + mean: [ 123.675, 116.28, 103.53 ] + std: [ 58.395, 57.12, 57.375 ] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 8 + num_workers: 4 diff --git a/configs/vqa/ser/layoutxlm.yml b/configs/vqa/ser/layoutxlm.yml new file mode 100644 index 0000000000..eb1cca5a21 --- /dev/null +++ b/configs/vqa/ser/layoutxlm.yml @@ -0,0 +1,121 @@ +Global: + use_gpu: True + epoch_num: &epoch_num 200 + log_smooth_window: 10 + print_batch_step: 10 + save_model_dir: ./output/ser_layoutxlm/ + save_epoch_step: 2000 + # evaluation is run every 10 iterations after the 0th iteration + eval_batch_step: [ 0, 19 ] + cal_metric_during_train: False + save_inference_dir: + use_visualdl: False + seed: 2022 + infer_img: doc/vqa/input/zh_val_42.jpg + save_res_path: ./output/ser + +Architecture: + model_type: vqa + algorithm: &algorithm "LayoutXLM" + Transform: + Backbone: + name: LayoutXLMForSer + pretrained: True + checkpoints: + num_classes: &num_classes 7 + +Loss: + name: VQASerTokenLayoutLMLoss + num_classes: *num_classes + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + lr: + name: Linear + learning_rate: 0.00005 + epochs: *epoch_num + warmup_epoch: 2 + regularizer: + name: L2 + factor: 0.00000 + +PostProcess: + name: VQASerTokenLayoutLMPostProcess + class_path: &class_path ppstructure/vqa/labels/labels_ser.txt + +Metric: + name: VQASerTokenMetric + main_indicator: hmean + +Train: + dataset: + name: SimpleDataSet + data_dir: train_data/XFUND/zh_train/image + label_file_list: + - train_data/XFUND/zh_train/xfun_normalize_train.json + ratio_list: [ 1.0 ] + transforms: + - DecodeImage: # load image + img_mode: RGB + channel_first: False + - VQATokenLabelEncode: # Class handling label + contains_re: False + algorithm: *algorithm + class_path: *class_path + - VQATokenPad: + max_seq_len: &max_seq_len 512 + return_attention_mask: True + - VQASerTokenChunk: + max_seq_len: *max_seq_len + - Resize: + size: [224,224] + - NormalizeImage: + scale: 1 + mean: [ 123.675, 116.28, 103.53 ] + std: [ 58.395, 57.12, 57.375 ] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order + loader: + shuffle: True + drop_last: False + batch_size_per_card: 8 + num_workers: 4 + +Eval: + dataset: + name: SimpleDataSet + data_dir: train_data/XFUND/zh_val/image + label_file_list: + - train_data/XFUND/zh_val/xfun_normalize_val.json + transforms: + - DecodeImage: # load image + img_mode: RGB + channel_first: False + - VQATokenLabelEncode: # Class handling label + contains_re: False + algorithm: *algorithm + class_path: *class_path + - VQATokenPad: + max_seq_len: *max_seq_len + return_attention_mask: True + - VQASerTokenChunk: + max_seq_len: *max_seq_len + - Resize: + size: [224,224] + - NormalizeImage: + scale: 1 + mean: [ 123.675, 116.28, 103.53 ] + std: [ 58.395, 57.12, 57.375 ] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 8 + num_workers: 4 diff --git a/deploy/slim/quantization/export_model.py b/deploy/slim/quantization/export_model.py index 0cb86108d2..34cf80f5e5 100755 --- a/deploy/slim/quantization/export_model.py +++ b/deploy/slim/quantization/export_model.py @@ -76,7 +76,7 @@ def main(): } FLAGS = ArgsParser().parse_args() config = load_config(FLAGS.config) - merge_config(FLAGS.opt) + config = merge_config(config, FLAGS.opt) logger = get_logger() # build post process diff --git a/ppstructure/vqa/images/input/zh_val_0.jpg b/doc/vqa/input/zh_val_0.jpg similarity index 100% rename from ppstructure/vqa/images/input/zh_val_0.jpg rename to doc/vqa/input/zh_val_0.jpg diff --git a/ppstructure/vqa/images/input/zh_val_21.jpg b/doc/vqa/input/zh_val_21.jpg similarity index 100% rename from ppstructure/vqa/images/input/zh_val_21.jpg rename to doc/vqa/input/zh_val_21.jpg diff --git a/ppstructure/vqa/images/input/zh_val_40.jpg b/doc/vqa/input/zh_val_40.jpg similarity index 100% rename from ppstructure/vqa/images/input/zh_val_40.jpg rename to doc/vqa/input/zh_val_40.jpg diff --git a/ppstructure/vqa/images/input/zh_val_42.jpg b/doc/vqa/input/zh_val_42.jpg similarity index 100% rename from ppstructure/vqa/images/input/zh_val_42.jpg rename to doc/vqa/input/zh_val_42.jpg diff --git a/ppstructure/vqa/images/result_re/zh_val_21_re.jpg b/doc/vqa/result_re/zh_val_21_re.jpg similarity index 100% rename from ppstructure/vqa/images/result_re/zh_val_21_re.jpg rename to doc/vqa/result_re/zh_val_21_re.jpg diff --git a/ppstructure/vqa/images/result_re/zh_val_40_re.jpg b/doc/vqa/result_re/zh_val_40_re.jpg similarity index 100% rename from ppstructure/vqa/images/result_re/zh_val_40_re.jpg rename to doc/vqa/result_re/zh_val_40_re.jpg diff --git a/ppstructure/vqa/images/result_ser/zh_val_0_ser.jpg b/doc/vqa/result_ser/zh_val_0_ser.jpg similarity index 100% rename from ppstructure/vqa/images/result_ser/zh_val_0_ser.jpg rename to doc/vqa/result_ser/zh_val_0_ser.jpg diff --git a/ppstructure/vqa/images/result_ser/zh_val_42_ser.jpg b/doc/vqa/result_ser/zh_val_42_ser.jpg similarity index 100% rename from ppstructure/vqa/images/result_ser/zh_val_42_ser.jpg rename to doc/vqa/result_ser/zh_val_42_ser.jpg diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py index 429d5a528b..60ab7bd0b4 100644 --- a/ppocr/data/__init__.py +++ b/ppocr/data/__init__.py @@ -87,13 +87,19 @@ def build_dataloader(config, mode, device, logger, seed=None): shuffle=shuffle, drop_last=drop_last) + if 'collate_fn' in loader_config: + from . import collate_fn + collate_fn = getattr(collate_fn, loader_config['collate_fn'])() + else: + collate_fn = None data_loader = DataLoader( dataset=dataset, batch_sampler=batch_sampler, places=device, num_workers=num_workers, return_list=True, - use_shared_memory=use_shared_memory) + use_shared_memory=use_shared_memory, + collate_fn=collate_fn) # support exit using ctrl+c signal.signal(signal.SIGINT, term_mp) diff --git a/ppstructure/vqa/data_collator.py b/ppocr/data/collate_fn.py similarity index 59% rename from ppstructure/vqa/data_collator.py rename to ppocr/data/collate_fn.py index a969935b48..89c6b4fd5a 100644 --- a/ppstructure/vqa/data_collator.py +++ b/ppocr/data/collate_fn.py @@ -15,20 +15,20 @@ import paddle import numbers import numpy as np +from collections import defaultdict -class DataCollator: +class DictCollator(object): """ data batch """ def __call__(self, batch): - data_dict = {} + # todo:support batch operators + data_dict = defaultdict(list) to_tensor_keys = [] for sample in batch: for k, v in sample.items(): - if k not in data_dict: - data_dict[k] = [] if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)): if k not in to_tensor_keys: to_tensor_keys.append(k) @@ -36,3 +36,23 @@ class DataCollator: for k in to_tensor_keys: data_dict[k] = paddle.to_tensor(data_dict[k]) return data_dict + + +class ListCollator(object): + """ + data batch + """ + + def __call__(self, batch): + # todo:support batch operators + data_dict = defaultdict(list) + to_tensor_idxs = [] + for sample in batch: + for idx, v in enumerate(sample): + if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)): + if idx not in to_tensor_idxs: + to_tensor_idxs.append(idx) + data_dict[idx].append(v) + for idx in to_tensor_idxs: + data_dict[idx] = paddle.to_tensor(data_dict[idx]) + return list(data_dict.values()) diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 5aaa1cd71e..90a70875b9 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -34,6 +34,8 @@ from .sast_process import * from .pg_process import * from .gen_table_mask import * +from .vqa import * + def transform(data, ops=None): """ transform """ diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index f83255b732..786647f1f6 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -17,6 +17,7 @@ from __future__ import division from __future__ import print_function from __future__ import unicode_literals +import copy import numpy as np import string from shapely.geometry import LineString, Point, Polygon @@ -736,7 +737,7 @@ class TableLabelEncode(object): % beg_or_end else: assert False, "Unsupport type %s in char_or_elem" \ - % char_or_elem + % char_or_elem return idx @@ -782,3 +783,176 @@ class SARLabelEncode(BaseRecLabelEncode): def get_ignored_tokens(self): return [self.padding_idx] + + +class VQATokenLabelEncode(object): + """ + Label encode for NLP VQA methods + """ + + def __init__(self, + class_path, + contains_re=False, + add_special_ids=False, + algorithm='LayoutXLM', + infer_mode=False, + ocr_engine=None, + **kwargs): + super(VQATokenLabelEncode, self).__init__() + from paddlenlp.transformers import LayoutXLMTokenizer, LayoutLMTokenizer + from ppocr.utils.utility import load_vqa_bio_label_maps + tokenizer_dict = { + 'LayoutXLM': { + 'class': LayoutXLMTokenizer, + 'pretrained_model': 'layoutxlm-base-uncased' + }, + 'LayoutLM': { + 'class': LayoutLMTokenizer, + 'pretrained_model': 'layoutlm-base-uncased' + } + } + self.contains_re = contains_re + tokenizer_config = tokenizer_dict[algorithm] + self.tokenizer = tokenizer_config['class'].from_pretrained( + tokenizer_config['pretrained_model']) + self.label2id_map, id2label_map = load_vqa_bio_label_maps(class_path) + self.add_special_ids = add_special_ids + self.infer_mode = infer_mode + self.ocr_engine = ocr_engine + + def __call__(self, data): + # load bbox and label info + ocr_info = self._load_ocr_info(data) + + height, width, _ = data['image'].shape + + words_list = [] + bbox_list = [] + input_ids_list = [] + token_type_ids_list = [] + segment_offset_id = [] + gt_label_list = [] + + entities = [] + + # for re + train_re = self.contains_re and not self.infer_mode + if train_re: + relations = [] + id2label = {} + entity_id_to_index_map = {} + empty_entity = set() + + data['ocr_info'] = copy.deepcopy(ocr_info) + + for info in ocr_info: + if train_re: + # for re + if len(info["text"]) == 0: + empty_entity.add(info["id"]) + continue + id2label[info["id"]] = info["label"] + relations.extend([tuple(sorted(l)) for l in info["linking"]]) + # smooth_box + bbox = self._smooth_box(info["bbox"], height, width) + + text = info["text"] + encode_res = self.tokenizer.encode( + text, pad_to_max_seq_len=False, return_attention_mask=True) + + if not self.add_special_ids: + # TODO: use tok.all_special_ids to remove + encode_res["input_ids"] = encode_res["input_ids"][1:-1] + encode_res["token_type_ids"] = encode_res["token_type_ids"][1: + -1] + encode_res["attention_mask"] = encode_res["attention_mask"][1: + -1] + # parse label + if not self.infer_mode: + label = info['label'] + gt_label = self._parse_label(label, encode_res) + + # construct entities for re + if train_re: + if gt_label[0] != self.label2id_map["O"]: + entity_id_to_index_map[info["id"]] = len(entities) + label = label.upper() + entities.append({ + "start": len(input_ids_list), + "end": + len(input_ids_list) + len(encode_res["input_ids"]), + "label": label.upper(), + }) + else: + entities.append({ + "start": len(input_ids_list), + "end": len(input_ids_list) + len(encode_res["input_ids"]), + "label": 'O', + }) + input_ids_list.extend(encode_res["input_ids"]) + token_type_ids_list.extend(encode_res["token_type_ids"]) + bbox_list.extend([bbox] * len(encode_res["input_ids"])) + words_list.append(text) + segment_offset_id.append(len(input_ids_list)) + if not self.infer_mode: + gt_label_list.extend(gt_label) + + data['input_ids'] = input_ids_list + data['token_type_ids'] = token_type_ids_list + data['bbox'] = bbox_list + data['attention_mask'] = [1] * len(input_ids_list) + data['labels'] = gt_label_list + data['segment_offset_id'] = segment_offset_id + data['tokenizer_params'] = dict( + padding_side=self.tokenizer.padding_side, + pad_token_type_id=self.tokenizer.pad_token_type_id, + pad_token_id=self.tokenizer.pad_token_id) + data['entities'] = entities + + if train_re: + data['relations'] = relations + data['id2label'] = id2label + data['empty_entity'] = empty_entity + data['entity_id_to_index_map'] = entity_id_to_index_map + return data + + def _load_ocr_info(self, data): + def trans_poly_to_bbox(poly): + x1 = np.min([p[0] for p in poly]) + x2 = np.max([p[0] for p in poly]) + y1 = np.min([p[1] for p in poly]) + y2 = np.max([p[1] for p in poly]) + return [x1, y1, x2, y2] + + if self.infer_mode: + ocr_result = self.ocr_engine.ocr(data['image'], cls=False) + ocr_info = [] + for res in ocr_result: + ocr_info.append({ + "text": res[1][0], + "bbox": trans_poly_to_bbox(res[0]), + "poly": res[0], + }) + return ocr_info + else: + info = data['label'] + # read text info + info_dict = json.loads(info) + return info_dict["ocr_info"] + + def _smooth_box(self, bbox, height, width): + bbox[0] = int(bbox[0] * 1000.0 / width) + bbox[2] = int(bbox[2] * 1000.0 / width) + bbox[1] = int(bbox[1] * 1000.0 / height) + bbox[3] = int(bbox[3] * 1000.0 / height) + return bbox + + def _parse_label(self, label, encode_res): + gt_label = [] + if label.lower() == "other": + gt_label.extend([0] * len(encode_res["input_ids"])) + else: + gt_label.append(self.label2id_map[("b-" + label).upper()]) + gt_label.extend([self.label2id_map[("i-" + label).upper()]] * + (len(encode_res["input_ids"]) - 1)) + return gt_label diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py index daa67a25da..f6568affc8 100644 --- a/ppocr/data/imaug/operators.py +++ b/ppocr/data/imaug/operators.py @@ -170,17 +170,19 @@ class Resize(object): def __call__(self, data): img = data['image'] - text_polys = data['polys'] + if 'polys' in data: + text_polys = data['polys'] img_resize, [ratio_h, ratio_w] = self.resize_image(img) - new_boxes = [] - for box in text_polys: - new_box = [] - for cord in box: - new_box.append([cord[0] * ratio_w, cord[1] * ratio_h]) - new_boxes.append(new_box) + if 'polys' in data: + new_boxes = [] + for box in text_polys: + new_box = [] + for cord in box: + new_box.append([cord[0] * ratio_w, cord[1] * ratio_h]) + new_boxes.append(new_box) + data['polys'] = np.array(new_boxes, dtype=np.float32) data['image'] = img_resize - data['polys'] = np.array(new_boxes, dtype=np.float32) return data diff --git a/ppocr/data/imaug/vqa/__init__.py b/ppocr/data/imaug/vqa/__init__.py new file mode 100644 index 0000000000..a5025e7985 --- /dev/null +++ b/ppocr/data/imaug/vqa/__init__.py @@ -0,0 +1,19 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation + +__all__ = [ + 'VQATokenPad', 'VQASerTokenChunk', 'VQAReTokenChunk', 'VQAReTokenRelation' +] diff --git a/ppocr/data/imaug/vqa/token/__init__.py b/ppocr/data/imaug/vqa/token/__init__.py new file mode 100644 index 0000000000..7c11566175 --- /dev/null +++ b/ppocr/data/imaug/vqa/token/__init__.py @@ -0,0 +1,17 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .vqa_token_chunk import VQASerTokenChunk, VQAReTokenChunk +from .vqa_token_pad import VQATokenPad +from .vqa_token_relation import VQAReTokenRelation diff --git a/ppocr/data/imaug/vqa/token/vqa_token_chunk.py b/ppocr/data/imaug/vqa/token/vqa_token_chunk.py new file mode 100644 index 0000000000..deb55b4d55 --- /dev/null +++ b/ppocr/data/imaug/vqa/token/vqa_token_chunk.py @@ -0,0 +1,117 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class VQASerTokenChunk(object): + def __init__(self, max_seq_len=512, infer_mode=False, **kwargs): + self.max_seq_len = max_seq_len + self.infer_mode = infer_mode + + def __call__(self, data): + encoded_inputs_all = [] + seq_len = len(data['input_ids']) + for index in range(0, seq_len, self.max_seq_len): + chunk_beg = index + chunk_end = min(index + self.max_seq_len, seq_len) + encoded_inputs_example = {} + for key in data: + if key in [ + 'label', 'input_ids', 'labels', 'token_type_ids', + 'bbox', 'attention_mask' + ]: + if self.infer_mode and key == 'labels': + encoded_inputs_example[key] = data[key] + else: + encoded_inputs_example[key] = data[key][chunk_beg: + chunk_end] + else: + encoded_inputs_example[key] = data[key] + + encoded_inputs_all.append(encoded_inputs_example) + return encoded_inputs_all[0] + + +class VQAReTokenChunk(object): + def __init__(self, + max_seq_len=512, + entities_labels=None, + infer_mode=False, + **kwargs): + self.max_seq_len = max_seq_len + self.entities_labels = { + 'HEADER': 0, + 'QUESTION': 1, + 'ANSWER': 2 + } if entities_labels is None else entities_labels + self.infer_mode = infer_mode + + def __call__(self, data): + # prepare data + entities = data.pop('entities') + relations = data.pop('relations') + encoded_inputs_all = [] + for index in range(0, len(data["input_ids"]), self.max_seq_len): + item = {} + for key in data: + if key in [ + 'label', 'input_ids', 'labels', 'token_type_ids', + 'bbox', 'attention_mask' + ]: + if self.infer_mode and key == 'labels': + item[key] = data[key] + else: + item[key] = data[key][index:index + self.max_seq_len] + else: + item[key] = data[key] + # select entity in current chunk + entities_in_this_span = [] + global_to_local_map = {} # + for entity_id, entity in enumerate(entities): + if (index <= entity["start"] < index + self.max_seq_len and + index <= entity["end"] < index + self.max_seq_len): + entity["start"] = entity["start"] - index + entity["end"] = entity["end"] - index + global_to_local_map[entity_id] = len(entities_in_this_span) + entities_in_this_span.append(entity) + + # select relations in current chunk + relations_in_this_span = [] + for relation in relations: + if (index <= relation["start_index"] < index + self.max_seq_len + and index <= relation["end_index"] < + index + self.max_seq_len): + relations_in_this_span.append({ + "head": global_to_local_map[relation["head"]], + "tail": global_to_local_map[relation["tail"]], + "start_index": relation["start_index"] - index, + "end_index": relation["end_index"] - index, + }) + item.update({ + "entities": self.reformat(entities_in_this_span), + "relations": self.reformat(relations_in_this_span), + }) + item['entities']['label'] = [ + self.entities_labels[x] for x in item['entities']['label'] + ] + encoded_inputs_all.append(item) + return encoded_inputs_all[0] + + def reformat(self, data): + new_data = {} + for item in data: + for k, v in item.items(): + if k not in new_data: + new_data[k] = [] + new_data[k].append(v) + return new_data diff --git a/ppocr/data/imaug/vqa/token/vqa_token_pad.py b/ppocr/data/imaug/vqa/token/vqa_token_pad.py new file mode 100644 index 0000000000..8e5a20f95f --- /dev/null +++ b/ppocr/data/imaug/vqa/token/vqa_token_pad.py @@ -0,0 +1,104 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import numpy as np + + +class VQATokenPad(object): + def __init__(self, + max_seq_len=512, + pad_to_max_seq_len=True, + return_attention_mask=True, + return_token_type_ids=True, + truncation_strategy="longest_first", + return_overflowing_tokens=False, + return_special_tokens_mask=False, + infer_mode=False, + **kwargs): + self.max_seq_len = max_seq_len + self.pad_to_max_seq_len = max_seq_len + self.return_attention_mask = return_attention_mask + self.return_token_type_ids = return_token_type_ids + self.truncation_strategy = truncation_strategy + self.return_overflowing_tokens = return_overflowing_tokens + self.return_special_tokens_mask = return_special_tokens_mask + self.pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index + self.infer_mode = infer_mode + + def __call__(self, data): + needs_to_be_padded = self.pad_to_max_seq_len and len(data[ + "input_ids"]) < self.max_seq_len + + if needs_to_be_padded: + if 'tokenizer_params' in data: + tokenizer_params = data.pop('tokenizer_params') + else: + tokenizer_params = dict( + padding_side='right', pad_token_type_id=0, pad_token_id=1) + + difference = self.max_seq_len - len(data["input_ids"]) + if tokenizer_params['padding_side'] == 'right': + if self.return_attention_mask: + data["attention_mask"] = [1] * len(data[ + "input_ids"]) + [0] * difference + if self.return_token_type_ids: + data["token_type_ids"] = ( + data["token_type_ids"] + + [tokenizer_params['pad_token_type_id']] * difference) + if self.return_special_tokens_mask: + data["special_tokens_mask"] = data[ + "special_tokens_mask"] + [1] * difference + data["input_ids"] = data["input_ids"] + [ + tokenizer_params['pad_token_id'] + ] * difference + if not self.infer_mode: + data["labels"] = data[ + "labels"] + [self.pad_token_label_id] * difference + data["bbox"] = data["bbox"] + [[0, 0, 0, 0]] * difference + elif tokenizer_params['padding_side'] == 'left': + if self.return_attention_mask: + data["attention_mask"] = [0] * difference + [ + 1 + ] * len(data["input_ids"]) + if self.return_token_type_ids: + data["token_type_ids"] = ( + [tokenizer_params['pad_token_type_id']] * difference + + data["token_type_ids"]) + if self.return_special_tokens_mask: + data["special_tokens_mask"] = [ + 1 + ] * difference + data["special_tokens_mask"] + data["input_ids"] = [tokenizer_params['pad_token_id'] + ] * difference + data["input_ids"] + if not self.infer_mode: + data["labels"] = [self.pad_token_label_id + ] * difference + data["labels"] + data["bbox"] = [[0, 0, 0, 0]] * difference + data["bbox"] + else: + if self.return_attention_mask: + data["attention_mask"] = [1] * len(data["input_ids"]) + + for key in data: + if key in [ + 'input_ids', 'labels', 'token_type_ids', 'bbox', + 'attention_mask' + ]: + if self.infer_mode: + if key != 'labels': + length = min(len(data[key]), self.max_seq_len) + data[key] = data[key][:length] + else: + continue + data[key] = np.array(data[key], dtype='int64') + return data diff --git a/ppocr/data/imaug/vqa/token/vqa_token_relation.py b/ppocr/data/imaug/vqa/token/vqa_token_relation.py new file mode 100644 index 0000000000..293988ff85 --- /dev/null +++ b/ppocr/data/imaug/vqa/token/vqa_token_relation.py @@ -0,0 +1,67 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class VQAReTokenRelation(object): + def __init__(self, **kwargs): + pass + + def __call__(self, data): + """ + build relations + """ + entities = data['entities'] + relations = data['relations'] + id2label = data.pop('id2label') + empty_entity = data.pop('empty_entity') + entity_id_to_index_map = data.pop('entity_id_to_index_map') + + relations = list(set(relations)) + relations = [ + rel for rel in relations + if rel[0] not in empty_entity and rel[1] not in empty_entity + ] + kv_relations = [] + for rel in relations: + pair = [id2label[rel[0]], id2label[rel[1]]] + if pair == ["question", "answer"]: + kv_relations.append({ + "head": entity_id_to_index_map[rel[0]], + "tail": entity_id_to_index_map[rel[1]] + }) + elif pair == ["answer", "question"]: + kv_relations.append({ + "head": entity_id_to_index_map[rel[1]], + "tail": entity_id_to_index_map[rel[0]] + }) + else: + continue + relations = sorted( + [{ + "head": rel["head"], + "tail": rel["tail"], + "start_index": self.get_relation_span(rel, entities)[0], + "end_index": self.get_relation_span(rel, entities)[1], + } for rel in kv_relations], + key=lambda x: x["head"], ) + + data['relations'] = relations + return data + + def get_relation_span(self, rel, entities): + bound = [] + for entity_index in [rel["head"], rel["tail"]]: + bound.append(entities[entity_index]["start"]) + bound.append(entities[entity_index]["end"]) + return min(bound), max(bound) diff --git a/ppocr/data/lmdb_dataset.py b/ppocr/data/lmdb_dataset.py index e2d6dc9327..e1b49809d1 100644 --- a/ppocr/data/lmdb_dataset.py +++ b/ppocr/data/lmdb_dataset.py @@ -38,6 +38,9 @@ class LMDBDataSet(Dataset): np.random.shuffle(self.data_idx_order_list) self.ops = create_operators(dataset_config['transforms'], global_config) + ratio_list = dataset_config.get("ratio_list", [1.0]) + self.need_reset = True in [x < 1 for x in ratio_list] + def load_hierarchical_lmdb_dataset(self, data_dir): lmdb_sets = {} dataset_idx = 0 diff --git a/ppocr/data/pgnet_dataset.py b/ppocr/data/pgnet_dataset.py index 5adcd02c4a..6f80179c4e 100644 --- a/ppocr/data/pgnet_dataset.py +++ b/ppocr/data/pgnet_dataset.py @@ -49,6 +49,8 @@ class PGDataSet(Dataset): self.ops = create_operators(dataset_config['transforms'], global_config) + self.need_reset = True in [x < 1 for x in ratio_list] + def shuffle_data_random(self): if self.do_shuffle: random.seed(self.seed) diff --git a/ppocr/data/pubtab_dataset.py b/ppocr/data/pubtab_dataset.py index 78b76c5afb..671cda76fb 100644 --- a/ppocr/data/pubtab_dataset.py +++ b/ppocr/data/pubtab_dataset.py @@ -53,6 +53,9 @@ class PubTabDataSet(Dataset): self.shuffle_data_random() self.ops = create_operators(dataset_config['transforms'], global_config) + ratio_list = dataset_config.get("ratio_list", [1.0]) + self.need_reset = True in [x < 1 for x in ratio_list] + def shuffle_data_random(self): if self.do_shuffle: random.seed(self.seed) @@ -70,7 +73,7 @@ class PubTabDataSet(Dataset): prob = self.img_select_prob[file_name] if prob < random.uniform(0, 1): select_flag = False - + if self.table_select_type: structure = info['html']['structure']['tokens'].copy() structure_str = ''.join(structure) @@ -79,13 +82,17 @@ class PubTabDataSet(Dataset): table_type = "complex" if table_type == "complex": if self.table_select_prob < random.uniform(0, 1): - select_flag = False - + select_flag = False + if select_flag: cells = info['html']['cells'].copy() structure = info['html']['structure'].copy() img_path = os.path.join(self.data_dir, file_name) - data = {'img_path': img_path, 'cells': cells, 'structure':structure} + data = { + 'img_path': img_path, + 'cells': cells, + 'structure': structure + } if not os.path.exists(img_path): raise Exception("{} does not exist!".format(img_path)) with open(data['img_path'], 'rb') as f: diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py index 9f0ce352d9..10b6b7a891 100644 --- a/ppocr/data/simple_dataset.py +++ b/ppocr/data/simple_dataset.py @@ -41,7 +41,6 @@ class SimpleDataSet(Dataset): ) == data_source_num, "The length of ratio_list should be the same as the file_list." self.data_dir = dataset_config['data_dir'] self.do_shuffle = loader_config['shuffle'] - self.seed = seed logger.info("Initialize indexs of datasets:%s" % label_file_list) self.data_lines = self.get_image_info_list(label_file_list, ratio_list) @@ -50,6 +49,8 @@ class SimpleDataSet(Dataset): self.shuffle_data_random() self.ops = create_operators(dataset_config['transforms'], global_config) + self.need_reset = True in [x < 1 for x in ratio_list] + def get_image_info_list(self, file_list, ratio_list): if isinstance(file_list, str): file_list = [file_list] diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 62ad2b6ad8..56e6d25d4b 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -16,6 +16,9 @@ import copy import paddle import paddle.nn as nn +# basic_loss +from .basic_loss import LossFromOutput + # det loss from .det_db_loss import DBLoss from .det_east_loss import EASTLoss @@ -46,12 +49,16 @@ from .combined_loss import CombinedLoss # table loss from .table_att_loss import TableAttentionLoss +# vqa token loss +from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss + def build_loss(config): support_dict = [ 'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', - 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss' + 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', + 'VQASerTokenLayoutLMLoss', 'LossFromOutput' ] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/basic_loss.py b/ppocr/losses/basic_loss.py index d2ef5e5ac9..fc64c133a4 100644 --- a/ppocr/losses/basic_loss.py +++ b/ppocr/losses/basic_loss.py @@ -133,3 +133,18 @@ class DistanceLoss(nn.Layer): def forward(self, x, y): return self.loss_func(x, y) + + +class LossFromOutput(nn.Layer): + def __init__(self, key='loss', reduction='none'): + super().__init__() + self.key = key + self.reduction = reduction + + def forward(self, predicts, batch): + loss = predicts[self.key] + if self.reduction == 'mean': + loss = paddle.mean(loss) + elif self.reduction == 'sum': + loss = paddle.sum(loss) + return {'loss': loss} diff --git a/ppstructure/vqa/losses.py b/ppocr/losses/vqa_token_layoutlm_loss.py old mode 100644 new mode 100755 similarity index 66% rename from ppstructure/vqa/losses.py rename to ppocr/losses/vqa_token_layoutlm_loss.py index e8dad01c31..244893d97d --- a/ppstructure/vqa/losses.py +++ b/ppocr/losses/vqa_token_layoutlm_loss.py @@ -1,10 +1,10 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,24 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + from paddle import nn -class SERLoss(nn.Layer): +class VQASerTokenLayoutLMLoss(nn.Layer): def __init__(self, num_classes): super().__init__() self.loss_class = nn.CrossEntropyLoss() self.num_classes = num_classes self.ignore_index = self.loss_class.ignore_index - def forward(self, labels, outputs, attention_mask): + def forward(self, predicts, batch): + labels = batch[1] + attention_mask = batch[4] if attention_mask is not None: active_loss = attention_mask.reshape([-1, ]) == 1 - active_outputs = outputs.reshape( + active_outputs = predicts.reshape( [-1, self.num_classes])[active_loss] active_labels = labels.reshape([-1, ])[active_loss] loss = self.loss_class(active_outputs, active_labels) else: loss = self.loss_class( - outputs.reshape([-1, self.num_classes]), labels.reshape([-1, ])) - return loss + predicts.reshape([-1, self.num_classes]), + labels.reshape([-1, ])) + return {'loss': loss} diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py index 28bff3cb4e..604ae548df 100644 --- a/ppocr/metrics/__init__.py +++ b/ppocr/metrics/__init__.py @@ -28,12 +28,15 @@ from .e2e_metric import E2EMetric from .distillation_metric import DistillationMetric from .table_metric import TableMetric from .kie_metric import KIEMetric +from .vqa_token_ser_metric import VQASerTokenMetric +from .vqa_token_re_metric import VQAReTokenMetric def build_metric(config): support_dict = [ "DetMetric", "RecMetric", "ClsMetric", "E2EMetric", - "DistillationMetric", "TableMetric", 'KIEMetric' + "DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric', + 'VQAReTokenMetric' ] config = copy.deepcopy(config) diff --git a/ppocr/metrics/vqa_token_re_metric.py b/ppocr/metrics/vqa_token_re_metric.py new file mode 100644 index 0000000000..8a13bc0812 --- /dev/null +++ b/ppocr/metrics/vqa_token_re_metric.py @@ -0,0 +1,176 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import paddle + +__all__ = ['KIEMetric'] + + +class VQAReTokenMetric(object): + def __init__(self, main_indicator='hmean', **kwargs): + self.main_indicator = main_indicator + self.reset() + + def __call__(self, preds, batch, **kwargs): + pred_relations, relations, entities = preds + self.pred_relations_list.extend(pred_relations) + self.relations_list.extend(relations) + self.entities_list.extend(entities) + + def get_metric(self): + gt_relations = [] + for b in range(len(self.relations_list)): + rel_sent = [] + for head, tail in zip(self.relations_list[b]["head"], + self.relations_list[b]["tail"]): + rel = {} + rel["head_id"] = head + rel["head"] = (self.entities_list[b]["start"][rel["head_id"]], + self.entities_list[b]["end"][rel["head_id"]]) + rel["head_type"] = self.entities_list[b]["label"][rel[ + "head_id"]] + + rel["tail_id"] = tail + rel["tail"] = (self.entities_list[b]["start"][rel["tail_id"]], + self.entities_list[b]["end"][rel["tail_id"]]) + rel["tail_type"] = self.entities_list[b]["label"][rel[ + "tail_id"]] + + rel["type"] = 1 + rel_sent.append(rel) + gt_relations.append(rel_sent) + re_metrics = self.re_score( + self.pred_relations_list, gt_relations, mode="boundaries") + metrics = { + "precision": re_metrics["ALL"]["p"], + "recall": re_metrics["ALL"]["r"], + "hmean": re_metrics["ALL"]["f1"], + } + self.reset() + return metrics + + def reset(self): + self.pred_relations_list = [] + self.relations_list = [] + self.entities_list = [] + + def re_score(self, pred_relations, gt_relations, mode="strict"): + """Evaluate RE predictions + + Args: + pred_relations (list) : list of list of predicted relations (several relations in each sentence) + gt_relations (list) : list of list of ground truth relations + + rel = { "head": (start_idx (inclusive), end_idx (exclusive)), + "tail": (start_idx (inclusive), end_idx (exclusive)), + "head_type": ent_type, + "tail_type": ent_type, + "type": rel_type} + + vocab (Vocab) : dataset vocabulary + mode (str) : in 'strict' or 'boundaries'""" + + assert mode in ["strict", "boundaries"] + + relation_types = [v for v in [0, 1] if not v == 0] + scores = { + rel: { + "tp": 0, + "fp": 0, + "fn": 0 + } + for rel in relation_types + ["ALL"] + } + + # Count GT relations and Predicted relations + n_sents = len(gt_relations) + n_rels = sum([len([rel for rel in sent]) for sent in gt_relations]) + n_found = sum([len([rel for rel in sent]) for sent in pred_relations]) + + # Count TP, FP and FN per type + for pred_sent, gt_sent in zip(pred_relations, gt_relations): + for rel_type in relation_types: + # strict mode takes argument types into account + if mode == "strict": + pred_rels = {(rel["head"], rel["head_type"], rel["tail"], + rel["tail_type"]) + for rel in pred_sent + if rel["type"] == rel_type} + gt_rels = {(rel["head"], rel["head_type"], rel["tail"], + rel["tail_type"]) + for rel in gt_sent if rel["type"] == rel_type} + + # boundaries mode only takes argument spans into account + elif mode == "boundaries": + pred_rels = {(rel["head"], rel["tail"]) + for rel in pred_sent + if rel["type"] == rel_type} + gt_rels = {(rel["head"], rel["tail"]) + for rel in gt_sent if rel["type"] == rel_type} + + scores[rel_type]["tp"] += len(pred_rels & gt_rels) + scores[rel_type]["fp"] += len(pred_rels - gt_rels) + scores[rel_type]["fn"] += len(gt_rels - pred_rels) + + # Compute per entity Precision / Recall / F1 + for rel_type in scores.keys(): + if scores[rel_type]["tp"]: + scores[rel_type]["p"] = scores[rel_type]["tp"] / ( + scores[rel_type]["fp"] + scores[rel_type]["tp"]) + scores[rel_type]["r"] = scores[rel_type]["tp"] / ( + scores[rel_type]["fn"] + scores[rel_type]["tp"]) + else: + scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0 + + if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0: + scores[rel_type]["f1"] = ( + 2 * scores[rel_type]["p"] * scores[rel_type]["r"] / + (scores[rel_type]["p"] + scores[rel_type]["r"])) + else: + scores[rel_type]["f1"] = 0 + + # Compute micro F1 Scores + tp = sum([scores[rel_type]["tp"] for rel_type in relation_types]) + fp = sum([scores[rel_type]["fp"] for rel_type in relation_types]) + fn = sum([scores[rel_type]["fn"] for rel_type in relation_types]) + + if tp: + precision = tp / (tp + fp) + recall = tp / (tp + fn) + f1 = 2 * precision * recall / (precision + recall) + + else: + precision, recall, f1 = 0, 0, 0 + + scores["ALL"]["p"] = precision + scores["ALL"]["r"] = recall + scores["ALL"]["f1"] = f1 + scores["ALL"]["tp"] = tp + scores["ALL"]["fp"] = fp + scores["ALL"]["fn"] = fn + + # Compute Macro F1 Scores + scores["ALL"]["Macro_f1"] = np.mean( + [scores[ent_type]["f1"] for ent_type in relation_types]) + scores["ALL"]["Macro_p"] = np.mean( + [scores[ent_type]["p"] for ent_type in relation_types]) + scores["ALL"]["Macro_r"] = np.mean( + [scores[ent_type]["r"] for ent_type in relation_types]) + + return scores diff --git a/ppocr/metrics/vqa_token_ser_metric.py b/ppocr/metrics/vqa_token_ser_metric.py new file mode 100644 index 0000000000..92d80d0970 --- /dev/null +++ b/ppocr/metrics/vqa_token_ser_metric.py @@ -0,0 +1,47 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import paddle + +__all__ = ['KIEMetric'] + + +class VQASerTokenMetric(object): + def __init__(self, main_indicator='hmean', **kwargs): + self.main_indicator = main_indicator + self.reset() + + def __call__(self, preds, batch, **kwargs): + preds, labels = preds + self.pred_list.extend(preds) + self.gt_list.extend(labels) + + def get_metric(self): + from seqeval.metrics import f1_score, precision_score, recall_score + metircs = { + "precision": precision_score(self.gt_list, self.pred_list), + "recall": recall_score(self.gt_list, self.pred_list), + "hmean": f1_score(self.gt_list, self.pred_list), + } + self.reset() + return metircs + + def reset(self): + self.pred_list = [] + self.gt_list = [] diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py index c498d9862a..e622db2567 100644 --- a/ppocr/modeling/architectures/base_model.py +++ b/ppocr/modeling/architectures/base_model.py @@ -63,8 +63,12 @@ class BaseModel(nn.Layer): in_channels = self.neck.out_channels # # build head, head is need for det, rec and cls - config["Head"]['in_channels'] = in_channels - self.head = build_head(config["Head"]) + if 'Head' not in config or config['Head'] is None: + self.use_head = False + else: + self.use_head = True + config["Head"]['in_channels'] = in_channels + self.head = build_head(config["Head"]) self.return_all_feats = config.get("return_all_feats", False) @@ -77,7 +81,8 @@ class BaseModel(nn.Layer): if self.use_neck: x = self.neck(x) y["neck_out"] = x - x = self.head(x, targets=data) + if self.use_head: + x = self.head(x, targets=data) if isinstance(x, dict): y.update(x) else: diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index d10983487b..1af87a8156 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -43,6 +43,9 @@ def build_backbone(config, model_type): from .table_resnet_vd import ResNet from .table_mobilenet_v3 import MobileNetV3 support_dict = ["ResNet", "MobileNetV3"] + elif model_type == 'vqa': + from .vqa_layoutlm import LayoutLMForSer, LayoutXLMForSer, LayoutXLMForRe + support_dict = ["LayoutLMForSer", "LayoutXLMForSer", 'LayoutXLMForRe'] else: raise NotImplementedError diff --git a/ppocr/modeling/backbones/vqa_layoutlm.py b/ppocr/modeling/backbones/vqa_layoutlm.py new file mode 100644 index 0000000000..0e98155514 --- /dev/null +++ b/ppocr/modeling/backbones/vqa_layoutlm.py @@ -0,0 +1,125 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +from paddle import nn + +from paddlenlp.transformers import LayoutXLMModel, LayoutXLMForTokenClassification, LayoutXLMForRelationExtraction +from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification + +__all__ = ["LayoutXLMForSer", 'LayoutLMForSer'] + +pretrained_model_dict = { + LayoutXLMModel: 'layoutxlm-base-uncased', + LayoutLMModel: 'layoutlm-base-uncased' +} + + +class NLPBaseModel(nn.Layer): + def __init__(self, + base_model_class, + model_class, + type='ser', + pretrained=True, + checkpoints=None, + **kwargs): + super(NLPBaseModel, self).__init__() + if checkpoints is not None: + self.model = model_class.from_pretrained(checkpoints) + else: + pretrained_model_name = pretrained_model_dict[base_model_class] + if pretrained: + base_model = base_model_class.from_pretrained( + pretrained_model_name) + else: + base_model = base_model_class( + **base_model_class.pretrained_init_configuration[ + pretrained_model_name]) + if type == 'ser': + self.model = model_class( + base_model, num_classes=kwargs['num_classes'], dropout=None) + else: + self.model = model_class(base_model, dropout=None) + self.out_channels = 1 + + +class LayoutXLMForSer(NLPBaseModel): + def __init__(self, num_classes, pretrained=True, checkpoints=None, + **kwargs): + super(LayoutXLMForSer, self).__init__( + LayoutXLMModel, + LayoutXLMForTokenClassification, + 'ser', + pretrained, + checkpoints, + num_classes=num_classes) + + def forward(self, x): + x = self.model( + input_ids=x[0], + bbox=x[2], + image=x[3], + attention_mask=x[4], + token_type_ids=x[5], + position_ids=None, + head_mask=None, + labels=None) + return x[0] + + +class LayoutLMForSer(NLPBaseModel): + def __init__(self, num_classes, pretrained=True, checkpoints=None, + **kwargs): + super(LayoutLMForSer, self).__init__( + LayoutLMModel, + LayoutLMForTokenClassification, + 'ser', + pretrained, + checkpoints, + num_classes=num_classes) + + def forward(self, x): + x = self.model( + input_ids=x[0], + bbox=x[2], + attention_mask=x[4], + token_type_ids=x[5], + position_ids=None, + output_hidden_states=False) + return x + + +class LayoutXLMForRe(NLPBaseModel): + def __init__(self, pretrained=True, checkpoints=None, **kwargs): + super(LayoutXLMForRe, self).__init__(LayoutXLMModel, + LayoutXLMForRelationExtraction, + 're', pretrained, checkpoints) + + def forward(self, x): + x = self.model( + input_ids=x[0], + bbox=x[1], + labels=None, + image=x[2], + attention_mask=x[3], + token_type_ids=x[4], + position_ids=None, + head_mask=None, + entities=x[5], + relations=x[6]) + return x diff --git a/ppocr/optimizer/__init__.py b/ppocr/optimizer/__init__.py index c729103a70..e0c6b90371 100644 --- a/ppocr/optimizer/__init__.py +++ b/ppocr/optimizer/__init__.py @@ -42,7 +42,9 @@ def build_optimizer(config, epochs, step_each_epoch, parameters): # step2 build regularization if 'regularizer' in config and config['regularizer'] is not None: reg_config = config.pop('regularizer') - reg_name = reg_config.pop('name') + 'Decay' + reg_name = reg_config.pop('name') + if not hasattr(regularizer, reg_name): + reg_name += 'Decay' reg = getattr(regularizer, reg_name)(**reg_config)() else: reg = None diff --git a/ppocr/optimizer/optimizer.py b/ppocr/optimizer/optimizer.py index 34098c0fad..b98081227e 100644 --- a/ppocr/optimizer/optimizer.py +++ b/ppocr/optimizer/optimizer.py @@ -158,3 +158,38 @@ class Adadelta(object): name=self.name, parameters=parameters) return opt + + +class AdamW(object): + def __init__(self, + learning_rate=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + weight_decay=0.01, + grad_clip=None, + name=None, + lazy_mode=False, + **kwargs): + self.learning_rate = learning_rate + self.beta1 = beta1 + self.beta2 = beta2 + self.epsilon = epsilon + self.learning_rate = learning_rate + self.weight_decay = 0.01 if weight_decay is None else weight_decay + self.grad_clip = grad_clip + self.name = name + self.lazy_mode = lazy_mode + + def __call__(self, parameters): + opt = optim.AdamW( + learning_rate=self.learning_rate, + beta1=self.beta1, + beta2=self.beta2, + epsilon=self.epsilon, + weight_decay=self.weight_decay, + grad_clip=self.grad_clip, + name=self.name, + lazy_mode=self.lazy_mode, + parameters=parameters) + return opt diff --git a/ppocr/optimizer/regularizer.py b/ppocr/optimizer/regularizer.py index c6396f338d..2ce68f7139 100644 --- a/ppocr/optimizer/regularizer.py +++ b/ppocr/optimizer/regularizer.py @@ -29,24 +29,23 @@ class L1Decay(object): def __init__(self, factor=0.0): super(L1Decay, self).__init__() - self.regularization_coeff = factor + self.coeff = factor def __call__(self): - reg = paddle.regularizer.L1Decay(self.regularization_coeff) + reg = paddle.regularizer.L1Decay(self.coeff) return reg class L2Decay(object): """ - L2 Weight Decay Regularization, which encourages the weights to be sparse. + L2 Weight Decay Regularization, which helps to prevent the model over-fitting. Args: factor(float): regularization coeff. Default:0.0. """ def __init__(self, factor=0.0): super(L2Decay, self).__init__() - self.regularization_coeff = factor + self.coeff = float(factor) def __call__(self): - reg = paddle.regularizer.L2Decay(self.regularization_coeff) - return reg + return self.coeff \ No newline at end of file diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 37dadd12d3..811bf57b64 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -28,6 +28,8 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, Di TableLabelDecode, NRTRLabelDecode, SARLabelDecode, SEEDLabelDecode from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess +from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess +from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess def build_post_process(config, global_config=None): @@ -36,7 +38,8 @@ def build_post_process(config, global_config=None): 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode', 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode', - 'SEEDLabelDecode' + 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess', + 'VQAReTokenLayoutLMPostProcess' ] if config['name'] == 'PSEPostProcess': diff --git a/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py b/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py new file mode 100644 index 0000000000..1d55d13d76 --- /dev/null +++ b/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py @@ -0,0 +1,51 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle + + +class VQAReTokenLayoutLMPostProcess(object): + """ Convert between text-label and text-index """ + + def __init__(self, **kwargs): + super(VQAReTokenLayoutLMPostProcess, self).__init__() + + def __call__(self, preds, label=None, *args, **kwargs): + if label is not None: + return self._metric(preds, label) + else: + return self._infer(preds, *args, **kwargs) + + def _metric(self, preds, label): + return preds['pred_relations'], label[6], label[5] + + def _infer(self, preds, *args, **kwargs): + ser_results = kwargs['ser_results'] + entity_idx_dict_batch = kwargs['entity_idx_dict_batch'] + pred_relations = preds['pred_relations'] + + # merge relations and ocr info + results = [] + for pred_relation, ser_result, entity_idx_dict in zip( + pred_relations, ser_results, entity_idx_dict_batch): + result = [] + used_tail_id = [] + for relation in pred_relation: + if relation['tail_id'] in used_tail_id: + continue + used_tail_id.append(relation['tail_id']) + ocr_info_head = ser_result[entity_idx_dict[relation['head_id']]] + ocr_info_tail = ser_result[entity_idx_dict[relation['tail_id']]] + result.append((ocr_info_head, ocr_info_tail)) + results.append(result) + return results diff --git a/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py b/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py new file mode 100644 index 0000000000..782cdea6c5 --- /dev/null +++ b/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py @@ -0,0 +1,93 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import paddle +from ppocr.utils.utility import load_vqa_bio_label_maps + + +class VQASerTokenLayoutLMPostProcess(object): + """ Convert between text-label and text-index """ + + def __init__(self, class_path, **kwargs): + super(VQASerTokenLayoutLMPostProcess, self).__init__() + label2id_map, self.id2label_map = load_vqa_bio_label_maps(class_path) + + self.label2id_map_for_draw = dict() + for key in label2id_map: + if key.startswith("I-"): + self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]] + else: + self.label2id_map_for_draw[key] = label2id_map[key] + + self.id2label_map_for_show = dict() + for key in self.label2id_map_for_draw: + val = self.label2id_map_for_draw[key] + if key == "O": + self.id2label_map_for_show[val] = key + if key.startswith("B-") or key.startswith("I-"): + self.id2label_map_for_show[val] = key[2:] + else: + self.id2label_map_for_show[val] = key + + def __call__(self, preds, batch=None, *args, **kwargs): + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + + if batch is not None: + return self._metric(preds, batch[1]) + else: + return self._infer(preds, **kwargs) + + def _metric(self, preds, label): + pred_idxs = preds.argmax(axis=2) + decode_out_list = [[] for _ in range(pred_idxs.shape[0])] + label_decode_out_list = [[] for _ in range(pred_idxs.shape[0])] + + for i in range(pred_idxs.shape[0]): + for j in range(pred_idxs.shape[1]): + if label[i, j] != -100: + label_decode_out_list[i].append(self.id2label_map[label[i, + j]]) + decode_out_list[i].append(self.id2label_map[pred_idxs[i, + j]]) + return decode_out_list, label_decode_out_list + + def _infer(self, preds, attention_masks, segment_offset_ids, ocr_infos): + results = [] + + for pred, attention_mask, segment_offset_id, ocr_info in zip( + preds, attention_masks, segment_offset_ids, ocr_infos): + pred = np.argmax(pred, axis=1) + pred = [self.id2label_map[idx] for idx in pred] + + for idx in range(len(segment_offset_id)): + if idx == 0: + start_id = 0 + else: + start_id = segment_offset_id[idx - 1] + + end_id = segment_offset_id[idx] + + curr_pred = pred[start_id:end_id] + curr_pred = [self.label2id_map_for_draw[p] for p in curr_pred] + + if len(curr_pred) <= 0: + pred_id = 0 + else: + counts = np.bincount(curr_pred) + pred_id = np.argmax(counts) + ocr_info[idx]["pred_id"] = int(pred_id) + ocr_info[idx]["pred"] = self.id2label_map_for_show[int(pred_id)] + results.append(ocr_info) + return results diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 0dd94e86c8..b09f1db6e9 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -44,7 +44,7 @@ def _mkdir_if_not_exist(path, logger): raise OSError('Failed to mkdir {}'.format(path)) -def load_model(config, model, optimizer=None): +def load_model(config, model, optimizer=None, model_type='det'): """ load model from checkpoint or pretrained_model """ @@ -53,6 +53,33 @@ def load_model(config, model, optimizer=None): checkpoints = global_config.get('checkpoints') pretrained_model = global_config.get('pretrained_model') best_model_dict = {} + + if model_type == 'vqa': + checkpoints = config['Architecture']['Backbone']['checkpoints'] + # load vqa method metric + if checkpoints: + if os.path.exists(os.path.join(checkpoints, 'metric.states')): + with open(os.path.join(checkpoints, 'metric.states'), + 'rb') as f: + states_dict = pickle.load(f) if six.PY2 else pickle.load( + f, encoding='latin1') + best_model_dict = states_dict.get('best_model_dict', {}) + if 'epoch' in states_dict: + best_model_dict['start_epoch'] = states_dict['epoch'] + 1 + logger.info("resume from {}".format(checkpoints)) + + if optimizer is not None: + if checkpoints[-1] in ['/', '\\']: + checkpoints = checkpoints[:-1] + if os.path.exists(checkpoints + '.pdopt'): + optim_dict = paddle.load(checkpoints + '.pdopt') + optimizer.set_state_dict(optim_dict) + else: + logger.warning( + "{}.pdopt is not exists, params of optimizer is not loaded". + format(checkpoints)) + return best_model_dict + if checkpoints: if checkpoints.endswith('.pdparams'): checkpoints = checkpoints.replace('.pdparams', '') @@ -130,6 +157,7 @@ def save_model(model, optimizer, model_path, logger, + config, is_best=False, prefix='ppocr', **kwargs): @@ -138,13 +166,20 @@ def save_model(model, """ _mkdir_if_not_exist(model_path, logger) model_prefix = os.path.join(model_path, prefix) - paddle.save(model.state_dict(), model_prefix + '.pdparams') paddle.save(optimizer.state_dict(), model_prefix + '.pdopt') - + if config['Architecture']["model_type"] != 'vqa': + paddle.save(model.state_dict(), model_prefix + '.pdparams') + metric_prefix = model_prefix + else: + if config['Global']['distributed']: + model._layers.backbone.model.save_pretrained(model_prefix) + else: + model.backbone.model.save_pretrained(model_prefix) + metric_prefix = os.path.join(model_prefix, 'metric') # save metric and config - with open(model_prefix + '.states', 'wb') as f: - pickle.dump(kwargs, f, protocol=2) if is_best: + with open(metric_prefix + '.states', 'wb') as f: + pickle.dump(kwargs, f, protocol=2) logger.info('save best model is to {}'.format(model_prefix)) else: logger.info("save model in {}".format(model_prefix)) diff --git a/ppocr/utils/utility.py b/ppocr/utils/utility.py index 7bb4c906d2..76484dfd3d 100755 --- a/ppocr/utils/utility.py +++ b/ppocr/utils/utility.py @@ -16,6 +16,9 @@ import logging import os import imghdr import cv2 +import random +import numpy as np +import paddle def print_dict(d, logger, delimiter=0): @@ -77,4 +80,28 @@ def check_and_read_gif(img_path): frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) imgvalue = frame[:, :, ::-1] return imgvalue, True - return None, False \ No newline at end of file + return None, False + + +def load_vqa_bio_label_maps(label_map_path): + with open(label_map_path, "r", encoding='utf-8') as fin: + lines = fin.readlines() + lines = [line.strip() for line in lines] + if "O" not in lines: + lines.insert(0, "O") + labels = [] + for line in lines: + if line == "O": + labels.append("O") + else: + labels.append("B-" + line) + labels.append("I-" + line) + label2id_map = {label: idx for idx, label in enumerate(labels)} + id2label_map = {idx: label for idx, label in enumerate(labels)} + return label2id_map, id2label_map + + +def set_seed(seed=1024): + random.seed(seed) + np.random.seed(seed) + paddle.seed(seed) diff --git a/ppocr/utils/visual.py b/ppocr/utils/visual.py new file mode 100644 index 0000000000..7a8c1674a7 --- /dev/null +++ b/ppocr/utils/visual.py @@ -0,0 +1,98 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import numpy as np +from PIL import Image, ImageDraw, ImageFont + + +def draw_ser_results(image, + ocr_results, + font_path="doc/fonts/simfang.ttf", + font_size=18): + np.random.seed(2021) + color = (np.random.permutation(range(255)), + np.random.permutation(range(255)), + np.random.permutation(range(255))) + color_map = { + idx: (color[0][idx], color[1][idx], color[2][idx]) + for idx in range(1, 255) + } + if isinstance(image, np.ndarray): + image = Image.fromarray(image) + elif isinstance(image, str) and os.path.isfile(image): + image = Image.open(image).convert('RGB') + img_new = image.copy() + draw = ImageDraw.Draw(img_new) + + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + for ocr_info in ocr_results: + if ocr_info["pred_id"] not in color_map: + continue + color = color_map[ocr_info["pred_id"]] + text = "{}: {}".format(ocr_info["pred"], ocr_info["text"]) + + draw_box_txt(ocr_info["bbox"], text, draw, font, font_size, color) + + img_new = Image.blend(image, img_new, 0.5) + return np.array(img_new) + + +def draw_box_txt(bbox, text, draw, font, font_size, color): + # draw ocr results outline + bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3])) + draw.rectangle(bbox, fill=color) + + # draw ocr results + start_y = max(0, bbox[0][1] - font_size) + tw = font.getsize(text)[0] + draw.rectangle( + [(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, start_y + font_size)], + fill=(0, 0, 255)) + draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font) + + +def draw_re_results(image, + result, + font_path="doc/fonts/simfang.ttf", + font_size=18): + np.random.seed(0) + if isinstance(image, np.ndarray): + image = Image.fromarray(image) + elif isinstance(image, str) and os.path.isfile(image): + image = Image.open(image).convert('RGB') + img_new = image.copy() + draw = ImageDraw.Draw(img_new) + + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + color_head = (0, 0, 255) + color_tail = (255, 0, 0) + color_line = (0, 255, 0) + + for ocr_info_head, ocr_info_tail in result: + draw_box_txt(ocr_info_head["bbox"], ocr_info_head["text"], draw, font, + font_size, color_head) + draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["text"], draw, font, + font_size, color_tail) + + center_head = ( + (ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2, + (ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2) + center_tail = ( + (ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2, + (ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2) + + draw.line([center_head, center_tail], fill=color_line, width=5) + + img_new = Image.blend(image, img_new, 0.5) + return np.array(img_new) diff --git a/ppstructure/docs/model_list.md b/ppstructure/docs/model_list.md index 45004490c1..baec2a2fd0 100644 --- a/ppstructure/docs/model_list.md +++ b/ppstructure/docs/model_list.md @@ -24,8 +24,8 @@ |ęØ”åž‹åē§°|ęØ”åž‹ē®€ä»‹|ęŽØē†ęØ”åž‹å¤§å°|äø‹č½½åœ°å€| | --- | --- | --- | --- | -|PP-Layout_v1.0_ser_pretrained|åŸŗäŗŽLayoutXLM在xfunäø­ę–‡ę•°ę®é›†äøŠč®­ē»ƒēš„SERęØ”åž‹|1.4G|[ęŽØē†ęØ”åž‹ coming soon]() / [č®­ē»ƒęØ”åž‹](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar) | -|PP-Layout_v1.0_re_pretrained|åŸŗäŗŽLayoutXLM在xfunäø­ę–‡ę•°ę®é›†äøŠč®­ē»ƒēš„REęØ”åž‹|1.4G|[ęŽØē†ęØ”åž‹ coming soon]() / [č®­ē»ƒęØ”åž‹](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar) | +|PP-Layout_v1.0_ser_pretrained|åŸŗäŗŽLayoutXLM在xfunäø­ę–‡ę•°ę®é›†äøŠč®­ē»ƒēš„SERęØ”åž‹|1.4G|[ęŽØē†ęØ”åž‹ coming soon]() / [č®­ē»ƒęØ”åž‹](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) | +|PP-Layout_v1.0_re_pretrained|åŸŗäŗŽLayoutXLM在xfunäø­ę–‡ę•°ę®é›†äøŠč®­ē»ƒēš„REęØ”åž‹|1.4G|[ęŽØē†ęØ”åž‹ coming soon]() / [č®­ē»ƒęØ”åž‹](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) | ## 3. KIEęØ”åž‹ diff --git a/ppstructure/vqa/README.md b/ppstructure/vqa/README.md index 36f2e6aad0..7f4ca119f7 100644 --- a/ppstructure/vqa/README.md +++ b/ppstructure/vqa/README.md @@ -20,11 +20,11 @@ PP-Structure é‡Œēš„ DOC-VQAē®—ę³•åŸŗäŗŽPaddleNLP自然语言处理算法库进 ęˆ‘ä»¬åœØ [XFUN](https://github.com/doc-analysis/XFUND) ēš„äø­ę–‡ę•°ę®é›†äøŠåÆ¹ē®—ę³•čæ›č”Œäŗ†čÆ„ä¼°ļ¼Œę€§čƒ½å¦‚äø‹ -| ęØ”åž‹ | 任劔 | f1 | ęØ”åž‹äø‹č½½åœ°å€ | +| ęØ”åž‹ | 任劔 | hmean | ęØ”åž‹äø‹č½½åœ°å€ | |:---:|:---:|:---:| :---:| -| LayoutXLM | RE | 0.7113 | [é“¾ęŽ„](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar) | -| LayoutXLM | SER | 0.9056 | [é“¾ęŽ„](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar) | -| LayoutLM | SER | 0.78 | [é“¾ęŽ„](https://paddleocr.bj.bcebos.com/pplayout/LayoutLM_ser_pretrained.tar) | +| LayoutXLM | RE | 0.7483 | [é“¾ęŽ„](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) | +| LayoutXLM | SER | 0.9038 | [é“¾ęŽ„](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) | +| LayoutLM | SER | 0.7731 | [é“¾ęŽ„](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) | @@ -34,7 +34,7 @@ PP-Structure é‡Œēš„ DOC-VQAē®—ę³•åŸŗäŗŽPaddleNLP自然语言处理算法库进 ### 2.1 SER -![](./images/result_ser/zh_val_0_ser.jpg) | ![](./images/result_ser/zh_val_42_ser.jpg) +![](../../doc/vqa/result_ser/zh_val_0_ser.jpg) | ![](../../doc/vqa/result_ser/zh_val_42_ser.jpg) ---|--- å›¾äø­äøåŒé¢œč‰²ēš„ę”†č”Øē¤ŗäøåŒēš„ē±»åˆ«ļ¼ŒåÆ¹äŗŽXFUNę•°ę®é›†ļ¼Œęœ‰`QUESTION`, `ANSWER`, `HEADER` 3ē§ē±»åˆ« @@ -48,7 +48,7 @@ PP-Structure é‡Œēš„ DOC-VQAē®—ę³•åŸŗäŗŽPaddleNLP自然语言处理算法库进 ### 2.2 RE -![](./images/result_re/zh_val_21_re.jpg) | ![](./images/result_re/zh_val_40_re.jpg) +![](../../doc/vqa/result_re/zh_val_21_re.jpg) | ![](../../doc/vqa/result_re/zh_val_40_re.jpg) ---|--- @@ -65,10 +65,10 @@ PP-Structure é‡Œēš„ DOC-VQAē®—ę³•åŸŗäŗŽPaddleNLP自然语言处理算法库进 python3 -m pip install --upgrade pip # GPU安装 -python3 -m pip install paddlepaddle-gpu==2.2 -i https://mirror.baidu.com/pypi/simple +python3 -m pip install "paddlepaddle-gpu>=2.2" -i https://mirror.baidu.com/pypi/simple # CPU安装 -python3 -m pip install paddlepaddle==2.2 -i https://mirror.baidu.com/pypi/simple +python3 -m pip install "paddlepaddle>=2.2" -i https://mirror.baidu.com/pypi/simple ``` ę›“å¤šéœ€ę±‚ļ¼ŒčÆ·å‚ē…§[安装文攣](https://www.paddlepaddle.org.cn/install/quick)äø­ēš„čÆ“ę˜Žčæ›č”Œę“ä½œć€‚ @@ -93,11 +93,10 @@ git clone https://gitee.com/paddlepaddle/PaddleOCR # ę³Øļ¼šē äŗ‘ę‰˜ē®”ä»£ē åÆčƒ½ę— ę³•å®žę—¶åŒę­„ęœ¬githubé”¹ē›®ę›“ę–°ļ¼Œå­˜åœØ3~5å¤©å»¶ę—¶ļ¼ŒčÆ·ä¼˜å…ˆä½æē”ØęŽØčę–¹å¼ć€‚ ``` -- **(4)安装VQAēš„`requirements`** +- **(3)安装VQAēš„`requirements`** ```bash -cd ppstructure/vqa -python3 -m pip install -r requirements.txt +python3 -m pip install -r ppstructure/vqa/requirements.txt ``` ## 4. 使用 @@ -105,6 +104,10 @@ python3 -m pip install -r requirements.txt ### 4.1 ę•°ę®å’Œé¢„č®­ē»ƒęØ”åž‹å‡†å¤‡ +å¦‚ęžœåøŒęœ›ē›“ęŽ„ä½“éŖŒé¢„ęµ‹čæ‡ēØ‹ļ¼ŒåÆä»„äø‹č½½ęˆ‘ä»¬ęä¾›ēš„é¢„č®­ē»ƒęØ”åž‹ļ¼Œč·³čæ‡č®­ē»ƒčæ‡ēØ‹ļ¼Œē›“ęŽ„é¢„ęµ‹å³åÆć€‚ + +* äø‹č½½å¤„ē†å„½ēš„ę•°ę®é›† + å¤„ē†å„½ēš„XFUNäø­ę–‡ę•°ę®é›†äø‹č½½åœ°å€ļ¼š[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar)怂 @@ -114,98 +117,62 @@ python3 -m pip install -r requirements.txt wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar ``` -å¦‚ęžœåøŒęœ›č½¬ę¢XFUNäø­å…¶ä»–čÆ­čØ€ēš„ę•°ę®é›†ļ¼ŒåÆä»„å‚č€ƒ[XFUNę•°ę®č½¬ę¢č„šęœ¬](helper/trans_xfun_data.py)怂 +* č½¬ę¢ę•°ę®é›† -å¦‚ęžœåøŒęœ›ē›“ęŽ„ä½“éŖŒé¢„ęµ‹čæ‡ēØ‹ļ¼ŒåÆä»„äø‹č½½ęˆ‘ä»¬ęä¾›ēš„é¢„č®­ē»ƒęØ”åž‹ļ¼Œč·³čæ‡č®­ē»ƒčæ‡ēØ‹ļ¼Œē›“ęŽ„é¢„ęµ‹å³åÆć€‚ +č‹„éœ€čæ›č”Œå…¶ä»–XFUNę•°ę®é›†ēš„č®­ē»ƒļ¼ŒåÆä½æē”Øäø‹é¢ēš„å‘½ä»¤čæ›č”Œę•°ę®é›†ēš„č½¬ę¢ +```bash +python3 ppstructure/vqa/helper/trans_xfun_data.py --ori_gt_path=path/to/json_path --output_path=path/to/save_path +``` ### 4.2 SER任劔 -* 启动训练 +åÆåŠØč®­ē»ƒä¹‹å‰ļ¼Œéœ€č¦äæ®ę”¹äø‹é¢ēš„å››äøŖå­—ę®µ +1. `Train.dataset.data_dir`ļ¼šęŒ‡å‘č®­ē»ƒé›†å›¾ē‰‡å­˜ę”¾ē›®å½• +2. `Train.dataset.label_file_list`ļ¼šęŒ‡å‘č®­ē»ƒé›†ę ‡ę³Øę–‡ä»¶ +3. `Eval.dataset.data_dir`ļ¼šęŒ‡ęŒ‡å‘éŖŒčÆé›†å›¾ē‰‡å­˜ę”¾ē›®å½• +4. `Eval.dataset.label_file_list`ļ¼šęŒ‡å‘éŖŒčÆé›†ę ‡ę³Øę–‡ä»¶ + +* 启动训练 ```shell -python3 train_ser.py \ - --model_name_or_path "layoutxlm-base-uncased" \ - --ser_model_type "LayoutXLM" \ - --train_data_dir "XFUND/zh_train/image" \ - --train_label_path "XFUND/zh_train/xfun_normalize_train.json" \ - --eval_data_dir "XFUND/zh_val/image" \ - --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \ - --num_train_epochs 200 \ - --eval_steps 10 \ - --output_dir "./output/ser/" \ - --learning_rate 5e-5 \ - --warmup_steps 50 \ - --evaluate_during_training \ - --seed 2048 +CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/ser/layoutxlm.yml ``` -ęœ€ē»ˆä¼šę‰“å°å‡ŗ`precision`, `recall`, `f1`ē­‰ęŒ‡ę ‡ļ¼ŒęØ”åž‹å’Œč®­ē»ƒę—„åæ—ä¼šäæå­˜åœØ`./output/ser/`文件夹中。 +ęœ€ē»ˆä¼šę‰“å°å‡ŗ`precision`, `recall`, `hmean`ē­‰ęŒ‡ę ‡ć€‚ +在`./output/ser_layoutxlm/`ę–‡ä»¶å¤¹äø­ä¼šäæå­˜č®­ē»ƒę—„åæ—ļ¼Œęœ€ä¼˜ēš„ęØ”åž‹å’Œęœ€ę–°epochēš„ęØ”åž‹ć€‚ * ę¢å¤č®­ē»ƒ +ę¢å¤č®­ē»ƒéœ€č¦å°†ä¹‹å‰č®­ē»ƒå„½ēš„ęØ”åž‹ę‰€åœØę–‡ä»¶å¤¹č·Æå¾„čµ‹å€¼ē»™ `Architecture.Backbone.checkpoints` 字段。 + ```shell -python3 train_ser.py \ - --model_name_or_path "model_path" \ - --ser_model_type "LayoutXLM" \ - --train_data_dir "XFUND/zh_train/image" \ - --train_label_path "XFUND/zh_train/xfun_normalize_train.json" \ - --eval_data_dir "XFUND/zh_val/image" \ - --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \ - --num_train_epochs 200 \ - --eval_steps 10 \ - --output_dir "./output/ser/" \ - --learning_rate 5e-5 \ - --warmup_steps 50 \ - --evaluate_during_training \ - --num_workers 8 \ - --seed 2048 \ - --resume +CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir ``` * 评估 -```shell -export CUDA_VISIBLE_DEVICES=0 -python3 eval_ser.py \ - --model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \ - --ser_model_type "LayoutXLM" \ - --eval_data_dir "XFUND/zh_val/image" \ - --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \ - --per_gpu_eval_batch_size 8 \ - --num_workers 8 \ - --output_dir "output/ser/" \ - --seed 2048 -``` -ęœ€ē»ˆä¼šę‰“å°å‡ŗ`precision`, `recall`, `f1`ē­‰ęŒ‡ę ‡ -* ä½æē”ØčÆ„ä¼°é›†åˆäø­ęä¾›ēš„OCRčÆ†åˆ«ē»“ęžœčæ›č”Œé¢„ęµ‹ +čÆ„ä¼°éœ€č¦å°†å¾…čÆ„ä¼°ēš„ęØ”åž‹ę‰€åœØę–‡ä»¶å¤¹č·Æå¾„čµ‹å€¼ē»™ `Architecture.Backbone.checkpoints` 字段。 ```shell -export CUDA_VISIBLE_DEVICES=0 -python3 infer_ser.py \ - --model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \ - --ser_model_type "LayoutXLM" \ - --output_dir "output/ser/" \ - --infer_imgs "XFUND/zh_val/image/" \ - --ocr_json_path "XFUND/zh_val/xfun_normalize_val.json" +CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir ``` +ęœ€ē»ˆä¼šę‰“å°å‡ŗ`precision`, `recall`, `hmean`ē­‰ęŒ‡ę ‡ -ęœ€ē»ˆä¼šåœØ`output_res`ē›®å½•äø‹äæå­˜é¢„ęµ‹ē»“ęžœåÆč§†åŒ–å›¾åƒä»„åŠé¢„ęµ‹ē»“ęžœę–‡ęœ¬ę–‡ä»¶ļ¼Œę–‡ä»¶åäøŗ`infer_results.txt`怂 +* 使用`OCRå¼•ę“Ž + SER`串联预测 -* 使用`OCRå¼•ę“Ž + SER`äø²č”ē»“ęžœ +ä½æē”Øå¦‚äø‹å‘½ä»¤å³åÆå®Œęˆ`OCRå¼•ę“Ž + SER`ēš„äø²č”é¢„ęµ‹ ```shell -export CUDA_VISIBLE_DEVICES=0 -python3 infer_ser_e2e.py \ - --model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \ - --ser_model_type "LayoutXLM" \ - --max_seq_length 512 \ - --output_dir "output/ser_e2e/" \ - --infer_imgs "images/input/zh_val_0.jpg" +CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=PP-Layout_v1.0_ser_pretrained/ Global.infer_img=doc/vqa/input/zh_val_42.jpg ``` +ęœ€ē»ˆä¼šåœØ`config.Global.save_res_path`å­—ę®µę‰€é…ē½®ēš„ē›®å½•äø‹äæå­˜é¢„ęµ‹ē»“ęžœåÆč§†åŒ–å›¾åƒä»„åŠé¢„ęµ‹ē»“ęžœę–‡ęœ¬ę–‡ä»¶ļ¼Œé¢„ęµ‹ē»“ęžœę–‡ęœ¬ę–‡ä»¶åäøŗ`infer_results.txt`怂 + * 对`OCRå¼•ę“Ž + SER`é¢„ęµ‹ē³»ē»Ÿčæ›č”Œē«Æåˆ°ē«ÆčÆ„ä¼° +é¦–å…ˆä½æē”Ø `tools/infer_vqa_token_ser.py` č„šęœ¬å®Œęˆę•°ę®é›†ēš„é¢„ęµ‹ļ¼Œē„¶åŽä½æē”Øäø‹é¢ēš„å‘½ä»¤čæ›č”ŒčÆ„ä¼°ć€‚ + ```shell export CUDA_VISIBLE_DEVICES=0 python3 helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt @@ -216,102 +183,48 @@ python3 helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_norma * 启动训练 -```shell -export CUDA_VISIBLE_DEVICES=0 -python3 train_re.py \ - --model_name_or_path "layoutxlm-base-uncased" \ - --train_data_dir "XFUND/zh_train/image" \ - --train_label_path "XFUND/zh_train/xfun_normalize_train.json" \ - --eval_data_dir "XFUND/zh_val/image" \ - --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \ - --label_map_path "labels/labels_ser.txt" \ - --num_train_epochs 200 \ - --eval_steps 10 \ - --output_dir "output/re/" \ - --learning_rate 5e-5 \ - --warmup_steps 50 \ - --per_gpu_train_batch_size 8 \ - --per_gpu_eval_batch_size 8 \ - --num_workers 8 \ - --evaluate_during_training \ - --seed 2048 +åÆåŠØč®­ē»ƒä¹‹å‰ļ¼Œéœ€č¦äæ®ę”¹äø‹é¢ēš„å››äøŖå­—ę®µ +1. `Train.dataset.data_dir`ļ¼šęŒ‡å‘č®­ē»ƒé›†å›¾ē‰‡å­˜ę”¾ē›®å½• +2. `Train.dataset.label_file_list`ļ¼šęŒ‡å‘č®­ē»ƒé›†ę ‡ę³Øę–‡ä»¶ +3. `Eval.dataset.data_dir`ļ¼šęŒ‡ęŒ‡å‘éŖŒčÆé›†å›¾ē‰‡å­˜ę”¾ē›®å½• +4. `Eval.dataset.label_file_list`ļ¼šęŒ‡å‘éŖŒčÆé›†ę ‡ę³Øę–‡ä»¶ + +```shell +CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/re/layoutxlm.yml ``` +ęœ€ē»ˆä¼šę‰“å°å‡ŗ`precision`, `recall`, `hmean`ē­‰ęŒ‡ę ‡ć€‚ +在`./output/re_layoutxlm/`ę–‡ä»¶å¤¹äø­ä¼šäæå­˜č®­ē»ƒę—„åæ—ļ¼Œęœ€ä¼˜ēš„ęØ”åž‹å’Œęœ€ę–°epochēš„ęØ”åž‹ć€‚ + * ę¢å¤č®­ē»ƒ +ę¢å¤č®­ē»ƒéœ€č¦å°†ä¹‹å‰č®­ē»ƒå„½ēš„ęØ”åž‹ę‰€åœØę–‡ä»¶å¤¹č·Æå¾„čµ‹å€¼ē»™ `Architecture.Backbone.checkpoints` 字段。 + ```shell -export CUDA_VISIBLE_DEVICES=0 -python3 train_re.py \ - --model_name_or_path "model_path" \ - --train_data_dir "XFUND/zh_train/image" \ - --train_label_path "XFUND/zh_train/xfun_normalize_train.json" \ - --eval_data_dir "XFUND/zh_val/image" \ - --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \ - --label_map_path "labels/labels_ser.txt" \ - --num_train_epochs 2 \ - --eval_steps 10 \ - --output_dir "output/re/" \ - --learning_rate 5e-5 \ - --warmup_steps 50 \ - --per_gpu_train_batch_size 8 \ - --per_gpu_eval_batch_size 8 \ - --num_workers 8 \ - --evaluate_during_training \ - --seed 2048 \ - --resume - +CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir ``` -ęœ€ē»ˆä¼šę‰“å°å‡ŗ`precision`, `recall`, `f1`ē­‰ęŒ‡ę ‡ļ¼ŒęØ”åž‹å’Œč®­ē»ƒę—„åæ—ä¼šäæå­˜åœØ`./output/re/`文件夹中。 - * 评估 -```shell -export CUDA_VISIBLE_DEVICES=0 -python3 eval_re.py \ - --model_name_or_path "PP-Layout_v1.0_re_pretrained/" \ - --max_seq_length 512 \ - --eval_data_dir "XFUND/zh_val/image" \ - --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \ - --label_map_path "labels/labels_ser.txt" \ - --output_dir "output/re/" \ - --per_gpu_eval_batch_size 8 \ - --num_workers 8 \ - --seed 2048 -``` -ęœ€ē»ˆä¼šę‰“å°å‡ŗ`precision`, `recall`, `f1`ē­‰ęŒ‡ę ‡ - -* ä½æē”ØčÆ„ä¼°é›†åˆäø­ęä¾›ēš„OCRčÆ†åˆ«ē»“ęžœčæ›č”Œé¢„ęµ‹ +čÆ„ä¼°éœ€č¦å°†å¾…čÆ„ä¼°ēš„ęØ”åž‹ę‰€åœØę–‡ä»¶å¤¹č·Æå¾„čµ‹å€¼ē»™ `Architecture.Backbone.checkpoints` 字段。 ```shell -export CUDA_VISIBLE_DEVICES=0 -python3 infer_re.py \ - --model_name_or_path "PP-Layout_v1.0_re_pretrained/" \ - --max_seq_length 512 \ - --eval_data_dir "XFUND/zh_val/image" \ - --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \ - --label_map_path "labels/labels_ser.txt" \ - --output_dir "output/re/" \ - --per_gpu_eval_batch_size 1 \ - --seed 2048 +CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir ``` +ęœ€ē»ˆä¼šę‰“å°å‡ŗ`precision`, `recall`, `hmean`ē­‰ęŒ‡ę ‡ -ęœ€ē»ˆä¼šåœØ`output_res`ē›®å½•äø‹äæå­˜é¢„ęµ‹ē»“ęžœåÆč§†åŒ–å›¾åƒä»„åŠé¢„ęµ‹ē»“ęžœę–‡ęœ¬ę–‡ä»¶ļ¼Œę–‡ä»¶åäøŗ`infer_results.txt`怂 - -* 使用`OCRå¼•ę“Ž + SER + RE`äø²č”ē»“ęžœ +* 使用`OCRå¼•ę“Ž + SER + RE`串联预测 +ä½æē”Øå¦‚äø‹å‘½ä»¤å³åÆå®Œęˆ`OCRå¼•ę“Ž + SER + RE`ēš„äø²č”é¢„ęµ‹ ```shell export CUDA_VISIBLE_DEVICES=0 -python3 infer_ser_re_e2e.py \ - --model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \ - --re_model_name_or_path "PP-Layout_v1.0_re_pretrained/" \ - --ser_model_type "LayoutXLM" \ - --max_seq_length 512 \ - --output_dir "output/ser_re_e2e/" \ - --infer_imgs "images/input/zh_val_21.jpg" +python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=PP-Layout_v1.0_re_pretrained/ Global.infer_img=doc/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=PP-Layout_v1.0_ser_pretrained/ ``` +ęœ€ē»ˆä¼šåœØ`config.Global.save_res_path`å­—ę®µę‰€é…ē½®ēš„ē›®å½•äø‹äæå­˜é¢„ęµ‹ē»“ęžœåÆč§†åŒ–å›¾åƒä»„åŠé¢„ęµ‹ē»“ęžœę–‡ęœ¬ę–‡ä»¶ļ¼Œé¢„ęµ‹ē»“ęžœę–‡ęœ¬ę–‡ä»¶åäøŗ`infer_results.txt`怂 + + ## å‚č€ƒé“¾ęŽ„ - LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf diff --git a/ppstructure/vqa/eval_re.py b/ppstructure/vqa/eval_re.py deleted file mode 100644 index 68c27bad8a..0000000000 --- a/ppstructure/vqa/eval_re.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import sys - -__dir__ = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) - -import paddle - -from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction - -from xfun import XFUNDataset -from vqa_utils import parse_args, get_bio_label_maps, print_arguments -from data_collator import DataCollator -from metric import re_score - -from ppocr.utils.logging import get_logger - - -def cal_metric(re_preds, re_labels, entities): - gt_relations = [] - for b in range(len(re_labels)): - rel_sent = [] - for head, tail in zip(re_labels[b]["head"], re_labels[b]["tail"]): - rel = {} - rel["head_id"] = head - rel["head"] = (entities[b]["start"][rel["head_id"]], - entities[b]["end"][rel["head_id"]]) - rel["head_type"] = entities[b]["label"][rel["head_id"]] - - rel["tail_id"] = tail - rel["tail"] = (entities[b]["start"][rel["tail_id"]], - entities[b]["end"][rel["tail_id"]]) - rel["tail_type"] = entities[b]["label"][rel["tail_id"]] - - rel["type"] = 1 - rel_sent.append(rel) - gt_relations.append(rel_sent) - re_metrics = re_score(re_preds, gt_relations, mode="boundaries") - return re_metrics - - -def evaluate(model, eval_dataloader, logger, prefix=""): - # Eval! - logger.info("***** Running evaluation {} *****".format(prefix)) - logger.info(" Num examples = {}".format(len(eval_dataloader.dataset))) - - re_preds = [] - re_labels = [] - entities = [] - eval_loss = 0.0 - model.eval() - for idx, batch in enumerate(eval_dataloader): - with paddle.no_grad(): - outputs = model(**batch) - loss = outputs['loss'].mean().item() - if paddle.distributed.get_rank() == 0: - logger.info("[Eval] process: {}/{}, loss: {:.5f}".format( - idx, len(eval_dataloader), loss)) - - eval_loss += loss - re_preds.extend(outputs['pred_relations']) - re_labels.extend(batch['relations']) - entities.extend(batch['entities']) - re_metrics = cal_metric(re_preds, re_labels, entities) - re_metrics = { - "precision": re_metrics["ALL"]["p"], - "recall": re_metrics["ALL"]["r"], - "f1": re_metrics["ALL"]["f1"], - } - model.train() - return re_metrics - - -def eval(args): - logger = get_logger() - label2id_map, id2label_map = get_bio_label_maps(args.label_map_path) - pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index - - tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path) - - model = LayoutXLMForRelationExtraction.from_pretrained( - args.model_name_or_path) - - eval_dataset = XFUNDataset( - tokenizer, - data_dir=args.eval_data_dir, - label_path=args.eval_label_path, - label2id_map=label2id_map, - img_size=(224, 224), - max_seq_len=args.max_seq_length, - pad_token_label_id=pad_token_label_id, - contains_re=True, - add_special_ids=False, - return_attention_mask=True, - load_mode='all') - - eval_dataloader = paddle.io.DataLoader( - eval_dataset, - batch_size=args.per_gpu_eval_batch_size, - num_workers=args.num_workers, - shuffle=False, - collate_fn=DataCollator()) - - results = evaluate(model, eval_dataloader, logger) - logger.info("eval results: {}".format(results)) - - -if __name__ == "__main__": - args = parse_args() - eval(args) diff --git a/ppstructure/vqa/eval_ser.py b/ppstructure/vqa/eval_ser.py deleted file mode 100644 index 95f428c721..0000000000 --- a/ppstructure/vqa/eval_ser.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import sys - -__dir__ = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) - -import random -import time -import copy -import logging - -import argparse -import paddle -import numpy as np -from seqeval.metrics import classification_report, f1_score, precision_score, recall_score -from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification -from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification - -from xfun import XFUNDataset -from losses import SERLoss -from vqa_utils import parse_args, get_bio_label_maps, print_arguments - -from ppocr.utils.logging import get_logger - -MODELS = { - 'LayoutXLM': - (LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification), - 'LayoutLM': - (LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification) -} - - -def eval(args): - logger = get_logger() - print_arguments(args, logger) - - label2id_map, id2label_map = get_bio_label_maps(args.label_map_path) - pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index - - tokenizer_class, base_model_class, model_class = MODELS[args.ser_model_type] - tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) - model = model_class.from_pretrained(args.model_name_or_path) - - eval_dataset = XFUNDataset( - tokenizer, - data_dir=args.eval_data_dir, - label_path=args.eval_label_path, - label2id_map=label2id_map, - img_size=(224, 224), - pad_token_label_id=pad_token_label_id, - contains_re=False, - add_special_ids=False, - return_attention_mask=True, - load_mode='all') - - eval_dataloader = paddle.io.DataLoader( - eval_dataset, - batch_size=args.per_gpu_eval_batch_size, - num_workers=args.num_workers, - use_shared_memory=True, - collate_fn=None, ) - - loss_class = SERLoss(len(label2id_map)) - - results, _ = evaluate(args, model, tokenizer, loss_class, eval_dataloader, - label2id_map, id2label_map, pad_token_label_id, - logger) - - logger.info(results) - - -def evaluate(args, - model, - tokenizer, - loss_class, - eval_dataloader, - label2id_map, - id2label_map, - pad_token_label_id, - logger, - prefix=""): - - eval_loss = 0.0 - nb_eval_steps = 0 - preds = None - out_label_ids = None - model.eval() - for idx, batch in enumerate(eval_dataloader): - with paddle.no_grad(): - if args.ser_model_type == 'LayoutLM': - if 'image' in batch: - batch.pop('image') - labels = batch.pop('labels') - outputs = model(**batch) - if args.ser_model_type == 'LayoutXLM': - outputs = outputs[0] - loss = loss_class(labels, outputs, batch['attention_mask']) - - loss = loss.mean() - - if paddle.distributed.get_rank() == 0: - logger.info("[Eval]process: {}/{}, loss: {:.5f}".format( - idx, len(eval_dataloader), loss.numpy()[0])) - - eval_loss += loss.item() - nb_eval_steps += 1 - if preds is None: - preds = outputs.numpy() - out_label_ids = labels.numpy() - else: - preds = np.append(preds, outputs.numpy(), axis=0) - out_label_ids = np.append(out_label_ids, labels.numpy(), axis=0) - - eval_loss = eval_loss / nb_eval_steps - preds = np.argmax(preds, axis=2) - - # label_map = {i: label.upper() for i, label in enumerate(labels)} - - out_label_list = [[] for _ in range(out_label_ids.shape[0])] - preds_list = [[] for _ in range(out_label_ids.shape[0])] - - for i in range(out_label_ids.shape[0]): - for j in range(out_label_ids.shape[1]): - if out_label_ids[i, j] != pad_token_label_id: - out_label_list[i].append(id2label_map[out_label_ids[i][j]]) - preds_list[i].append(id2label_map[preds[i][j]]) - - results = { - "loss": eval_loss, - "precision": precision_score(out_label_list, preds_list), - "recall": recall_score(out_label_list, preds_list), - "f1": f1_score(out_label_list, preds_list), - } - - with open( - os.path.join(args.output_dir, "test_gt.txt"), "w", - encoding='utf-8') as fout: - for lbl in out_label_list: - for l in lbl: - fout.write(l + "\t") - fout.write("\n") - with open( - os.path.join(args.output_dir, "test_pred.txt"), "w", - encoding='utf-8') as fout: - for lbl in preds_list: - for l in lbl: - fout.write(l + "\t") - fout.write("\n") - - report = classification_report(out_label_list, preds_list) - logger.info("\n" + report) - - logger.info("***** Eval results %s *****", prefix) - for key in sorted(results.keys()): - logger.info(" %s = %s", key, str(results[key])) - model.train() - return results, preds_list - - -if __name__ == "__main__": - args = parse_args() - eval(args) diff --git a/ppstructure/vqa/helper/trans_xfun_data.py b/ppstructure/vqa/helper/trans_xfun_data.py index 25b3963d83..93ec98163c 100644 --- a/ppstructure/vqa/helper/trans_xfun_data.py +++ b/ppstructure/vqa/helper/trans_xfun_data.py @@ -49,4 +49,16 @@ def transfer_xfun_data(json_path=None, output_file=None): print("===ok====") -transfer_xfun_data("./xfun/zh.val.json", "./xfun_normalize_val.json") +def parser_args(): + import argparse + parser = argparse.ArgumentParser(description="args for paddleserving") + parser.add_argument( + "--ori_gt_path", type=str, required=True, help='origin xfun gt path') + parser.add_argument( + "--output_path", type=str, required=True, help='path to save') + args = parser.parse_args() + return args + + +args = parser_args() +transfer_xfun_data(args.ori_gt_path, args.output_path) diff --git a/ppstructure/vqa/infer.sh b/ppstructure/vqa/infer.sh deleted file mode 100644 index 2cd1cea447..0000000000 --- a/ppstructure/vqa/infer.sh +++ /dev/null @@ -1,61 +0,0 @@ -export CUDA_VISIBLE_DEVICES=6 -# python3.7 infer_ser_e2e.py \ -# --model_name_or_path "output/ser_distributed/best_model" \ -# --max_seq_length 512 \ -# --output_dir "output_res_e2e/" \ -# --infer_imgs "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/zh.val/zh_val_0.jpg" - - -# python3.7 infer_ser_re_e2e.py \ -# --model_name_or_path "output/ser_distributed/best_model" \ -# --re_model_name_or_path "output/re_test/best_model" \ -# --max_seq_length 512 \ -# --output_dir "output_ser_re_e2e_train/" \ -# --infer_imgs "images/input/zh_val_21.jpg" - -# python3.7 infer_ser.py \ -# --model_name_or_path "output/ser_LayoutLM/best_model" \ -# --ser_model_type "LayoutLM" \ -# --output_dir "ser_LayoutLM/" \ -# --infer_imgs "images/input/zh_val_21.jpg" \ -# --ocr_json_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json" - -python3.7 infer_ser.py \ - --model_name_or_path "output/ser_new/best_model" \ - --ser_model_type "LayoutXLM" \ - --output_dir "ser_new/" \ - --infer_imgs "images/input/zh_val_21.jpg" \ - --ocr_json_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json" - -# python3.7 infer_ser_e2e.py \ -# --model_name_or_path "output/ser_new/best_model" \ -# --ser_model_type "LayoutXLM" \ -# --max_seq_length 512 \ -# --output_dir "output/ser_new/" \ -# --infer_imgs "images/input/zh_val_0.jpg" - - -# python3.7 infer_ser_e2e.py \ -# --model_name_or_path "output/ser_LayoutLM/best_model" \ -# --ser_model_type "LayoutLM" \ -# --max_seq_length 512 \ -# --output_dir "output/ser_LayoutLM/" \ -# --infer_imgs "images/input/zh_val_0.jpg" - -# python3 infer_re.py \ -# --model_name_or_path "/ssd1/zhoujun20/VQA/PaddleOCR/ppstructure/vqa/output/re_test/best_model/" \ -# --max_seq_length 512 \ -# --eval_data_dir "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/zh.val" \ -# --eval_label_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json" \ -# --label_map_path 'labels/labels_ser.txt' \ -# --output_dir "output_res" \ -# --per_gpu_eval_batch_size 1 \ -# --seed 2048 - -# python3.7 infer_ser_re_e2e.py \ -# --model_name_or_path "output/ser_LayoutLM/best_model" \ -# --ser_model_type "LayoutLM" \ -# --re_model_name_or_path "output/re_new/best_model" \ -# --max_seq_length 512 \ -# --output_dir "output_ser_re_e2e/" \ -# --infer_imgs "images/input/zh_val_21.jpg" \ No newline at end of file diff --git a/ppstructure/vqa/infer_re.py b/ppstructure/vqa/infer_re.py deleted file mode 100644 index b6774e77be..0000000000 --- a/ppstructure/vqa/infer_re.py +++ /dev/null @@ -1,165 +0,0 @@ -import os -import sys - -__dir__ = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) - -import random - -import cv2 -import matplotlib.pyplot as plt -import numpy as np -import paddle - -from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction - -from xfun import XFUNDataset -from vqa_utils import parse_args, get_bio_label_maps, draw_re_results -from data_collator import DataCollator - -from ppocr.utils.logging import get_logger - - -def infer(args): - os.makedirs(args.output_dir, exist_ok=True) - logger = get_logger() - label2id_map, id2label_map = get_bio_label_maps(args.label_map_path) - pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index - - tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path) - - model = LayoutXLMForRelationExtraction.from_pretrained( - args.model_name_or_path) - - eval_dataset = XFUNDataset( - tokenizer, - data_dir=args.eval_data_dir, - label_path=args.eval_label_path, - label2id_map=label2id_map, - img_size=(224, 224), - max_seq_len=args.max_seq_length, - pad_token_label_id=pad_token_label_id, - contains_re=True, - add_special_ids=False, - return_attention_mask=True, - load_mode='all') - - eval_dataloader = paddle.io.DataLoader( - eval_dataset, - batch_size=args.per_gpu_eval_batch_size, - num_workers=8, - shuffle=False, - collate_fn=DataCollator()) - - # čÆ»å–gtēš„octę•°ę® - ocr_info_list = load_ocr(args.eval_data_dir, args.eval_label_path) - - for idx, batch in enumerate(eval_dataloader): - ocr_info = ocr_info_list[idx] - image_path = ocr_info['image_path'] - ocr_info = ocr_info['ocr_info'] - - save_img_path = os.path.join( - args.output_dir, - os.path.splitext(os.path.basename(image_path))[0] + "_re.jpg") - logger.info("[Infer] process: {}/{}, save result to {}".format( - idx, len(eval_dataloader), save_img_path)) - with paddle.no_grad(): - outputs = model(**batch) - pred_relations = outputs['pred_relations'] - - # ę ¹ę®entityé‡Œēš„äæ”ęÆļ¼Œåštokenč§£ē åŽåŽ»čæ‡ę»¤äøč¦ēš„ocr_info - ocr_info = filter_bg_by_txt(ocr_info, batch, tokenizer) - - # čæ›č”Œ relations 到 ocräæ”ęÆēš„č½¬ę¢ - result = [] - used_tail_id = [] - for relations in pred_relations: - for relation in relations: - if relation['tail_id'] in used_tail_id: - continue - if relation['head_id'] not in ocr_info or relation[ - 'tail_id'] not in ocr_info: - continue - used_tail_id.append(relation['tail_id']) - ocr_info_head = ocr_info[relation['head_id']] - ocr_info_tail = ocr_info[relation['tail_id']] - result.append((ocr_info_head, ocr_info_tail)) - - img = cv2.imread(image_path) - img_show = draw_re_results(img, result) - cv2.imwrite(save_img_path, img_show) - - -def load_ocr(img_folder, json_path): - import json - d = [] - with open(json_path, "r", encoding='utf-8') as fin: - lines = fin.readlines() - for line in lines: - image_name, info_str = line.split("\t") - info_dict = json.loads(info_str) - info_dict['image_path'] = os.path.join(img_folder, image_name) - d.append(info_dict) - return d - - -def filter_bg_by_txt(ocr_info, batch, tokenizer): - entities = batch['entities'][0] - input_ids = batch['input_ids'][0] - - new_info_dict = {} - for i in range(len(entities['start'])): - entitie_head = entities['start'][i] - entitie_tail = entities['end'][i] - word_input_ids = input_ids[entitie_head:entitie_tail].numpy().tolist() - txt = tokenizer.convert_ids_to_tokens(word_input_ids) - txt = tokenizer.convert_tokens_to_string(txt) - - for i, info in enumerate(ocr_info): - if info['text'] == txt: - new_info_dict[i] = info - return new_info_dict - - -def post_process(pred_relations, ocr_info, img): - result = [] - for relations in pred_relations: - for relation in relations: - ocr_info_head = ocr_info[relation['head_id']] - ocr_info_tail = ocr_info[relation['tail_id']] - result.append((ocr_info_head, ocr_info_tail)) - return result - - -def draw_re(result, image_path, output_folder): - img = cv2.imread(image_path) - - from matplotlib import pyplot as plt - for ocr_info_head, ocr_info_tail in result: - cv2.rectangle( - img, - tuple(ocr_info_head['bbox'][:2]), - tuple(ocr_info_head['bbox'][2:]), (255, 0, 0), - thickness=2) - cv2.rectangle( - img, - tuple(ocr_info_tail['bbox'][:2]), - tuple(ocr_info_tail['bbox'][2:]), (0, 0, 255), - thickness=2) - center_p1 = [(ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2, - (ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2] - center_p2 = [(ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2, - (ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2] - cv2.line( - img, tuple(center_p1), tuple(center_p2), (0, 255, 0), thickness=2) - plt.imshow(img) - plt.savefig( - os.path.join(output_folder, os.path.basename(image_path)), dpi=600) - # plt.show() - - -if __name__ == "__main__": - args = parse_args() - infer(args) diff --git a/ppstructure/vqa/infer_ser.py b/ppstructure/vqa/infer_ser.py deleted file mode 100644 index f5fb581fa7..0000000000 --- a/ppstructure/vqa/infer_ser.py +++ /dev/null @@ -1,302 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import sys - -__dir__ = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(__dir__) - -import json -import cv2 -import numpy as np -from copy import deepcopy - -import paddle - -# relative reference -from vqa_utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps -from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification -from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification - -MODELS = { - 'LayoutXLM': - (LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification), - 'LayoutLM': - (LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification) -} - - -def pad_sentences(tokenizer, - encoded_inputs, - max_seq_len=512, - pad_to_max_seq_len=True, - return_attention_mask=True, - return_token_type_ids=True, - return_overflowing_tokens=False, - return_special_tokens_mask=False): - # Padding with larger size, reshape is carried out - max_seq_len = ( - len(encoded_inputs["input_ids"]) // max_seq_len + 1) * max_seq_len - - needs_to_be_padded = pad_to_max_seq_len and \ - max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len - - if needs_to_be_padded: - difference = max_seq_len - len(encoded_inputs["input_ids"]) - if tokenizer.padding_side == 'right': - if return_attention_mask: - encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[ - "input_ids"]) + [0] * difference - if return_token_type_ids: - encoded_inputs["token_type_ids"] = ( - encoded_inputs["token_type_ids"] + - [tokenizer.pad_token_type_id] * difference) - if return_special_tokens_mask: - encoded_inputs["special_tokens_mask"] = encoded_inputs[ - "special_tokens_mask"] + [1] * difference - encoded_inputs["input_ids"] = encoded_inputs[ - "input_ids"] + [tokenizer.pad_token_id] * difference - encoded_inputs["bbox"] = encoded_inputs["bbox"] + [[0, 0, 0, 0] - ] * difference - else: - assert False, "padding_side of tokenizer just supports [\"right\"] but got {}".format( - tokenizer.padding_side) - else: - if return_attention_mask: - encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[ - "input_ids"]) - - return encoded_inputs - - -def split_page(encoded_inputs, max_seq_len=512): - """ - truncate is often used in training process - """ - for key in encoded_inputs: - encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key]) - if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on - encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len]) - else: # for bbox - encoded_inputs[key] = encoded_inputs[key].reshape( - [-1, max_seq_len, 4]) - return encoded_inputs - - -def preprocess( - tokenizer, - ori_img, - ocr_info, - img_size=(224, 224), - pad_token_label_id=-100, - max_seq_len=512, - add_special_ids=False, - return_attention_mask=True, ): - ocr_info = deepcopy(ocr_info) - height = ori_img.shape[0] - width = ori_img.shape[1] - - img = cv2.resize(ori_img, - (224, 224)).transpose([2, 0, 1]).astype(np.float32) - - segment_offset_id = [] - words_list = [] - bbox_list = [] - input_ids_list = [] - token_type_ids_list = [] - - for info in ocr_info: - # x1, y1, x2, y2 - bbox = info["bbox"] - bbox[0] = int(bbox[0] * 1000.0 / width) - bbox[2] = int(bbox[2] * 1000.0 / width) - bbox[1] = int(bbox[1] * 1000.0 / height) - bbox[3] = int(bbox[3] * 1000.0 / height) - - text = info["text"] - encode_res = tokenizer.encode( - text, pad_to_max_seq_len=False, return_attention_mask=True) - - if not add_special_ids: - # TODO: use tok.all_special_ids to remove - encode_res["input_ids"] = encode_res["input_ids"][1:-1] - encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1] - encode_res["attention_mask"] = encode_res["attention_mask"][1:-1] - - input_ids_list.extend(encode_res["input_ids"]) - token_type_ids_list.extend(encode_res["token_type_ids"]) - bbox_list.extend([bbox] * len(encode_res["input_ids"])) - words_list.append(text) - segment_offset_id.append(len(input_ids_list)) - - encoded_inputs = { - "input_ids": input_ids_list, - "token_type_ids": token_type_ids_list, - "bbox": bbox_list, - "attention_mask": [1] * len(input_ids_list), - } - - encoded_inputs = pad_sentences( - tokenizer, - encoded_inputs, - max_seq_len=max_seq_len, - return_attention_mask=return_attention_mask) - - encoded_inputs = split_page(encoded_inputs) - - fake_bs = encoded_inputs["input_ids"].shape[0] - - encoded_inputs["image"] = paddle.to_tensor(img).unsqueeze(0).expand( - [fake_bs] + list(img.shape)) - - encoded_inputs["segment_offset_id"] = segment_offset_id - - return encoded_inputs - - -def postprocess(attention_mask, preds, label_map_path): - if isinstance(preds, paddle.Tensor): - preds = preds.numpy() - preds = np.argmax(preds, axis=2) - - _, label_map = get_bio_label_maps(label_map_path) - - preds_list = [[] for _ in range(preds.shape[0])] - - # keep batch info - for i in range(preds.shape[0]): - for j in range(preds.shape[1]): - if attention_mask[i][j] == 1: - preds_list[i].append(label_map[preds[i][j]]) - - return preds_list - - -def merge_preds_list_with_ocr_info(label_map_path, ocr_info, segment_offset_id, - preds_list): - # must ensure the preds_list is generated from the same image - preds = [p for pred in preds_list for p in pred] - label2id_map, _ = get_bio_label_maps(label_map_path) - for key in label2id_map: - if key.startswith("I-"): - label2id_map[key] = label2id_map["B" + key[1:]] - - id2label_map = dict() - for key in label2id_map: - val = label2id_map[key] - if key == "O": - id2label_map[val] = key - if key.startswith("B-") or key.startswith("I-"): - id2label_map[val] = key[2:] - else: - id2label_map[val] = key - - for idx in range(len(segment_offset_id)): - if idx == 0: - start_id = 0 - else: - start_id = segment_offset_id[idx - 1] - - end_id = segment_offset_id[idx] - - curr_pred = preds[start_id:end_id] - curr_pred = [label2id_map[p] for p in curr_pred] - - if len(curr_pred) <= 0: - pred_id = 0 - else: - counts = np.bincount(curr_pred) - pred_id = np.argmax(counts) - ocr_info[idx]["pred_id"] = int(pred_id) - ocr_info[idx]["pred"] = id2label_map[pred_id] - return ocr_info - - -@paddle.no_grad() -def infer(args): - os.makedirs(args.output_dir, exist_ok=True) - - # init token and model - tokenizer_class, base_model_class, model_class = MODELS[args.ser_model_type] - tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) - model = model_class.from_pretrained(args.model_name_or_path) - - model.eval() - - # load ocr results json - ocr_results = dict() - with open(args.ocr_json_path, "r", encoding='utf-8') as fin: - lines = fin.readlines() - for line in lines: - img_name, json_info = line.split("\t") - ocr_results[os.path.basename(img_name)] = json.loads(json_info) - - # get infer img list - infer_imgs = get_image_file_list(args.infer_imgs) - - # loop for infer - with open( - os.path.join(args.output_dir, "infer_results.txt"), - "w", - encoding='utf-8') as fout: - for idx, img_path in enumerate(infer_imgs): - save_img_path = os.path.join(args.output_dir, - os.path.basename(img_path)) - print("process: [{}/{}], save result to {}".format( - idx, len(infer_imgs), save_img_path)) - - img = cv2.imread(img_path) - - ocr_info = ocr_results[os.path.basename(img_path)]["ocr_info"] - inputs = preprocess( - tokenizer=tokenizer, - ori_img=img, - ocr_info=ocr_info, - max_seq_len=args.max_seq_length) - if args.ser_model_type == 'LayoutLM': - preds = model( - input_ids=inputs["input_ids"], - bbox=inputs["bbox"], - token_type_ids=inputs["token_type_ids"], - attention_mask=inputs["attention_mask"]) - elif args.ser_model_type == 'LayoutXLM': - preds = model( - input_ids=inputs["input_ids"], - bbox=inputs["bbox"], - image=inputs["image"], - token_type_ids=inputs["token_type_ids"], - attention_mask=inputs["attention_mask"]) - preds = preds[0] - - preds = postprocess(inputs["attention_mask"], preds, - args.label_map_path) - ocr_info = merge_preds_list_with_ocr_info( - args.label_map_path, ocr_info, inputs["segment_offset_id"], - preds) - - fout.write(img_path + "\t" + json.dumps( - { - "ocr_info": ocr_info, - }, ensure_ascii=False) + "\n") - - img_res = draw_ser_results(img, ocr_info) - cv2.imwrite(save_img_path, img_res) - - return - - -if __name__ == "__main__": - args = parse_args() - infer(args) diff --git a/ppstructure/vqa/infer_ser_e2e.py b/ppstructure/vqa/infer_ser_e2e.py deleted file mode 100644 index 33fe4dbb5e..0000000000 --- a/ppstructure/vqa/infer_ser_e2e.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import sys - -__dir__ = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(__dir__) - -import json -import cv2 -import numpy as np -from copy import deepcopy -from PIL import Image - -import paddle -from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification -from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification - -# relative reference -from vqa_utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps - -from vqa_utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info - -MODELS = { - 'LayoutXLM': - (LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification), - 'LayoutLM': - (LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification) -} - - -def trans_poly_to_bbox(poly): - x1 = np.min([p[0] for p in poly]) - x2 = np.max([p[0] for p in poly]) - y1 = np.min([p[1] for p in poly]) - y2 = np.max([p[1] for p in poly]) - return [x1, y1, x2, y2] - - -def parse_ocr_info_for_ser(ocr_result): - ocr_info = [] - for res in ocr_result: - ocr_info.append({ - "text": res[1][0], - "bbox": trans_poly_to_bbox(res[0]), - "poly": res[0], - }) - return ocr_info - - -class SerPredictor(object): - def __init__(self, args): - self.args = args - self.max_seq_length = args.max_seq_length - - # init ser token and model - tokenizer_class, base_model_class, model_class = MODELS[ - args.ser_model_type] - self.tokenizer = tokenizer_class.from_pretrained( - args.model_name_or_path) - self.model = model_class.from_pretrained(args.model_name_or_path) - self.model.eval() - - # init ocr_engine - from paddleocr import PaddleOCR - - self.ocr_engine = PaddleOCR( - rec_model_dir=args.rec_model_dir, - det_model_dir=args.det_model_dir, - use_angle_cls=False, - show_log=False) - # init dict - label2id_map, self.id2label_map = get_bio_label_maps( - args.label_map_path) - self.label2id_map_for_draw = dict() - for key in label2id_map: - if key.startswith("I-"): - self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]] - else: - self.label2id_map_for_draw[key] = label2id_map[key] - - def __call__(self, img): - ocr_result = self.ocr_engine.ocr(img, cls=False) - - ocr_info = parse_ocr_info_for_ser(ocr_result) - - inputs = preprocess( - tokenizer=self.tokenizer, - ori_img=img, - ocr_info=ocr_info, - max_seq_len=self.max_seq_length) - - if self.args.ser_model_type == 'LayoutLM': - preds = self.model( - input_ids=inputs["input_ids"], - bbox=inputs["bbox"], - token_type_ids=inputs["token_type_ids"], - attention_mask=inputs["attention_mask"]) - elif self.args.ser_model_type == 'LayoutXLM': - preds = self.model( - input_ids=inputs["input_ids"], - bbox=inputs["bbox"], - image=inputs["image"], - token_type_ids=inputs["token_type_ids"], - attention_mask=inputs["attention_mask"]) - preds = preds[0] - - preds = postprocess(inputs["attention_mask"], preds, self.id2label_map) - ocr_info = merge_preds_list_with_ocr_info( - ocr_info, inputs["segment_offset_id"], preds, - self.label2id_map_for_draw) - return ocr_info, inputs - - -if __name__ == "__main__": - args = parse_args() - os.makedirs(args.output_dir, exist_ok=True) - - # get infer img list - infer_imgs = get_image_file_list(args.infer_imgs) - - # loop for infer - ser_engine = SerPredictor(args) - with open( - os.path.join(args.output_dir, "infer_results.txt"), - "w", - encoding='utf-8') as fout: - for idx, img_path in enumerate(infer_imgs): - save_img_path = os.path.join( - args.output_dir, - os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg") - print("process: [{}/{}], save result to {}".format( - idx, len(infer_imgs), save_img_path)) - - img = cv2.imread(img_path) - - result, _ = ser_engine(img) - fout.write(img_path + "\t" + json.dumps( - { - "ser_resule": result, - }, ensure_ascii=False) + "\n") - - img_res = draw_ser_results(img, result) - cv2.imwrite(save_img_path, img_res) diff --git a/ppstructure/vqa/infer_ser_re_e2e.py b/ppstructure/vqa/infer_ser_re_e2e.py deleted file mode 100644 index e24c9f69e0..0000000000 --- a/ppstructure/vqa/infer_ser_re_e2e.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import sys -import json -import cv2 -import numpy as np -from copy import deepcopy -from PIL import Image - -import paddle -from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForRelationExtraction - -# relative reference -from vqa_utils import parse_args, get_image_file_list, draw_re_results -from infer_ser_e2e import SerPredictor - - -def make_input(ser_input, ser_result, max_seq_len=512): - entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2} - - entities = ser_input['entities'][0] - assert len(entities) == len(ser_result) - - # entities - start = [] - end = [] - label = [] - entity_idx_dict = {} - for i, (res, entity) in enumerate(zip(ser_result, entities)): - if res['pred'] == 'O': - continue - entity_idx_dict[len(start)] = i - start.append(entity['start']) - end.append(entity['end']) - label.append(entities_labels[res['pred']]) - entities = dict(start=start, end=end, label=label) - - # relations - head = [] - tail = [] - for i in range(len(entities["label"])): - for j in range(len(entities["label"])): - if entities["label"][i] == 1 and entities["label"][j] == 2: - head.append(i) - tail.append(j) - - relations = dict(head=head, tail=tail) - - batch_size = ser_input["input_ids"].shape[0] - entities_batch = [] - relations_batch = [] - for b in range(batch_size): - entities_batch.append(entities) - relations_batch.append(relations) - - ser_input['entities'] = entities_batch - ser_input['relations'] = relations_batch - - ser_input.pop('segment_offset_id') - return ser_input, entity_idx_dict - - -class SerReSystem(object): - def __init__(self, args): - self.ser_engine = SerPredictor(args) - self.tokenizer = LayoutXLMTokenizer.from_pretrained( - args.re_model_name_or_path) - self.model = LayoutXLMForRelationExtraction.from_pretrained( - args.re_model_name_or_path) - self.model.eval() - - def __call__(self, img): - ser_result, ser_inputs = self.ser_engine(img) - re_input, entity_idx_dict = make_input(ser_inputs, ser_result) - - re_result = self.model(**re_input) - - pred_relations = re_result['pred_relations'][0] - # čæ›č”Œ relations 到 ocräæ”ęÆēš„č½¬ę¢ - result = [] - used_tail_id = [] - for relation in pred_relations: - if relation['tail_id'] in used_tail_id: - continue - used_tail_id.append(relation['tail_id']) - ocr_info_head = ser_result[entity_idx_dict[relation['head_id']]] - ocr_info_tail = ser_result[entity_idx_dict[relation['tail_id']]] - result.append((ocr_info_head, ocr_info_tail)) - - return result - - -if __name__ == "__main__": - args = parse_args() - os.makedirs(args.output_dir, exist_ok=True) - - # get infer img list - infer_imgs = get_image_file_list(args.infer_imgs) - - # loop for infer - ser_re_engine = SerReSystem(args) - with open( - os.path.join(args.output_dir, "infer_results.txt"), - "w", - encoding='utf-8') as fout: - for idx, img_path in enumerate(infer_imgs): - save_img_path = os.path.join( - args.output_dir, - os.path.splitext(os.path.basename(img_path))[0] + "_re.jpg") - print("process: [{}/{}], save result to {}".format( - idx, len(infer_imgs), save_img_path)) - - img = cv2.imread(img_path) - - result = ser_re_engine(img) - fout.write(img_path + "\t" + json.dumps( - { - "result": result, - }, ensure_ascii=False) + "\n") - - img_res = draw_re_results(img, result) - cv2.imwrite(save_img_path, img_res) diff --git a/ppstructure/vqa/metric.py b/ppstructure/vqa/metric.py deleted file mode 100644 index cb58370521..0000000000 --- a/ppstructure/vqa/metric.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import re - -import numpy as np - -import logging - -logger = logging.getLogger(__name__) - -PREFIX_CHECKPOINT_DIR = "checkpoint" -_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$") - - -def get_last_checkpoint(folder): - content = os.listdir(folder) - checkpoints = [ - path for path in content - if _re_checkpoint.search(path) is not None and os.path.isdir( - os.path.join(folder, path)) - ] - if len(checkpoints) == 0: - return - return os.path.join( - folder, - max(checkpoints, - key=lambda x: int(_re_checkpoint.search(x).groups()[0]))) - - -def re_score(pred_relations, gt_relations, mode="strict"): - """Evaluate RE predictions - - Args: - pred_relations (list) : list of list of predicted relations (several relations in each sentence) - gt_relations (list) : list of list of ground truth relations - - rel = { "head": (start_idx (inclusive), end_idx (exclusive)), - "tail": (start_idx (inclusive), end_idx (exclusive)), - "head_type": ent_type, - "tail_type": ent_type, - "type": rel_type} - - vocab (Vocab) : dataset vocabulary - mode (str) : in 'strict' or 'boundaries'""" - - assert mode in ["strict", "boundaries"] - - relation_types = [v for v in [0, 1] if not v == 0] - scores = { - rel: { - "tp": 0, - "fp": 0, - "fn": 0 - } - for rel in relation_types + ["ALL"] - } - - # Count GT relations and Predicted relations - n_sents = len(gt_relations) - n_rels = sum([len([rel for rel in sent]) for sent in gt_relations]) - n_found = sum([len([rel for rel in sent]) for sent in pred_relations]) - - # Count TP, FP and FN per type - for pred_sent, gt_sent in zip(pred_relations, gt_relations): - for rel_type in relation_types: - # strict mode takes argument types into account - if mode == "strict": - pred_rels = {(rel["head"], rel["head_type"], rel["tail"], - rel["tail_type"]) - for rel in pred_sent if rel["type"] == rel_type} - gt_rels = {(rel["head"], rel["head_type"], rel["tail"], - rel["tail_type"]) - for rel in gt_sent if rel["type"] == rel_type} - - # boundaries mode only takes argument spans into account - elif mode == "boundaries": - pred_rels = {(rel["head"], rel["tail"]) - for rel in pred_sent if rel["type"] == rel_type} - gt_rels = {(rel["head"], rel["tail"]) - for rel in gt_sent if rel["type"] == rel_type} - - scores[rel_type]["tp"] += len(pred_rels & gt_rels) - scores[rel_type]["fp"] += len(pred_rels - gt_rels) - scores[rel_type]["fn"] += len(gt_rels - pred_rels) - - # Compute per entity Precision / Recall / F1 - for rel_type in scores.keys(): - if scores[rel_type]["tp"]: - scores[rel_type]["p"] = scores[rel_type]["tp"] / ( - scores[rel_type]["fp"] + scores[rel_type]["tp"]) - scores[rel_type]["r"] = scores[rel_type]["tp"] / ( - scores[rel_type]["fn"] + scores[rel_type]["tp"]) - else: - scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0 - - if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0: - scores[rel_type]["f1"] = ( - 2 * scores[rel_type]["p"] * scores[rel_type]["r"] / - (scores[rel_type]["p"] + scores[rel_type]["r"])) - else: - scores[rel_type]["f1"] = 0 - - # Compute micro F1 Scores - tp = sum([scores[rel_type]["tp"] for rel_type in relation_types]) - fp = sum([scores[rel_type]["fp"] for rel_type in relation_types]) - fn = sum([scores[rel_type]["fn"] for rel_type in relation_types]) - - if tp: - precision = tp / (tp + fp) - recall = tp / (tp + fn) - f1 = 2 * precision * recall / (precision + recall) - - else: - precision, recall, f1 = 0, 0, 0 - - scores["ALL"]["p"] = precision - scores["ALL"]["r"] = recall - scores["ALL"]["f1"] = f1 - scores["ALL"]["tp"] = tp - scores["ALL"]["fp"] = fp - scores["ALL"]["fn"] = fn - - # Compute Macro F1 Scores - scores["ALL"]["Macro_f1"] = np.mean( - [scores[ent_type]["f1"] for ent_type in relation_types]) - scores["ALL"]["Macro_p"] = np.mean( - [scores[ent_type]["p"] for ent_type in relation_types]) - scores["ALL"]["Macro_r"] = np.mean( - [scores[ent_type]["r"] for ent_type in relation_types]) - - # logger.info(f"RE Evaluation in *** {mode.upper()} *** mode") - - # logger.info( - # "processed {} sentences with {} relations; found: {} relations; correct: {}.".format( - # n_sents, n_rels, n_found, tp - # ) - # ) - # logger.info( - # "\tALL\t TP: {};\tFP: {};\tFN: {}".format(scores["ALL"]["tp"], scores["ALL"]["fp"], scores["ALL"]["fn"]) - # ) - # logger.info("\t\t(m avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (micro)".format(precision, recall, f1)) - # logger.info( - # "\t\t(M avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (Macro)\n".format( - # scores["ALL"]["Macro_p"], scores["ALL"]["Macro_r"], scores["ALL"]["Macro_f1"] - # ) - # ) - - # for rel_type in relation_types: - # logger.info( - # "\t{}: \tTP: {};\tFP: {};\tFN: {};\tprecision: {:.2f};\trecall: {:.2f};\tf1: {:.2f};\t{}".format( - # rel_type, - # scores[rel_type]["tp"], - # scores[rel_type]["fp"], - # scores[rel_type]["fn"], - # scores[rel_type]["p"], - # scores[rel_type]["r"], - # scores[rel_type]["f1"], - # scores[rel_type]["tp"] + scores[rel_type]["fp"], - # ) - # ) - - return scores diff --git a/ppstructure/vqa/train_re.py b/ppstructure/vqa/train_re.py deleted file mode 100644 index eeff2bfbbe..0000000000 --- a/ppstructure/vqa/train_re.py +++ /dev/null @@ -1,229 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import sys - -__dir__ = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) - -import random -import time -import numpy as np -import paddle - -from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction - -from xfun import XFUNDataset -from vqa_utils import parse_args, get_bio_label_maps, print_arguments, set_seed -from data_collator import DataCollator -from eval_re import evaluate - -from ppocr.utils.logging import get_logger - - -def train(args): - logger = get_logger(log_file=os.path.join(args.output_dir, "train.log")) - rank = paddle.distributed.get_rank() - distributed = paddle.distributed.get_world_size() > 1 - - print_arguments(args, logger) - - # Added here for reproducibility (even between python 2 and 3) - set_seed(args.seed) - - label2id_map, id2label_map = get_bio_label_maps(args.label_map_path) - pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index - - # dist mode - if distributed: - paddle.distributed.init_parallel_env() - - tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path) - if not args.resume: - model = LayoutXLMModel.from_pretrained(args.model_name_or_path) - model = LayoutXLMForRelationExtraction(model, dropout=None) - logger.info('train from scratch') - else: - logger.info('resume from {}'.format(args.model_name_or_path)) - model = LayoutXLMForRelationExtraction.from_pretrained( - args.model_name_or_path) - - # dist mode - if distributed: - model = paddle.DataParallel(model) - - train_dataset = XFUNDataset( - tokenizer, - data_dir=args.train_data_dir, - label_path=args.train_label_path, - label2id_map=label2id_map, - img_size=(224, 224), - max_seq_len=args.max_seq_length, - pad_token_label_id=pad_token_label_id, - contains_re=True, - add_special_ids=False, - return_attention_mask=True, - load_mode='all') - - eval_dataset = XFUNDataset( - tokenizer, - data_dir=args.eval_data_dir, - label_path=args.eval_label_path, - label2id_map=label2id_map, - img_size=(224, 224), - max_seq_len=args.max_seq_length, - pad_token_label_id=pad_token_label_id, - contains_re=True, - add_special_ids=False, - return_attention_mask=True, - load_mode='all') - - train_sampler = paddle.io.DistributedBatchSampler( - train_dataset, batch_size=args.per_gpu_train_batch_size, shuffle=True) - - train_dataloader = paddle.io.DataLoader( - train_dataset, - batch_sampler=train_sampler, - num_workers=args.num_workers, - use_shared_memory=True, - collate_fn=DataCollator()) - - eval_dataloader = paddle.io.DataLoader( - eval_dataset, - batch_size=args.per_gpu_eval_batch_size, - num_workers=args.num_workers, - shuffle=False, - collate_fn=DataCollator()) - - t_total = len(train_dataloader) * args.num_train_epochs - - # build linear decay with warmup lr sch - lr_scheduler = paddle.optimizer.lr.PolynomialDecay( - learning_rate=args.learning_rate, - decay_steps=t_total, - end_lr=0.0, - power=1.0) - if args.warmup_steps > 0: - lr_scheduler = paddle.optimizer.lr.LinearWarmup( - lr_scheduler, - args.warmup_steps, - start_lr=0, - end_lr=args.learning_rate, ) - grad_clip = paddle.nn.ClipGradByNorm(clip_norm=10) - optimizer = paddle.optimizer.Adam( - learning_rate=args.learning_rate, - parameters=model.parameters(), - epsilon=args.adam_epsilon, - grad_clip=grad_clip, - weight_decay=args.weight_decay) - - # Train! - logger.info("***** Running training *****") - logger.info(" Num examples = {}".format(len(train_dataset))) - logger.info(" Num Epochs = {}".format(args.num_train_epochs)) - logger.info(" Instantaneous batch size per GPU = {}".format( - args.per_gpu_train_batch_size)) - logger.info( - " Total train batch size (w. parallel, distributed & accumulation) = {}". - format(args.per_gpu_train_batch_size * - paddle.distributed.get_world_size())) - logger.info(" Total optimization steps = {}".format(t_total)) - - global_step = 0 - model.clear_gradients() - train_dataloader_len = len(train_dataloader) - best_metirc = {'f1': 0} - model.train() - - train_reader_cost = 0.0 - train_run_cost = 0.0 - total_samples = 0 - reader_start = time.time() - - print_step = 1 - - for epoch in range(int(args.num_train_epochs)): - for step, batch in enumerate(train_dataloader): - train_reader_cost += time.time() - reader_start - train_start = time.time() - outputs = model(**batch) - train_run_cost += time.time() - train_start - # model outputs are always tuple in ppnlp (see doc) - loss = outputs['loss'] - loss = loss.mean() - - loss.backward() - optimizer.step() - optimizer.clear_grad() - # lr_scheduler.step() # Update learning rate schedule - - global_step += 1 - total_samples += batch['image'].shape[0] - - if rank == 0 and step % print_step == 0: - logger.info( - "epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {:.6f}, lr: {:.6f}, avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec". - format(epoch, args.num_train_epochs, step, - train_dataloader_len, global_step, - np.mean(loss.numpy()), - optimizer.get_lr(), train_reader_cost / print_step, ( - train_reader_cost + train_run_cost) / print_step, - total_samples / print_step, total_samples / ( - train_reader_cost + train_run_cost))) - - train_reader_cost = 0.0 - train_run_cost = 0.0 - total_samples = 0 - - if rank == 0 and args.eval_steps > 0 and global_step % args.eval_steps == 0 and args.evaluate_during_training: - # Log metrics - # Only evaluate when single GPU otherwise metrics may not average well - results = evaluate(model, eval_dataloader, logger) - if results['f1'] >= best_metirc['f1']: - best_metirc = results - output_dir = os.path.join(args.output_dir, "best_model") - os.makedirs(output_dir, exist_ok=True) - if distributed: - model._layers.save_pretrained(output_dir) - else: - model.save_pretrained(output_dir) - tokenizer.save_pretrained(output_dir) - paddle.save(args, - os.path.join(output_dir, "training_args.bin")) - logger.info("Saving model checkpoint to {}".format( - output_dir)) - logger.info("eval results: {}".format(results)) - logger.info("best_metirc: {}".format(best_metirc)) - reader_start = time.time() - - if rank == 0: - # Save model checkpoint - output_dir = os.path.join(args.output_dir, "latest_model") - os.makedirs(output_dir, exist_ok=True) - if distributed: - model._layers.save_pretrained(output_dir) - else: - model.save_pretrained(output_dir) - tokenizer.save_pretrained(output_dir) - paddle.save(args, os.path.join(output_dir, "training_args.bin")) - logger.info("Saving model checkpoint to {}".format(output_dir)) - logger.info("best_metirc: {}".format(best_metirc)) - - -if __name__ == "__main__": - args = parse_args() - os.makedirs(args.output_dir, exist_ok=True) - train(args) diff --git a/ppstructure/vqa/train_ser.py b/ppstructure/vqa/train_ser.py deleted file mode 100644 index 226172050e..0000000000 --- a/ppstructure/vqa/train_ser.py +++ /dev/null @@ -1,248 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import sys - -__dir__ = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) - -import random -import time -import copy -import logging - -import argparse -import paddle -import numpy as np -from seqeval.metrics import classification_report, f1_score, precision_score, recall_score -from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification -from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification - -from xfun import XFUNDataset -from vqa_utils import parse_args, get_bio_label_maps, print_arguments, set_seed -from eval_ser import evaluate -from losses import SERLoss -from ppocr.utils.logging import get_logger - -MODELS = { - 'LayoutXLM': - (LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification), - 'LayoutLM': - (LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification) -} - - -def train(args): - os.makedirs(args.output_dir, exist_ok=True) - rank = paddle.distributed.get_rank() - distributed = paddle.distributed.get_world_size() > 1 - - logger = get_logger(log_file=os.path.join(args.output_dir, "train.log")) - print_arguments(args, logger) - - label2id_map, id2label_map = get_bio_label_maps(args.label_map_path) - loss_class = SERLoss(len(label2id_map)) - - pad_token_label_id = loss_class.ignore_index - - # dist mode - if distributed: - paddle.distributed.init_parallel_env() - - tokenizer_class, base_model_class, model_class = MODELS[args.ser_model_type] - tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) - if not args.resume: - base_model = base_model_class.from_pretrained(args.model_name_or_path) - model = model_class( - base_model, num_classes=len(label2id_map), dropout=None) - logger.info('train from scratch') - else: - logger.info('resume from {}'.format(args.model_name_or_path)) - model = model_class.from_pretrained(args.model_name_or_path) - - # dist mode - if distributed: - model = paddle.DataParallel(model) - - train_dataset = XFUNDataset( - tokenizer, - data_dir=args.train_data_dir, - label_path=args.train_label_path, - label2id_map=label2id_map, - img_size=(224, 224), - pad_token_label_id=pad_token_label_id, - contains_re=False, - add_special_ids=False, - return_attention_mask=True, - load_mode='all') - eval_dataset = XFUNDataset( - tokenizer, - data_dir=args.eval_data_dir, - label_path=args.eval_label_path, - label2id_map=label2id_map, - img_size=(224, 224), - pad_token_label_id=pad_token_label_id, - contains_re=False, - add_special_ids=False, - return_attention_mask=True, - load_mode='all') - - train_sampler = paddle.io.DistributedBatchSampler( - train_dataset, batch_size=args.per_gpu_train_batch_size, shuffle=True) - - train_dataloader = paddle.io.DataLoader( - train_dataset, - batch_sampler=train_sampler, - num_workers=args.num_workers, - use_shared_memory=True, - collate_fn=None, ) - - eval_dataloader = paddle.io.DataLoader( - eval_dataset, - batch_size=args.per_gpu_eval_batch_size, - num_workers=args.num_workers, - use_shared_memory=True, - collate_fn=None, ) - - t_total = len(train_dataloader) * args.num_train_epochs - - # build linear decay with warmup lr sch - lr_scheduler = paddle.optimizer.lr.PolynomialDecay( - learning_rate=args.learning_rate, - decay_steps=t_total, - end_lr=0.0, - power=1.0) - if args.warmup_steps > 0: - lr_scheduler = paddle.optimizer.lr.LinearWarmup( - lr_scheduler, - args.warmup_steps, - start_lr=0, - end_lr=args.learning_rate, ) - - optimizer = paddle.optimizer.AdamW( - learning_rate=lr_scheduler, - parameters=model.parameters(), - epsilon=args.adam_epsilon, - weight_decay=args.weight_decay) - - # Train! - logger.info("***** Running training *****") - logger.info(" Num examples = %d", len(train_dataset)) - logger.info(" Num Epochs = %d", args.num_train_epochs) - logger.info(" Instantaneous batch size per GPU = %d", - args.per_gpu_train_batch_size) - logger.info( - " Total train batch size (w. parallel, distributed) = %d", - args.per_gpu_train_batch_size * paddle.distributed.get_world_size(), ) - logger.info(" Total optimization steps = %d", t_total) - - global_step = 0 - tr_loss = 0.0 - set_seed(args.seed) - best_metrics = None - - train_reader_cost = 0.0 - train_run_cost = 0.0 - total_samples = 0 - reader_start = time.time() - - print_step = 1 - model.train() - for epoch_id in range(args.num_train_epochs): - for step, batch in enumerate(train_dataloader): - train_reader_cost += time.time() - reader_start - - if args.ser_model_type == 'LayoutLM': - if 'image' in batch: - batch.pop('image') - labels = batch.pop('labels') - - train_start = time.time() - outputs = model(**batch) - train_run_cost += time.time() - train_start - if args.ser_model_type == 'LayoutXLM': - outputs = outputs[0] - loss = loss_class(labels, outputs, batch['attention_mask']) - - # model outputs are always tuple in ppnlp (see doc) - loss = loss.mean() - loss.backward() - tr_loss += loss.item() - optimizer.step() - lr_scheduler.step() # Update learning rate schedule - optimizer.clear_grad() - global_step += 1 - total_samples += batch['input_ids'].shape[0] - - if rank == 0 and step % print_step == 0: - logger.info( - "epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {:.6f}, lr: {:.6f}, avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec". - format(epoch_id, args.num_train_epochs, step, - len(train_dataloader), global_step, - loss.numpy()[0], - lr_scheduler.get_lr(), train_reader_cost / - print_step, (train_reader_cost + train_run_cost) / - print_step, total_samples / print_step, total_samples - / (train_reader_cost + train_run_cost))) - - train_reader_cost = 0.0 - train_run_cost = 0.0 - total_samples = 0 - - if rank == 0 and args.eval_steps > 0 and global_step % args.eval_steps == 0 and args.evaluate_during_training: - # Log metrics - # Only evaluate when single GPU otherwise metrics may not average well - results, _ = evaluate(args, model, tokenizer, loss_class, - eval_dataloader, label2id_map, - id2label_map, pad_token_label_id, logger) - - if best_metrics is None or results["f1"] >= best_metrics["f1"]: - best_metrics = copy.deepcopy(results) - output_dir = os.path.join(args.output_dir, "best_model") - os.makedirs(output_dir, exist_ok=True) - if distributed: - model._layers.save_pretrained(output_dir) - else: - model.save_pretrained(output_dir) - tokenizer.save_pretrained(output_dir) - paddle.save(args, - os.path.join(output_dir, "training_args.bin")) - logger.info("Saving model checkpoint to {}".format( - output_dir)) - - logger.info("[epoch {}/{}][iter: {}/{}] results: {}".format( - epoch_id, args.num_train_epochs, step, - len(train_dataloader), results)) - if best_metrics is not None: - logger.info("best metrics: {}".format(best_metrics)) - reader_start = time.time() - if rank == 0: - # Save model checkpoint - output_dir = os.path.join(args.output_dir, "latest_model") - os.makedirs(output_dir, exist_ok=True) - if distributed: - model._layers.save_pretrained(output_dir) - else: - model.save_pretrained(output_dir) - tokenizer.save_pretrained(output_dir) - paddle.save(args, os.path.join(output_dir, "training_args.bin")) - logger.info("Saving model checkpoint to {}".format(output_dir)) - return global_step, tr_loss / global_step - - -if __name__ == "__main__": - args = parse_args() - train(args) diff --git a/ppstructure/vqa/vqa_utils.py b/ppstructure/vqa/vqa_utils.py deleted file mode 100644 index b9f2edc860..0000000000 --- a/ppstructure/vqa/vqa_utils.py +++ /dev/null @@ -1,400 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import argparse -import cv2 -import random -import numpy as np -import imghdr -from copy import deepcopy - -import paddle - -from PIL import Image, ImageDraw, ImageFont - - -def set_seed(seed): - random.seed(seed) - np.random.seed(seed) - paddle.seed(seed) - - -def get_bio_label_maps(label_map_path): - with open(label_map_path, "r", encoding='utf-8') as fin: - lines = fin.readlines() - lines = [line.strip() for line in lines] - if "O" not in lines: - lines.insert(0, "O") - labels = [] - for line in lines: - if line == "O": - labels.append("O") - else: - labels.append("B-" + line) - labels.append("I-" + line) - label2id_map = {label: idx for idx, label in enumerate(labels)} - id2label_map = {idx: label for idx, label in enumerate(labels)} - return label2id_map, id2label_map - - -def get_image_file_list(img_file): - imgs_lists = [] - if img_file is None or not os.path.exists(img_file): - raise Exception("not found any img file in {}".format(img_file)) - - img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'GIF'} - if os.path.isfile(img_file) and imghdr.what(img_file) in img_end: - imgs_lists.append(img_file) - elif os.path.isdir(img_file): - for single_file in os.listdir(img_file): - file_path = os.path.join(img_file, single_file) - if os.path.isfile(file_path) and imghdr.what(file_path) in img_end: - imgs_lists.append(file_path) - if len(imgs_lists) == 0: - raise Exception("not found any img file in {}".format(img_file)) - imgs_lists = sorted(imgs_lists) - return imgs_lists - - -def draw_ser_results(image, - ocr_results, - font_path="../../doc/fonts/simfang.ttf", - font_size=18): - np.random.seed(2021) - color = (np.random.permutation(range(255)), - np.random.permutation(range(255)), - np.random.permutation(range(255))) - color_map = { - idx: (color[0][idx], color[1][idx], color[2][idx]) - for idx in range(1, 255) - } - if isinstance(image, np.ndarray): - image = Image.fromarray(image) - img_new = image.copy() - draw = ImageDraw.Draw(img_new) - - font = ImageFont.truetype(font_path, font_size, encoding="utf-8") - for ocr_info in ocr_results: - if ocr_info["pred_id"] not in color_map: - continue - color = color_map[ocr_info["pred_id"]] - text = "{}: {}".format(ocr_info["pred"], ocr_info["text"]) - - draw_box_txt(ocr_info["bbox"], text, draw, font, font_size, color) - - img_new = Image.blend(image, img_new, 0.5) - return np.array(img_new) - - -def draw_box_txt(bbox, text, draw, font, font_size, color): - # draw ocr results outline - bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3])) - draw.rectangle(bbox, fill=color) - - # draw ocr results - start_y = max(0, bbox[0][1] - font_size) - tw = font.getsize(text)[0] - draw.rectangle( - [(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, start_y + font_size)], - fill=(0, 0, 255)) - draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font) - - -def draw_re_results(image, - result, - font_path="../../doc/fonts/simfang.ttf", - font_size=18): - np.random.seed(0) - if isinstance(image, np.ndarray): - image = Image.fromarray(image) - img_new = image.copy() - draw = ImageDraw.Draw(img_new) - - font = ImageFont.truetype(font_path, font_size, encoding="utf-8") - color_head = (0, 0, 255) - color_tail = (255, 0, 0) - color_line = (0, 255, 0) - - for ocr_info_head, ocr_info_tail in result: - draw_box_txt(ocr_info_head["bbox"], ocr_info_head["text"], draw, font, - font_size, color_head) - draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["text"], draw, font, - font_size, color_tail) - - center_head = ( - (ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2, - (ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2) - center_tail = ( - (ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2, - (ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2) - - draw.line([center_head, center_tail], fill=color_line, width=5) - - img_new = Image.blend(image, img_new, 0.5) - return np.array(img_new) - - -# pad sentences -def pad_sentences(tokenizer, - encoded_inputs, - max_seq_len=512, - pad_to_max_seq_len=True, - return_attention_mask=True, - return_token_type_ids=True, - return_overflowing_tokens=False, - return_special_tokens_mask=False): - # Padding with larger size, reshape is carried out - max_seq_len = ( - len(encoded_inputs["input_ids"]) // max_seq_len + 1) * max_seq_len - - needs_to_be_padded = pad_to_max_seq_len and \ - max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len - - if needs_to_be_padded: - difference = max_seq_len - len(encoded_inputs["input_ids"]) - if tokenizer.padding_side == 'right': - if return_attention_mask: - encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[ - "input_ids"]) + [0] * difference - if return_token_type_ids: - encoded_inputs["token_type_ids"] = ( - encoded_inputs["token_type_ids"] + - [tokenizer.pad_token_type_id] * difference) - if return_special_tokens_mask: - encoded_inputs["special_tokens_mask"] = encoded_inputs[ - "special_tokens_mask"] + [1] * difference - encoded_inputs["input_ids"] = encoded_inputs[ - "input_ids"] + [tokenizer.pad_token_id] * difference - encoded_inputs["bbox"] = encoded_inputs["bbox"] + [[0, 0, 0, 0] - ] * difference - else: - if return_attention_mask: - encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[ - "input_ids"]) - - return encoded_inputs - - -def split_page(encoded_inputs, max_seq_len=512): - """ - truncate is often used in training process - """ - for key in encoded_inputs: - if key == 'entities': - encoded_inputs[key] = [encoded_inputs[key]] - continue - encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key]) - if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on - encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len]) - else: # for bbox - encoded_inputs[key] = encoded_inputs[key].reshape( - [-1, max_seq_len, 4]) - return encoded_inputs - - -def preprocess( - tokenizer, - ori_img, - ocr_info, - img_size=(224, 224), - pad_token_label_id=-100, - max_seq_len=512, - add_special_ids=False, - return_attention_mask=True, ): - ocr_info = deepcopy(ocr_info) - height = ori_img.shape[0] - width = ori_img.shape[1] - - img = cv2.resize(ori_img, img_size).transpose([2, 0, 1]).astype(np.float32) - - segment_offset_id = [] - words_list = [] - bbox_list = [] - input_ids_list = [] - token_type_ids_list = [] - entities = [] - - for info in ocr_info: - # x1, y1, x2, y2 - bbox = info["bbox"] - bbox[0] = int(bbox[0] * 1000.0 / width) - bbox[2] = int(bbox[2] * 1000.0 / width) - bbox[1] = int(bbox[1] * 1000.0 / height) - bbox[3] = int(bbox[3] * 1000.0 / height) - - text = info["text"] - encode_res = tokenizer.encode( - text, pad_to_max_seq_len=False, return_attention_mask=True) - - if not add_special_ids: - # TODO: use tok.all_special_ids to remove - encode_res["input_ids"] = encode_res["input_ids"][1:-1] - encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1] - encode_res["attention_mask"] = encode_res["attention_mask"][1:-1] - - # for re - entities.append({ - "start": len(input_ids_list), - "end": len(input_ids_list) + len(encode_res["input_ids"]), - "label": "O", - }) - - input_ids_list.extend(encode_res["input_ids"]) - token_type_ids_list.extend(encode_res["token_type_ids"]) - bbox_list.extend([bbox] * len(encode_res["input_ids"])) - words_list.append(text) - segment_offset_id.append(len(input_ids_list)) - - encoded_inputs = { - "input_ids": input_ids_list, - "token_type_ids": token_type_ids_list, - "bbox": bbox_list, - "attention_mask": [1] * len(input_ids_list), - "entities": entities - } - - encoded_inputs = pad_sentences( - tokenizer, - encoded_inputs, - max_seq_len=max_seq_len, - return_attention_mask=return_attention_mask) - - encoded_inputs = split_page(encoded_inputs) - - fake_bs = encoded_inputs["input_ids"].shape[0] - - encoded_inputs["image"] = paddle.to_tensor(img).unsqueeze(0).expand( - [fake_bs] + list(img.shape)) - - encoded_inputs["segment_offset_id"] = segment_offset_id - - return encoded_inputs - - -def postprocess(attention_mask, preds, id2label_map): - if isinstance(preds, paddle.Tensor): - preds = preds.numpy() - preds = np.argmax(preds, axis=2) - - preds_list = [[] for _ in range(preds.shape[0])] - - # keep batch info - for i in range(preds.shape[0]): - for j in range(preds.shape[1]): - if attention_mask[i][j] == 1: - preds_list[i].append(id2label_map[preds[i][j]]) - - return preds_list - - -def merge_preds_list_with_ocr_info(ocr_info, segment_offset_id, preds_list, - label2id_map_for_draw): - # must ensure the preds_list is generated from the same image - preds = [p for pred in preds_list for p in pred] - - id2label_map = dict() - for key in label2id_map_for_draw: - val = label2id_map_for_draw[key] - if key == "O": - id2label_map[val] = key - if key.startswith("B-") or key.startswith("I-"): - id2label_map[val] = key[2:] - else: - id2label_map[val] = key - - for idx in range(len(segment_offset_id)): - if idx == 0: - start_id = 0 - else: - start_id = segment_offset_id[idx - 1] - - end_id = segment_offset_id[idx] - - curr_pred = preds[start_id:end_id] - curr_pred = [label2id_map_for_draw[p] for p in curr_pred] - - if len(curr_pred) <= 0: - pred_id = 0 - else: - counts = np.bincount(curr_pred) - pred_id = np.argmax(counts) - ocr_info[idx]["pred_id"] = int(pred_id) - ocr_info[idx]["pred"] = id2label_map[int(pred_id)] - return ocr_info - - -def print_arguments(args, logger=None): - print_func = logger.info if logger is not None else print - """print arguments""" - print_func('----------- Configuration Arguments -----------') - for arg, value in sorted(vars(args).items()): - print_func('%s: %s' % (arg, value)) - print_func('------------------------------------------------') - - -def parse_args(): - parser = argparse.ArgumentParser() - # Required parameters - # yapf: disable - parser.add_argument("--model_name_or_path", - default=None, type=str, required=True,) - parser.add_argument("--ser_model_type", - default='LayoutXLM', type=str) - parser.add_argument("--re_model_name_or_path", - default=None, type=str, required=False,) - parser.add_argument("--train_data_dir", default=None, - type=str, required=False,) - parser.add_argument("--train_label_path", default=None, - type=str, required=False,) - parser.add_argument("--eval_data_dir", default=None, - type=str, required=False,) - parser.add_argument("--eval_label_path", default=None, - type=str, required=False,) - parser.add_argument("--output_dir", default=None, type=str, required=True,) - parser.add_argument("--max_seq_length", default=512, type=int,) - parser.add_argument("--evaluate_during_training", action="store_true",) - parser.add_argument("--num_workers", default=8, type=int,) - parser.add_argument("--per_gpu_train_batch_size", default=8, - type=int, help="Batch size per GPU/CPU for training.",) - parser.add_argument("--per_gpu_eval_batch_size", default=8, - type=int, help="Batch size per GPU/CPU for eval.",) - parser.add_argument("--learning_rate", default=5e-5, - type=float, help="The initial learning rate for Adam.",) - parser.add_argument("--weight_decay", default=0.0, - type=float, help="Weight decay if we apply some.",) - parser.add_argument("--adam_epsilon", default=1e-8, - type=float, help="Epsilon for Adam optimizer.",) - parser.add_argument("--max_grad_norm", default=1.0, - type=float, help="Max gradient norm.",) - parser.add_argument("--num_train_epochs", default=3, type=int, - help="Total number of training epochs to perform.",) - parser.add_argument("--warmup_steps", default=0, type=int, - help="Linear warmup over warmup_steps.",) - parser.add_argument("--eval_steps", type=int, default=10, - help="eval every X updates steps.",) - parser.add_argument("--seed", type=int, default=2048, - help="random seed for initialization",) - - parser.add_argument("--rec_model_dir", default=None, type=str, ) - parser.add_argument("--det_model_dir", default=None, type=str, ) - parser.add_argument( - "--label_map_path", default="./labels/labels_ser.txt", type=str, required=False, ) - parser.add_argument("--infer_imgs", default=None, type=str, required=False) - parser.add_argument("--resume", action='store_true') - parser.add_argument("--ocr_json_path", default=None, - type=str, required=False, help="ocr prediction results") - # yapf: enable - args = parser.parse_args() - return args diff --git a/ppstructure/vqa/xfun.py b/ppstructure/vqa/xfun.py deleted file mode 100644 index f5dbe507e8..0000000000 --- a/ppstructure/vqa/xfun.py +++ /dev/null @@ -1,464 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os -import cv2 -import numpy as np -import paddle -import copy -from paddle.io import Dataset - -__all__ = ["XFUNDataset"] - - -class XFUNDataset(Dataset): - """ - Example: - print("=====begin to build dataset=====") - from paddlenlp.transformers import LayoutXLMTokenizer - tokenizer = LayoutXLMTokenizer.from_pretrained("/paddle/models/transformers/layoutxlm-base-paddle/") - tok_res = tokenizer.tokenize("Maribyrnong") - # res = tokenizer.convert_ids_to_tokens(val_data["input_ids"][0]) - dataset = XfunDatasetForSer( - tokenizer, - data_dir="./zh.val/", - label_path="zh.val/xfun_normalize_val.json", - img_size=(224,224)) - print(len(dataset)) - - data = dataset[0] - print(data.keys()) - print("input_ids: ", data["input_ids"]) - print("labels: ", data["labels"]) - print("token_type_ids: ", data["token_type_ids"]) - print("words_list: ", data["words_list"]) - print("image shape: ", data["image"].shape) - """ - - def __init__(self, - tokenizer, - data_dir, - label_path, - contains_re=False, - label2id_map=None, - img_size=(224, 224), - pad_token_label_id=None, - add_special_ids=False, - return_attention_mask=True, - load_mode='all', - max_seq_len=512): - super().__init__() - self.tokenizer = tokenizer - self.data_dir = data_dir - self.label_path = label_path - self.contains_re = contains_re - self.label2id_map = label2id_map - self.img_size = img_size - self.pad_token_label_id = pad_token_label_id - self.add_special_ids = add_special_ids - self.return_attention_mask = return_attention_mask - self.load_mode = load_mode - self.max_seq_len = max_seq_len - - if self.pad_token_label_id is None: - self.pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index - - self.all_lines = self.read_all_lines() - - self.entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2} - self.return_keys = { - 'bbox': { - 'type': 'np', - 'dtype': 'int64' - }, - 'input_ids': { - 'type': 'np', - 'dtype': 'int64' - }, - 'labels': { - 'type': 'np', - 'dtype': 'int64' - }, - 'attention_mask': { - 'type': 'np', - 'dtype': 'int64' - }, - 'image': { - 'type': 'np', - 'dtype': 'float32' - }, - 'token_type_ids': { - 'type': 'np', - 'dtype': 'int64' - }, - 'entities': { - 'type': 'dict' - }, - 'relations': { - 'type': 'dict' - } - } - - if load_mode == "all": - self.encoded_inputs_all = self._parse_label_file_all() - - def pad_sentences(self, - encoded_inputs, - max_seq_len=512, - pad_to_max_seq_len=True, - return_attention_mask=True, - return_token_type_ids=True, - truncation_strategy="longest_first", - return_overflowing_tokens=False, - return_special_tokens_mask=False): - # Padding - needs_to_be_padded = pad_to_max_seq_len and \ - max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len - - if needs_to_be_padded: - difference = max_seq_len - len(encoded_inputs["input_ids"]) - if self.tokenizer.padding_side == 'right': - if return_attention_mask: - encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[ - "input_ids"]) + [0] * difference - if return_token_type_ids: - encoded_inputs["token_type_ids"] = ( - encoded_inputs["token_type_ids"] + - [self.tokenizer.pad_token_type_id] * difference) - if return_special_tokens_mask: - encoded_inputs["special_tokens_mask"] = encoded_inputs[ - "special_tokens_mask"] + [1] * difference - encoded_inputs["input_ids"] = encoded_inputs[ - "input_ids"] + [self.tokenizer.pad_token_id] * difference - encoded_inputs["labels"] = encoded_inputs[ - "labels"] + [self.pad_token_label_id] * difference - encoded_inputs["bbox"] = encoded_inputs[ - "bbox"] + [[0, 0, 0, 0]] * difference - elif self.tokenizer.padding_side == 'left': - if return_attention_mask: - encoded_inputs["attention_mask"] = [0] * difference + [ - 1 - ] * len(encoded_inputs["input_ids"]) - if return_token_type_ids: - encoded_inputs["token_type_ids"] = ( - [self.tokenizer.pad_token_type_id] * difference + - encoded_inputs["token_type_ids"]) - if return_special_tokens_mask: - encoded_inputs["special_tokens_mask"] = [ - 1 - ] * difference + encoded_inputs["special_tokens_mask"] - encoded_inputs["input_ids"] = [ - self.tokenizer.pad_token_id - ] * difference + encoded_inputs["input_ids"] - encoded_inputs["labels"] = [ - self.pad_token_label_id - ] * difference + encoded_inputs["labels"] - encoded_inputs["bbox"] = [ - [0, 0, 0, 0] - ] * difference + encoded_inputs["bbox"] - else: - if return_attention_mask: - encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[ - "input_ids"]) - - return encoded_inputs - - def truncate_inputs(self, encoded_inputs, max_seq_len=512): - for key in encoded_inputs: - if key == "sample_id": - continue - length = min(len(encoded_inputs[key]), max_seq_len) - encoded_inputs[key] = encoded_inputs[key][:length] - return encoded_inputs - - def read_all_lines(self, ): - with open(self.label_path, "r", encoding='utf-8') as fin: - lines = fin.readlines() - return lines - - def _parse_label_file_all(self): - """ - parse all samples - """ - encoded_inputs_all = [] - for line in self.all_lines: - encoded_inputs_all.extend(self._parse_label_file(line)) - return encoded_inputs_all - - def _parse_label_file(self, line): - """ - parse single sample - """ - - image_name, info_str = line.split("\t") - image_path = os.path.join(self.data_dir, image_name) - - def add_imgge_path(x): - x['image_path'] = image_path - return x - - encoded_inputs = self._read_encoded_inputs_sample(info_str) - if self.contains_re: - encoded_inputs = self._chunk_re(encoded_inputs) - else: - encoded_inputs = self._chunk_ser(encoded_inputs) - encoded_inputs = list(map(add_imgge_path, encoded_inputs)) - return encoded_inputs - - def _read_encoded_inputs_sample(self, info_str): - """ - parse label info - """ - # read text info - info_dict = json.loads(info_str) - height = info_dict["height"] - width = info_dict["width"] - - words_list = [] - bbox_list = [] - input_ids_list = [] - token_type_ids_list = [] - gt_label_list = [] - - if self.contains_re: - # for re - entities = [] - relations = [] - id2label = {} - entity_id_to_index_map = {} - empty_entity = set() - for info in info_dict["ocr_info"]: - if self.contains_re: - # for re - if len(info["text"]) == 0: - empty_entity.add(info["id"]) - continue - id2label[info["id"]] = info["label"] - relations.extend([tuple(sorted(l)) for l in info["linking"]]) - - # x1, y1, x2, y2 - bbox = info["bbox"] - label = info["label"] - bbox[0] = int(bbox[0] * 1000.0 / width) - bbox[2] = int(bbox[2] * 1000.0 / width) - bbox[1] = int(bbox[1] * 1000.0 / height) - bbox[3] = int(bbox[3] * 1000.0 / height) - - text = info["text"] - encode_res = self.tokenizer.encode( - text, pad_to_max_seq_len=False, return_attention_mask=True) - - gt_label = [] - if not self.add_special_ids: - # TODO: use tok.all_special_ids to remove - encode_res["input_ids"] = encode_res["input_ids"][1:-1] - encode_res["token_type_ids"] = encode_res["token_type_ids"][1: - -1] - encode_res["attention_mask"] = encode_res["attention_mask"][1: - -1] - if label.lower() == "other": - gt_label.extend([0] * len(encode_res["input_ids"])) - else: - gt_label.append(self.label2id_map[("b-" + label).upper()]) - gt_label.extend([self.label2id_map[("i-" + label).upper()]] * - (len(encode_res["input_ids"]) - 1)) - if self.contains_re: - if gt_label[0] != self.label2id_map["O"]: - entity_id_to_index_map[info["id"]] = len(entities) - entities.append({ - "start": len(input_ids_list), - "end": - len(input_ids_list) + len(encode_res["input_ids"]), - "label": label.upper(), - }) - input_ids_list.extend(encode_res["input_ids"]) - token_type_ids_list.extend(encode_res["token_type_ids"]) - bbox_list.extend([bbox] * len(encode_res["input_ids"])) - gt_label_list.extend(gt_label) - words_list.append(text) - - encoded_inputs = { - "input_ids": input_ids_list, - "labels": gt_label_list, - "token_type_ids": token_type_ids_list, - "bbox": bbox_list, - "attention_mask": [1] * len(input_ids_list), - # "words_list": words_list, - } - encoded_inputs = self.pad_sentences( - encoded_inputs, - max_seq_len=self.max_seq_len, - return_attention_mask=self.return_attention_mask) - encoded_inputs = self.truncate_inputs(encoded_inputs) - - if self.contains_re: - relations = self._relations(entities, relations, id2label, - empty_entity, entity_id_to_index_map) - encoded_inputs['relations'] = relations - encoded_inputs['entities'] = entities - return encoded_inputs - - def _chunk_ser(self, encoded_inputs): - encoded_inputs_all = [] - seq_len = len(encoded_inputs['input_ids']) - chunk_size = 512 - for chunk_id, index in enumerate(range(0, seq_len, chunk_size)): - chunk_beg = index - chunk_end = min(index + chunk_size, seq_len) - encoded_inputs_example = {} - for key in encoded_inputs: - encoded_inputs_example[key] = encoded_inputs[key][chunk_beg: - chunk_end] - - encoded_inputs_all.append(encoded_inputs_example) - return encoded_inputs_all - - def _chunk_re(self, encoded_inputs): - # prepare data - entities = encoded_inputs.pop('entities') - relations = encoded_inputs.pop('relations') - encoded_inputs_all = [] - chunk_size = 512 - for chunk_id, index in enumerate( - range(0, len(encoded_inputs["input_ids"]), chunk_size)): - item = {} - for k in encoded_inputs: - item[k] = encoded_inputs[k][index:index + chunk_size] - - # select entity in current chunk - entities_in_this_span = [] - global_to_local_map = {} # - for entity_id, entity in enumerate(entities): - if (index <= entity["start"] < index + chunk_size and - index <= entity["end"] < index + chunk_size): - entity["start"] = entity["start"] - index - entity["end"] = entity["end"] - index - global_to_local_map[entity_id] = len(entities_in_this_span) - entities_in_this_span.append(entity) - - # select relations in current chunk - relations_in_this_span = [] - for relation in relations: - if (index <= relation["start_index"] < index + chunk_size and - index <= relation["end_index"] < index + chunk_size): - relations_in_this_span.append({ - "head": global_to_local_map[relation["head"]], - "tail": global_to_local_map[relation["tail"]], - "start_index": relation["start_index"] - index, - "end_index": relation["end_index"] - index, - }) - item.update({ - "entities": reformat(entities_in_this_span), - "relations": reformat(relations_in_this_span), - }) - item['entities']['label'] = [ - self.entities_labels[x] for x in item['entities']['label'] - ] - encoded_inputs_all.append(item) - return encoded_inputs_all - - def _relations(self, entities, relations, id2label, empty_entity, - entity_id_to_index_map): - """ - build relations - """ - relations = list(set(relations)) - relations = [ - rel for rel in relations - if rel[0] not in empty_entity and rel[1] not in empty_entity - ] - kv_relations = [] - for rel in relations: - pair = [id2label[rel[0]], id2label[rel[1]]] - if pair == ["question", "answer"]: - kv_relations.append({ - "head": entity_id_to_index_map[rel[0]], - "tail": entity_id_to_index_map[rel[1]] - }) - elif pair == ["answer", "question"]: - kv_relations.append({ - "head": entity_id_to_index_map[rel[1]], - "tail": entity_id_to_index_map[rel[0]] - }) - else: - continue - relations = sorted( - [{ - "head": rel["head"], - "tail": rel["tail"], - "start_index": get_relation_span(rel, entities)[0], - "end_index": get_relation_span(rel, entities)[1], - } for rel in kv_relations], - key=lambda x: x["head"], ) - return relations - - def load_img(self, image_path): - # read img - img = cv2.imread(image_path) - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - resize_h, resize_w = self.img_size - im_shape = img.shape[0:2] - im_scale_y = resize_h / im_shape[0] - im_scale_x = resize_w / im_shape[1] - img_new = cv2.resize( - img, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=2) - mean = np.array([0.485, 0.456, 0.406])[np.newaxis, np.newaxis, :] - std = np.array([0.229, 0.224, 0.225])[np.newaxis, np.newaxis, :] - img_new = img_new / 255.0 - img_new -= mean - img_new /= std - img = img_new.transpose((2, 0, 1)) - return img - - def __getitem__(self, idx): - if self.load_mode == "all": - data = copy.deepcopy(self.encoded_inputs_all[idx]) - else: - data = self._parse_label_file(self.all_lines[idx])[0] - - image_path = data.pop('image_path') - data["image"] = self.load_img(image_path) - - return_data = {} - for k, v in data.items(): - if k in self.return_keys: - if self.return_keys[k]['type'] == 'np': - v = np.array(v, dtype=self.return_keys[k]['dtype']) - return_data[k] = v - return return_data - - def __len__(self, ): - if self.load_mode == "all": - return len(self.encoded_inputs_all) - else: - return len(self.all_lines) - - -def get_relation_span(rel, entities): - bound = [] - for entity_index in [rel["head"], rel["tail"]]: - bound.append(entities[entity_index]["start"]) - bound.append(entities[entity_index]["end"]) - return min(bound), max(bound) - - -def reformat(data): - new_data = {} - for item in data: - for k, v in item.items(): - if k not in new_data: - new_data[k] = [] - new_data[k].append(v) - return new_data diff --git a/requirements.txt b/requirements.txt index 3865781996..1d9522aa01 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,4 @@ cython lxml premailer openpyxl -paddlenlp>=2.2.1 +fasttext==0.9.1 diff --git a/tools/eval.py b/tools/eval.py index 13a4a0882f..3a25c2660d 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -61,7 +61,8 @@ def main(): else: model_type = None - best_model_dict = load_model(config, model) + best_model_dict = load_model( + config, model, model_type=config['Architecture']["model_type"]) if len(best_model_dict): logger.info('metric in ckpt ***************') for k, v in best_model_dict.items(): diff --git a/tools/export_model.py b/tools/export_model.py index 9ed8e1b6ac..695af5c8bd 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -85,7 +85,7 @@ def export_single_model(model, arch_config, save_path, logger): def main(): FLAGS = ArgsParser().parse_args() config = load_config(FLAGS.config) - merge_config(FLAGS.opt) + config = merge_config(config, FLAGS.opt) logger = get_logger() # build post process diff --git a/tools/infer_vqa_token_ser.py b/tools/infer_vqa_token_ser.py new file mode 100755 index 0000000000..5859c28f92 --- /dev/null +++ b/tools/infer_vqa_token_ser.py @@ -0,0 +1,135 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' +import cv2 +import json +import paddle + +from ppocr.data import create_operators, transform +from ppocr.modeling.architectures import build_model +from ppocr.postprocess import build_post_process +from ppocr.utils.save_load import load_model +from ppocr.utils.visual import draw_ser_results +from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps +import tools.program as program + + +def to_tensor(data): + import numbers + from collections import defaultdict + data_dict = defaultdict(list) + to_tensor_idxs = [] + for idx, v in enumerate(data): + if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)): + if idx not in to_tensor_idxs: + to_tensor_idxs.append(idx) + data_dict[idx].append(v) + for idx in to_tensor_idxs: + data_dict[idx] = paddle.to_tensor(data_dict[idx]) + return list(data_dict.values()) + + +class SerPredictor(object): + def __init__(self, config): + global_config = config['Global'] + + # build post process + self.post_process_class = build_post_process(config['PostProcess'], + global_config) + + # build model + self.model = build_model(config['Architecture']) + + load_model( + config, self.model, model_type=config['Architecture']["model_type"]) + + from paddleocr import PaddleOCR + + self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False) + + # create data ops + transforms = [] + for op in config['Eval']['dataset']['transforms']: + op_name = list(op)[0] + if 'Label' in op_name: + op[op_name]['ocr_engine'] = self.ocr_engine + elif op_name == 'KeepKeys': + op[op_name]['keep_keys'] = [ + 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', + 'token_type_ids', 'segment_offset_id', 'ocr_info', + 'entities' + ] + + transforms.append(op) + global_config['infer_mode'] = True + self.ops = create_operators(config['Eval']['dataset']['transforms'], + global_config) + self.model.eval() + + def __call__(self, img_path): + with open(img_path, 'rb') as f: + img = f.read() + data = {'image': img} + batch = transform(data, self.ops) + batch = to_tensor(batch) + preds = self.model(batch) + post_result = self.post_process_class( + preds, + attention_masks=batch[4], + segment_offset_ids=batch[6], + ocr_infos=batch[7]) + return post_result, batch + + +if __name__ == '__main__': + config, device, logger, vdl_writer = program.preprocess() + os.makedirs(config['Global']['save_res_path'], exist_ok=True) + + ser_engine = SerPredictor(config) + + infer_imgs = get_image_file_list(config['Global']['infer_img']) + with open( + os.path.join(config['Global']['save_res_path'], + "infer_results.txt"), + "w", + encoding='utf-8') as fout: + for idx, img_path in enumerate(infer_imgs): + save_img_path = os.path.join( + config['Global']['save_res_path'], + os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg") + logger.info("process: [{}/{}], save result to {}".format( + idx, len(infer_imgs), save_img_path)) + + result, _ = ser_engine(img_path) + result = result[0] + fout.write(img_path + "\t" + json.dumps( + { + "ocr_info": result, + }, ensure_ascii=False) + "\n") + img_res = draw_ser_results(img_path, result) + cv2.imwrite(save_img_path, img_res) diff --git a/tools/infer_vqa_token_ser_re.py b/tools/infer_vqa_token_ser_re.py new file mode 100755 index 0000000000..fd62ace8ae --- /dev/null +++ b/tools/infer_vqa_token_ser_re.py @@ -0,0 +1,199 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' +import cv2 +import json +import paddle +import paddle.distributed as dist + +from ppocr.data import create_operators, transform +from ppocr.modeling.architectures import build_model +from ppocr.postprocess import build_post_process +from ppocr.utils.save_load import load_model +from ppocr.utils.visual import draw_re_results +from ppocr.utils.logging import get_logger +from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps, print_dict +from tools.program import ArgsParser, load_config, merge_config, check_gpu +from tools.infer_vqa_token_ser import SerPredictor + + +class ReArgsParser(ArgsParser): + def __init__(self): + super(ReArgsParser, self).__init__() + self.add_argument( + "-c_ser", "--config_ser", help="ser configuration file to use") + self.add_argument( + "-o_ser", + "--opt_ser", + nargs='+', + help="set ser configuration options ") + + def parse_args(self, argv=None): + args = super(ReArgsParser, self).parse_args(argv) + assert args.config_ser is not None, \ + "Please specify --config_ser=ser_configure_file_path." + args.opt_ser = self._parse_opt(args.opt_ser) + return args + + +def make_input(ser_inputs, ser_results): + entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2} + + entities = ser_inputs[8][0] + ser_results = ser_results[0] + assert len(entities) == len(ser_results) + + # entities + start = [] + end = [] + label = [] + entity_idx_dict = {} + for i, (res, entity) in enumerate(zip(ser_results, entities)): + if res['pred'] == 'O': + continue + entity_idx_dict[len(start)] = i + start.append(entity['start']) + end.append(entity['end']) + label.append(entities_labels[res['pred']]) + entities = dict(start=start, end=end, label=label) + + # relations + head = [] + tail = [] + for i in range(len(entities["label"])): + for j in range(len(entities["label"])): + if entities["label"][i] == 1 and entities["label"][j] == 2: + head.append(i) + tail.append(j) + + relations = dict(head=head, tail=tail) + + batch_size = ser_inputs[0].shape[0] + entities_batch = [] + relations_batch = [] + entity_idx_dict_batch = [] + for b in range(batch_size): + entities_batch.append(entities) + relations_batch.append(relations) + entity_idx_dict_batch.append(entity_idx_dict) + + ser_inputs[8] = entities_batch + ser_inputs.append(relations_batch) + # remove ocr_info segment_offset_id and label in ser input + ser_inputs.pop(7) + ser_inputs.pop(6) + ser_inputs.pop(1) + return ser_inputs, entity_idx_dict_batch + + +class SerRePredictor(object): + def __init__(self, config, ser_config): + self.ser_engine = SerPredictor(ser_config) + + # init re model + global_config = config['Global'] + + # build post process + self.post_process_class = build_post_process(config['PostProcess'], + global_config) + + # build model + self.model = build_model(config['Architecture']) + + load_model( + config, self.model, model_type=config['Architecture']["model_type"]) + + self.model.eval() + + def __call__(self, img_path): + ser_results, ser_inputs = self.ser_engine(img_path) + paddle.save(ser_inputs, 'ser_inputs.npy') + paddle.save(ser_results, 'ser_results.npy') + re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results) + preds = self.model(re_input) + post_result = self.post_process_class( + preds, + ser_results=ser_results, + entity_idx_dict_batch=entity_idx_dict_batch) + return post_result + + +def preprocess(): + FLAGS = ReArgsParser().parse_args() + config = load_config(FLAGS.config) + config = merge_config(config, FLAGS.opt) + + ser_config = load_config(FLAGS.config_ser) + ser_config = merge_config(ser_config, FLAGS.opt_ser) + + logger = get_logger(name='root') + + # check if set use_gpu=True in paddlepaddle cpu version + use_gpu = config['Global']['use_gpu'] + check_gpu(use_gpu) + + device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' + device = paddle.set_device(device) + + logger.info('{} re config {}'.format('*' * 10, '*' * 10)) + print_dict(config, logger) + logger.info('\n') + logger.info('{} ser config {}'.format('*' * 10, '*' * 10)) + print_dict(ser_config, logger) + logger.info('train with paddle {} and device {}'.format(paddle.__version__, + device)) + return config, ser_config, device, logger + + +if __name__ == '__main__': + config, ser_config, device, logger = preprocess() + os.makedirs(config['Global']['save_res_path'], exist_ok=True) + + ser_re_engine = SerRePredictor(config, ser_config) + + infer_imgs = get_image_file_list(config['Global']['infer_img']) + with open( + os.path.join(config['Global']['save_res_path'], + "infer_results.txt"), + "w", + encoding='utf-8') as fout: + for idx, img_path in enumerate(infer_imgs): + save_img_path = os.path.join( + config['Global']['save_res_path'], + os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg") + logger.info("process: [{}/{}], save result to {}".format( + idx, len(infer_imgs), save_img_path)) + + result = ser_re_engine(img_path) + result = result[0] + fout.write(img_path + "\t" + json.dumps( + { + "ser_resule": result, + }, ensure_ascii=False) + "\n") + img_res = draw_re_results(img_path, result) + cv2.imwrite(save_img_path, img_res) diff --git a/tools/program.py b/tools/program.py index 333e8ed977..743ace090c 100755 --- a/tools/program.py +++ b/tools/program.py @@ -69,24 +69,6 @@ class ArgsParser(ArgumentParser): return config -class AttrDict(dict): - """Single level attribute dict, NOT recursive""" - - def __init__(self, **kwargs): - super(AttrDict, self).__init__() - super(AttrDict, self).update(kwargs) - - def __getattr__(self, key): - if key in self: - return self[key] - raise AttributeError("object has no attribute '{}'".format(key)) - - -global_config = AttrDict() - -default_config = {'Global': {'debug': False, }} - - def load_config(file_path): """ Load config from yml/yaml file. @@ -94,38 +76,38 @@ def load_config(file_path): file_path (str): Path of the config file to be loaded. Returns: global config """ - merge_config(default_config) _, ext = os.path.splitext(file_path) assert ext in ['.yml', '.yaml'], "only support yaml files for now" - merge_config(yaml.load(open(file_path, 'rb'), Loader=yaml.Loader)) - return global_config + config = yaml.load(open(file_path, 'rb'), Loader=yaml.Loader) + return config -def merge_config(config): +def merge_config(config, opts): """ Merge config into global config. Args: config (dict): Config to be merged. Returns: global config """ - for key, value in config.items(): + for key, value in opts.items(): if "." not in key: - if isinstance(value, dict) and key in global_config: - global_config[key].update(value) + if isinstance(value, dict) and key in config: + config[key].update(value) else: - global_config[key] = value + config[key] = value else: sub_keys = key.split('.') assert ( - sub_keys[0] in global_config + sub_keys[0] in config ), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format( - global_config.keys(), sub_keys[0]) - cur = global_config[sub_keys[0]] + config.keys(), sub_keys[0]) + cur = config[sub_keys[0]] for idx, sub_key in enumerate(sub_keys[1:]): if idx == len(sub_keys) - 2: cur[sub_key] = value else: cur = cur[sub_key] + return config def check_gpu(use_gpu): @@ -204,20 +186,24 @@ def train(config, model_type = None algorithm = config['Architecture']['algorithm'] - if 'start_epoch' in best_model_dict: - start_epoch = best_model_dict['start_epoch'] - else: - start_epoch = 1 + start_epoch = best_model_dict[ + 'start_epoch'] if 'start_epoch' in best_model_dict else 1 + + train_reader_cost = 0.0 + train_run_cost = 0.0 + total_samples = 0 + reader_start = time.time() + + max_iter = len(train_dataloader) - 1 if platform.system( + ) == "Windows" else len(train_dataloader) for epoch in range(start_epoch, epoch_num + 1): - train_dataloader = build_dataloader( - config, 'Train', device, logger, seed=epoch) - train_reader_cost = 0.0 - train_run_cost = 0.0 - total_samples = 0 - reader_start = time.time() - max_iter = len(train_dataloader) - 1 if platform.system( - ) == "Windows" else len(train_dataloader) + if train_dataloader.dataset.need_reset: + train_dataloader = build_dataloader( + config, 'Train', device, logger, seed=epoch) + max_iter = len(train_dataloader) - 1 if platform.system( + ) == "Windows" else len(train_dataloader) + for idx, batch in enumerate(train_dataloader): profiler.add_profiler_step(profiler_options) train_reader_cost += time.time() - reader_start @@ -239,10 +225,11 @@ def train(config, else: if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) - elif model_type == "kie": + elif model_type in ["kie", 'vqa']: preds = model(batch) else: preds = model(images) + loss = loss_class(preds, batch) avg_loss = loss['loss'] @@ -256,6 +243,7 @@ def train(config, optimizer.clear_grad() train_run_cost += time.time() - train_start + global_step += 1 total_samples += len(images) if not isinstance(lr_scheduler, float): @@ -285,12 +273,13 @@ def train(config, (global_step > 0 and global_step % print_batch_step == 0) or (idx >= len(train_dataloader) - 1)): logs = train_stats.log() - strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format( + strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: {:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ips: {:.5f}'.format( epoch, epoch_num, global_step, logs, train_reader_cost / print_batch_step, (train_reader_cost + train_run_cost) / - print_batch_step, total_samples, + print_batch_step, total_samples / print_batch_step, total_samples / (train_reader_cost + train_run_cost)) logger.info(strs) + train_reader_cost = 0.0 train_run_cost = 0.0 total_samples = 0 @@ -330,6 +319,7 @@ def train(config, optimizer, save_model_dir, logger, + config, is_best=True, prefix='best_accuracy', best_model_dict=best_model_dict, @@ -344,8 +334,7 @@ def train(config, vdl_writer.add_scalar('EVAL/best_{}'.format(main_indicator), best_model_dict[main_indicator], global_step) - global_step += 1 - optimizer.clear_grad() + reader_start = time.time() if dist.get_rank() == 0: save_model( @@ -353,6 +342,7 @@ def train(config, optimizer, save_model_dir, logger, + config, is_best=False, prefix='latest', best_model_dict=best_model_dict, @@ -364,6 +354,7 @@ def train(config, optimizer, save_model_dir, logger, + config, is_best=False, prefix='iter_epoch_{}'.format(epoch), best_model_dict=best_model_dict, @@ -401,19 +392,28 @@ def eval(model, start = time.time() if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) - elif model_type == "kie": + elif model_type in ["kie", 'vqa']: preds = model(batch) else: preds = model(images) - batch = [item.numpy() for item in batch] + + batch_numpy = [] + for item in batch: + if isinstance(item, paddle.Tensor): + batch_numpy.append(item.numpy()) + else: + batch_numpy.append(item) # Obtain usable results from post-processing methods total_time += time.time() - start # Evaluate the results of the current batch if model_type in ['table', 'kie']: - eval_class(preds, batch) + eval_class(preds, batch_numpy) + elif model_type in ['vqa']: + post_result = post_process_class(preds, batch_numpy) + eval_class(post_result, batch_numpy) else: - post_result = post_process_class(preds, batch[1]) - eval_class(post_result, batch) + post_result = post_process_class(preds, batch_numpy[1]) + eval_class(post_result, batch_numpy) pbar.update(1) total_frame += len(images) @@ -479,9 +479,9 @@ def preprocess(is_train=False): FLAGS = ArgsParser().parse_args() profiler_options = FLAGS.profiler_options config = load_config(FLAGS.config) - merge_config(FLAGS.opt) + config = merge_config(config, FLAGS.opt) profile_dic = {"profiler_options": FLAGS.profiler_options} - merge_config(profile_dic) + config = merge_config(config, profile_dic) if is_train: # save_config @@ -503,13 +503,8 @@ def preprocess(is_train=False): assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', - 'SEED', 'SDMGR' + 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM' ] - windows_not_support_list = ['PSE'] - if platform.system() == "Windows" and alg in windows_not_support_list: - logger.warning('{} is not support in Windows now'.format( - windows_not_support_list)) - sys.exit() device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' device = paddle.set_device(device) diff --git a/tools/train.py b/tools/train.py index f3852469eb..506e0f7fa8 100755 --- a/tools/train.py +++ b/tools/train.py @@ -27,8 +27,6 @@ import yaml import paddle import paddle.distributed as dist -paddle.seed(2) - from ppocr.data import build_dataloader from ppocr.modeling.architectures import build_model from ppocr.losses import build_loss @@ -36,6 +34,7 @@ from ppocr.optimizer import build_optimizer from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric from ppocr.utils.save_load import load_model +from ppocr.utils.utility import set_seed import tools.program as program dist.get_world_size() @@ -97,7 +96,8 @@ def main(config, device, logger, vdl_writer): # build metric eval_class = build_metric(config['Metric']) # load pretrain model - pre_best_model_dict = load_model(config, model, optimizer) + pre_best_model_dict = load_model(config, model, optimizer, + config['Architecture']["model_type"]) logger.info('train dataloader has {} iters'.format(len(train_dataloader))) if valid_dataloader is not None: logger.info('valid dataloader has {} iters'.format( @@ -145,5 +145,7 @@ def test_reader(config, device, logger): if __name__ == '__main__': config, device, logger, vdl_writer = program.preprocess(is_train=True) + seed = config['Global']['seed'] if 'seed' in config['Global'] else 1024 + set_seed(seed) main(config, device, logger, vdl_writer) # test_reader(config, device, logger)