From 14c75da7bd7a31a5dd013ca43c58f80985ddd417 Mon Sep 17 00:00:00 2001 From: Xinyu Wang <45810070+xinke-wang@users.noreply.github.com> Date: Fri, 4 Mar 2022 14:55:54 +1030 Subject: [PATCH] [Feature] Add FUNSD Converter (#808) * Add FUNSD Converter * Update tools/data/textrecog/funsd_converter.py Co-authored-by: Tong Gao * Update tools/data/textrecog/funsd_converter.py Co-authored-by: Tong Gao * Update tools/data/textdet/funsd_converter.py Co-authored-by: Tong Gao * blank line between sections Co-authored-by: Tong Gao * fix incorrect docstrings * fix docstrings & fix timer * add --preserve-vertical arg for preserving vertical texts * fix --preserve-vertical * [doc] fix recog.md incorrect description * fix docstring style Co-authored-by: Tong Gao * fix docstring spaces Co-authored-by: Tong Gao --- docs/en/datasets/det.md | 54 +++++-- docs/en/datasets/recog.md | 69 ++++++-- tools/data/textdet/funsd_converter.py | 157 ++++++++++++++++++ tools/data/textrecog/funsd_converter.py | 203 ++++++++++++++++++++++++ 4 files changed, 456 insertions(+), 27 deletions(-) create mode 100644 tools/data/textdet/funsd_converter.py create mode 100644 tools/data/textrecog/funsd_converter.py diff --git a/docs/en/datasets/det.md b/docs/en/datasets/det.md index 75637e8d..93d4fdd3 100644 --- a/docs/en/datasets/det.md +++ b/docs/en/datasets/det.md @@ -36,18 +36,25 @@ The structure of the text detection dataset directory is organized as follows. │   ├── syntext_word_eng │   ├── emcs_imgs │   └── instances_training.json +|── funsd +|   ├── annotations +│   ├── imgs +│   ├── instances_test.json +│   └── instances_training.json ``` -|Dataset|Images| | Annotation Files | | | -| :-------: | :------------------------------------------------------------: | :----------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------: | :-------------------------------------: | :--------------------------------------------------------------------------------------------: | -| | | training | validation | testing | | -| CTW1500 | [homepage](https://github.com/Yuliang-Liu/Curve-Text-Detector) | - | - | - | -| ICDAR2015 | [homepage](https://rrc.cvc.uab.es/?ch=4&com=downloads) | [instances_training.json](https://download.openmmlab.com/mmocr/data/icdar2015/instances_training.json) | - | [instances_test.json](https://download.openmmlab.com/mmocr/data/icdar2015/instances_test.json) | -| ICDAR2017 | [homepage](https://rrc.cvc.uab.es/?ch=8&com=downloads) | [instances_training.json](https://download.openmmlab.com/mmocr/data/icdar2017/instances_training.json) | [instances_val.json](https://download.openmmlab.com/mmocr/data/icdar2017/instances_val.json) | - | | | -| Synthtext | [homepage](https://www.robots.ox.ac.uk/~vgg/data/scenetext/) | instances_training.lmdb ([data.mdb](https://download.openmmlab.com/mmocr/data/synthtext/instances_training.lmdb/data.mdb), [lock.mdb](https://download.openmmlab.com/mmocr/data/synthtext/instances_training.lmdb/lock.mdb)) | - | - | -| TextOCR | [homepage](https://textvqa.org/textocr/dataset) | - | - | - -| Totaltext | [homepage](https://github.com/cs-chan/Total-Text-Dataset) | - | - | - -| CurvedSynText150k | [homepage](https://github.com/aim-uofa/AdelaiDet/blob/master/datasets/README.md) \| [Part1](https://drive.google.com/file/d/1OSJ-zId2h3t_-I7g_wUkrK-VqQy153Kj/view?usp=sharing) \| [Part2](https://drive.google.com/file/d/1EzkcOlIgEp5wmEubvHb7-J5EImHExYgY/view?usp=sharing) | [instances_training.json](https://download.openmmlab.com/mmocr/data/curvedsyntext/instances_training.json) | - | - | +| Dataset | Images | | Annotation Files | | | +| :---------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------: | :---: | +| | | training | validation | testing | | +| CTW1500 | [homepage](https://github.com/Yuliang-Liu/Curve-Text-Detector) | - | - | - | +| ICDAR2015 | [homepage](https://rrc.cvc.uab.es/?ch=4&com=downloads) | [instances_training.json](https://download.openmmlab.com/mmocr/data/icdar2015/instances_training.json) | - | [instances_test.json](https://download.openmmlab.com/mmocr/data/icdar2015/instances_test.json) | +| ICDAR2017 | [homepage](https://rrc.cvc.uab.es/?ch=8&com=downloads) | [instances_training.json](https://download.openmmlab.com/mmocr/data/icdar2017/instances_training.json) | [instances_val.json](https://download.openmmlab.com/mmocr/data/icdar2017/instances_val.json) | - | | | +| Synthtext | [homepage](https://www.robots.ox.ac.uk/~vgg/data/scenetext/) | instances_training.lmdb ([data.mdb](https://download.openmmlab.com/mmocr/data/synthtext/instances_training.lmdb/data.mdb), [lock.mdb](https://download.openmmlab.com/mmocr/data/synthtext/instances_training.lmdb/lock.mdb)) | - | - | +| TextOCR | [homepage](https://textvqa.org/textocr/dataset) | - | - | - | +| Totaltext | [homepage](https://github.com/cs-chan/Total-Text-Dataset) | - | - | - | +| CurvedSynText150k | [homepage](https://github.com/aim-uofa/AdelaiDet/blob/master/datasets/README.md) \| [Part1](https://drive.google.com/file/d/1OSJ-zId2h3t_-I7g_wUkrK-VqQy153Kj/view?usp=sharing) \| [Part2](https://drive.google.com/file/d/1EzkcOlIgEp5wmEubvHb7-J5EImHExYgY/view?usp=sharing) | [instances_training.json](https://download.openmmlab.com/mmocr/data/curvedsyntext/instances_training.json) | - | - | +| FUNSD | [homepage](https://guillaumejaume.github.io/FUNSD/) | - | - | - | + ## Important Note @@ -178,3 +185,30 @@ rm images.zip ```bash python tools/data/common/curvedsyntext_converter.py PATH/TO/CurvedSynText150k --nproc 4 ``` + +### FUNSD + +- Step1: Download [dataset.zip](https://guillaumejaume.github.io/FUNSD/dataset.zip) to `funsd/`. + +```bash +mkdir funsd && cd funsd + +# Download FUNSD dataset +wget https://guillaumejaume.github.io/FUNSD/dataset.zip +unzip -q dataset.zip + +# For images +mv dataset/training_data/images imgs && mv dataset/testing_data/images/* imgs/ + +# For annotations +mkdir annotations +mv dataset/training_data/annotations annotations/training && mv dataset/testing_data/annotations annotations/test + +rm dataset.zip && rm -rf dataset +``` + +- Step2: Generate `instances_training.json` and `instances_test.json` with following command: + +```bash +python tools/data/textdet/funsd_converter.py PATH/TO/funsd --nproc 4 +``` diff --git a/docs/en/datasets/recog.md b/docs/en/datasets/recog.md index 0b3aee47..47d1e18c 100644 --- a/docs/en/datasets/recog.md +++ b/docs/en/datasets/recog.md @@ -73,25 +73,33 @@ │ │ ├── train_5_label.txt │ │ ├── train_f_label.txt │ │ ├── val_label.txt +│   ├── funsd +│ │ ├── imgs +│ │ ├── dst_imgs +│ │ ├── annotations +│ │ ├── train_label.txt +│ │ ├── test_label.txt ``` -| Dataset | images | annotation file | annotation file | -| :--------: | :-----------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------: | -| | | training | test | -| coco_text | [homepage](https://rrc.cvc.uab.es/?ch=5&com=downloads) | [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/coco_text/train_label.txt) | - | | -| icdar_2011 | [homepage](http://www.cvc.uab.es/icdar2011competition/?com=downloads) | [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2015/train_label.txt) | - | | -| icdar_2013 | [homepage](https://rrc.cvc.uab.es/?ch=2&com=downloads) | [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2013/train_label.txt) | [test_label_1015.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2013/test_label_1015.txt) | | -| icdar_2015 | [homepage](https://rrc.cvc.uab.es/?ch=4&com=downloads) | [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2015/train_label.txt) | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2015/test_label.txt) | | -| IIIT5K | [homepage](http://cvit.iiit.ac.in/projects/SceneTextUnderstanding/IIIT5K.html) | [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/IIIT5K/train_label.txt) | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/IIIT5K/test_label.txt) | | -| ct80 | [homepage](http://cs-chan.com/downloads_CUTE80_dataset.html) | - | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/ct80/test_label.txt) | | -| svt |[homepage](http://www.iapr-tc11.org/mediawiki/index.php/The_Street_View_Text_Dataset) | - | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/svt/test_label.txt) | | -| svtp | [unofficial homepage\[1\]](https://github.com/Jyouhou/Case-Sensitive-Scene-Text-Recognition-Datasets) | - | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/svtp/test_label.txt) | | -| MJSynth (Syn90k) | [homepage](https://www.robots.ox.ac.uk/~vgg/data/text/) | [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/Syn90k/shuffle_labels.txt) \| [label.txt](https://download.openmmlab.com/mmocr/data/mixture/Syn90k/label.txt) | - | | -| SynthText (Synth800k) | [homepage](https://www.robots.ox.ac.uk/~vgg/data/scenetext/) | [alphanumeric_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/alphanumeric_labels.txt) \|[shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/shuffle_labels.txt) \| [instances_train.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/instances_train.txt) \| [label.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/label.txt) | - | | -| SynthAdd | [SynthText_Add.zip](https://pan.baidu.com/s/1uV0LtoNmcxbO-0YA7Ch4dg) (code:627x) | [label.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthAdd/label.txt) | - | | -| TextOCR | [homepage](https://textvqa.org/textocr/dataset) | - | - | | -| Totaltext | [homepage](https://github.com/cs-chan/Total-Text-Dataset) | - | - | | -| OpenVINO | [Open Images](https://github.com/cvdfoundation/open-images-dataset) | [annotations](https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/datasets/open_images_v5_text) |[annotations](https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/datasets/open_images_v5_text)| | +| Dataset | images | annotation file | annotation file | +| :-------------------: | :---------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------: | +| | | training | test | +| coco_text | [homepage](https://rrc.cvc.uab.es/?ch=5&com=downloads) | [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/coco_text/train_label.txt) | - | | +| icdar_2011 | [homepage](http://www.cvc.uab.es/icdar2011competition/?com=downloads) | [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2015/train_label.txt) | - | | +| icdar_2013 | [homepage](https://rrc.cvc.uab.es/?ch=2&com=downloads) | [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2013/train_label.txt) | [test_label_1015.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2013/test_label_1015.txt) | | +| icdar_2015 | [homepage](https://rrc.cvc.uab.es/?ch=4&com=downloads) | [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2015/train_label.txt) | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/icdar_2015/test_label.txt) | | +| IIIT5K | [homepage](http://cvit.iiit.ac.in/projects/SceneTextUnderstanding/IIIT5K.html) | [train_label.txt](https://download.openmmlab.com/mmocr/data/mixture/IIIT5K/train_label.txt) | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/IIIT5K/test_label.txt) | | +| ct80 | [homepage](http://cs-chan.com/downloads_CUTE80_dataset.html) | - | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/ct80/test_label.txt) | | +| svt | [homepage](http://www.iapr-tc11.org/mediawiki/index.php/The_Street_View_Text_Dataset) | - | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/svt/test_label.txt) | | +| svtp | [unofficial homepage\[1\]](https://github.com/Jyouhou/Case-Sensitive-Scene-Text-Recognition-Datasets) | - | [test_label.txt](https://download.openmmlab.com/mmocr/data/mixture/svtp/test_label.txt) | | +| MJSynth (Syn90k) | [homepage](https://www.robots.ox.ac.uk/~vgg/data/text/) | [shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/Syn90k/shuffle_labels.txt) \| [label.txt](https://download.openmmlab.com/mmocr/data/mixture/Syn90k/label.txt) | - | | +| SynthText (Synth800k) | [homepage](https://www.robots.ox.ac.uk/~vgg/data/scenetext/) | [alphanumeric_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/alphanumeric_labels.txt) \|[shuffle_labels.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/shuffle_labels.txt) \| [instances_train.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/instances_train.txt) \| [label.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthText/label.txt) | - | | +| SynthAdd | [SynthText_Add.zip](https://pan.baidu.com/s/1uV0LtoNmcxbO-0YA7Ch4dg) (code:627x) | [label.txt](https://download.openmmlab.com/mmocr/data/mixture/SynthAdd/label.txt) | - | | +| TextOCR | [homepage](https://textvqa.org/textocr/dataset) | - | - | | +| Totaltext | [homepage](https://github.com/cs-chan/Total-Text-Dataset) | - | - | | +| OpenVINO | [Open Images](https://github.com/cvdfoundation/open-images-dataset) | [annotations](https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/datasets/open_images_v5_text) | [annotations](https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/datasets/open_images_v5_text) | | +| FUNSD | [homepage](https://guillaumejaume.github.io/FUNSD/) | - | - | | + (*) Since the official homepage is unavailable now, we provide an alternative for quick reference. However, we do not guarantee the correctness of the dataset. @@ -286,3 +294,30 @@ python tools/data/utils/txt2lmdb.py -i data/mixture/Syn90k/label.txt -o data/mix ```bash python tools/data/textrecog/openvino_converter.py /path/to/openvino 4 ``` + +### FUNSD + +- Step1: Download [dataset.zip](https://guillaumejaume.github.io/FUNSD/dataset.zip) to `funsd/`. + +```bash +mkdir funsd && cd funsd + +# Download FUNSD dataset +wget https://guillaumejaume.github.io/FUNSD/dataset.zip +unzip -q dataset.zip + +# For images +mv dataset/training_data/images imgs && mv dataset/testing_data/images/* imgs/ + +# For annotations +mkdir annotations +mv dataset/training_data/annotations annotations/training && mv dataset/testing_data/annotations annotations/test + +rm dataset.zip && rm -rf dataset +``` + +- Step2: Generate `train_label.txt` and `test_label.txt` and crop images using 4 processes with following command (add `--preserve-vertical` if you wish to preserve the images containing vertical texts): + +```bash +python tools/data/textrecog/funsd_converter.py PATH/TO/funsd --nproc 4 +``` diff --git a/tools/data/textdet/funsd_converter.py b/tools/data/textdet/funsd_converter.py new file mode 100644 index 00000000..6e3cf5dc --- /dev/null +++ b/tools/data/textdet/funsd_converter.py @@ -0,0 +1,157 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import math +import os +import os.path as osp + +import mmcv + +from mmocr.utils import convert_annotations + + +def collect_files(img_dir, gt_dir): + """Collect all images and their corresponding groundtruth files. + + Args: + img_dir (str): The image directory + gt_dir (str): The groundtruth directory + + Returns: + files (list): The list of tuples (img_file, groundtruth_file) + """ + assert isinstance(img_dir, str) + assert img_dir + assert isinstance(gt_dir, str) + assert gt_dir + + ann_list, imgs_list = [], [] + for gt_file in os.listdir(gt_dir): + ann_list.append(osp.join(gt_dir, gt_file)) + imgs_list.append(osp.join(img_dir, gt_file.replace('.json', '.png'))) + + files = list(zip(sorted(imgs_list), sorted(ann_list))) + assert len(files), f'No images found in {img_dir}' + print(f'Loaded {len(files)} images from {img_dir}') + + return files + + +def collect_annotations(files, nproc=1): + """Collect the annotation information. + + Args: + files (list): The list of tuples (image_file, groundtruth_file) + nproc (int): The number of process to collect annotations + + Returns: + images (list): The list of image information dicts + """ + assert isinstance(files, list) + assert isinstance(nproc, int) + + if nproc > 1: + images = mmcv.track_parallel_progress( + load_img_info, files, nproc=nproc) + else: + images = mmcv.track_progress(load_img_info, files) + + return images + + +def load_img_info(files): + """Load the information of one image. + + Args: + files (tuple): The tuple of (img_file, groundtruth_file) + + Returns: + img_info (dict): The dict of the img and annotation information + """ + assert isinstance(files, tuple) + + img_file, gt_file = files + assert osp.basename(gt_file).split('.')[0] == osp.basename(img_file).split( + '.')[0] + # read imgs while ignoring orientations + img = mmcv.imread(img_file, 'unchanged') + + img_info = dict( + file_name=osp.join(osp.basename(img_file)), + height=img.shape[0], + width=img.shape[1], + segm_file=osp.join(osp.basename(gt_file))) + + if osp.splitext(gt_file)[1] == '.json': + img_info = load_json_info(gt_file, img_info) + else: + raise NotImplementedError + + return img_info + + +def load_json_info(gt_file, img_info): + """Collect the annotation information. + + Args: + gt_file (str): The path to ground-truth + img_info (dict): The dict of the img and annotation information + + Returns: + img_info (dict): The dict of the img and annotation information + """ + + annotation = mmcv.load(gt_file) + anno_info = [] + for form in annotation['form']: + for ann in form['words']: + + iscrowd = 1 if len(ann['text']) == 0 else 0 + + x1, y1, x2, y2 = ann['box'] + x = max(0, min(math.floor(x1), math.floor(x2))) + y = max(0, min(math.floor(y1), math.floor(y2))) + w, h = math.ceil(abs(x2 - x1)), math.ceil(abs(y2 - y1)) + bbox = [x, y, w, h] + segmentation = [x, y, x + w, y, x + w, y + h, x, y + h] + + anno = dict( + iscrowd=iscrowd, + category_id=1, + bbox=bbox, + area=w * h, + segmentation=[segmentation]) + anno_info.append(anno) + + img_info.update(anno_info=anno_info) + + return img_info + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Generate training and test set of FUNSD ') + parser.add_argument('root_path', help='Root dir path of FUNSD') + parser.add_argument( + '--nproc', default=1, type=int, help='number of process') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + root_path = args.root_path + + for split in ['training', 'test']: + print(f'Processing {split} set...') + with mmcv.Timer(print_tmpl='It takes {}s to convert FUNSD annotation'): + files = collect_files( + osp.join(root_path, 'imgs'), + osp.join(root_path, 'annotations', split)) + image_infos = collect_annotations(files, nproc=args.nproc) + convert_annotations( + image_infos, osp.join(root_path, + 'instances_' + split + '.json')) + + +if __name__ == '__main__': + main() diff --git a/tools/data/textrecog/funsd_converter.py b/tools/data/textrecog/funsd_converter.py new file mode 100644 index 00000000..28311120 --- /dev/null +++ b/tools/data/textrecog/funsd_converter.py @@ -0,0 +1,203 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import math +import os +import os.path as osp + +import mmcv + +from mmocr.datasets.pipelines.crop import crop_img +from mmocr.utils.fileio import list_to_file + + +def collect_files(img_dir, gt_dir): + """Collect all images and their corresponding groundtruth files. + + Args: + img_dir (str): The image directory + gt_dir (str): The groundtruth directory + + Returns: + files (list): The list of tuples (img_file, groundtruth_file) + """ + assert isinstance(img_dir, str) + assert img_dir + assert isinstance(gt_dir, str) + assert gt_dir + + ann_list, imgs_list = [], [] + for gt_file in os.listdir(gt_dir): + ann_list.append(osp.join(gt_dir, gt_file)) + imgs_list.append(osp.join(img_dir, gt_file.replace('.json', '.png'))) + + files = list(zip(sorted(imgs_list), sorted(ann_list))) + assert len(files), f'No images found in {img_dir}' + print(f'Loaded {len(files)} images from {img_dir}') + + return files + + +def collect_annotations(files, nproc=1): + """Collect the annotation information. + + Args: + files (list): The list of tuples (image_file, groundtruth_file) + nproc (int): The number of process to collect annotations + + Returns: + images (list): The list of image information dicts + """ + assert isinstance(files, list) + assert isinstance(nproc, int) + + if nproc > 1: + images = mmcv.track_parallel_progress( + load_img_info, files, nproc=nproc) + else: + images = mmcv.track_progress(load_img_info, files) + + return images + + +def load_img_info(files): + """Load the information of one image. + + Args: + files (tuple): The tuple of (img_file, groundtruth_file) + + Returns: + img_info (dict): The dict of the img and annotation information + """ + assert isinstance(files, tuple) + + img_file, gt_file = files + assert osp.basename(gt_file).split('.')[0] == osp.basename(img_file).split( + '.')[0] + # read imgs while ignoring orientations + img = mmcv.imread(img_file, 'unchanged') + + img_info = dict( + file_name=osp.join(osp.basename(img_file)), + height=img.shape[0], + width=img.shape[1], + segm_file=osp.join(osp.basename(gt_file))) + + if osp.splitext(gt_file)[1] == '.json': + img_info = load_json_info(gt_file, img_info) + else: + raise NotImplementedError + + return img_info + + +def load_json_info(gt_file, img_info): + """Collect the annotation information. + + Args: + gt_file (str): The path to ground-truth + img_info (dict): The dict of the img and annotation information + + Returns: + img_info (dict): The dict of the img and annotation information + """ + + annotation = mmcv.load(gt_file) + anno_info = [] + for form in annotation['form']: + for ann in form['words']: + + # Ignore illegible samples + if len(ann['text']) == 0: + continue + + x1, y1, x2, y2 = ann['box'] + x = max(0, min(math.floor(x1), math.floor(x2))) + y = max(0, min(math.floor(y1), math.floor(y2))) + w, h = math.ceil(abs(x2 - x1)), math.ceil(abs(y2 - y1)) + bbox = [x, y, x + w, y, x + w, y + h, x, y + h] + word = ann['text'] + + anno = dict(bbox=bbox, word=word) + anno_info.append(anno) + + img_info.update(anno_info=anno_info) + + return img_info + + +def generate_ann(root_path, split, image_infos, preserve_vertical): + """Generate cropped annotations and label txt file. + + Args: + root_path (str): The root path of the dataset + split (str): The split of dataset. Namely: training or test + image_infos (list[dict]): A list of dicts of the img and + annotation information + preserve_vertical (bool): Whether to preserve vertical texts + """ + + dst_image_root = osp.join(root_path, 'dst_imgs', split) + if split == 'training': + dst_label_file = osp.join(root_path, 'train_label.txt') + elif split == 'test': + dst_label_file = osp.join(root_path, 'test_label.txt') + os.makedirs(dst_image_root, exist_ok=True) + + lines = [] + for image_info in image_infos: + index = 1 + src_img_path = osp.join(root_path, 'imgs', image_info['file_name']) + image = mmcv.imread(src_img_path) + src_img_root = image_info['file_name'].split('.')[0] + + for anno in image_info['anno_info']: + word = anno['word'] + dst_img = crop_img(image, anno['bbox']) + h, w, _ = dst_img.shape + + # Skip invalid annotations + if min(dst_img.shape) == 0: + continue + # Skip vertical texts + if not preserve_vertical and h / w > 2: + continue + + dst_img_name = f'{src_img_root}_{index}.png' + index += 1 + dst_img_path = osp.join(dst_image_root, dst_img_name) + mmcv.imwrite(dst_img, dst_img_path) + lines.append(f'{osp.basename(dst_image_root)}/{dst_img_name} ' + f'{word}') + list_to_file(dst_label_file, lines) + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Generate training and test set of FUNSD ') + parser.add_argument('root_path', help='Root dir path of FUNSD') + parser.add_argument( + '--preserve-vertical', + help='Preserve samples containing vertical texts', + action='store_true') + parser.add_argument( + '--nproc', default=1, type=int, help='Number of processes') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + root_path = args.root_path + + for split in ['training', 'test']: + print(f'Processing {split} set...') + with mmcv.Timer(print_tmpl='It takes {}s to convert FUNSD annotation'): + files = collect_files( + osp.join(root_path, 'imgs'), + osp.join(root_path, 'annotations', split)) + image_infos = collect_annotations(files, nproc=args.nproc) + generate_ann(root_path, split, image_infos, args.preserve_vertical) + + +if __name__ == '__main__': + main()