diff --git a/docs/en/datasets/det.md b/docs/en/datasets/det.md index 65cfd84a..6f8887b2 100644 --- a/docs/en/datasets/det.md +++ b/docs/en/datasets/det.md @@ -59,6 +59,7 @@ The structure of the text detection dataset directory is organized as follows. | Totaltext | [homepage](https://github.com/cs-chan/Total-Text-Dataset) | - | - | - | | CurvedSynText150k | [homepage](https://github.com/aim-uofa/AdelaiDet/blob/master/datasets/README.md) \| [Part1](https://drive.google.com/file/d/1OSJ-zId2h3t_-I7g_wUkrK-VqQy153Kj/view?usp=sharing) \| [Part2](https://drive.google.com/file/d/1EzkcOlIgEp5wmEubvHb7-J5EImHExYgY/view?usp=sharing) | [instances_training.json](https://download.openmmlab.com/mmocr/data/curvedsyntext/instances_training.json) | - | - | | FUNSD | [homepage](https://guillaumejaume.github.io/FUNSD/) | - | - | - | +| SROIE | [homepage](https://rrc.cvc.uab.es/?ch=13) | - | - | - | | Lecture Video DB | [homepage](https://cvit.iiit.ac.in/research/projects/cvit-projects/lecturevideodb) | - | - | - | @@ -219,6 +220,45 @@ rm dataset.zip && rm -rf dataset python tools/data/textdet/funsd_converter.py PATH/TO/funsd --nproc 4 ``` +### SROIE + +- Step1: Download `0325updated.task1train(626p).zip`, `task1&2_test(361p).zip`, and `text.task1&2-test(361p).zip` from [homepage](https://rrc.cvc.uab.es/?ch=13&com=downloads) to `sroie/` + +- Step2: + + ```bash + mkdir sroie && cd sroie + mkdir imgs && mkdir annotations && mkdir imgs/training + + # Warnninig: The zip files downloaded from Google Drive and BaiduYun Cloud may + # be different, the user should revise the following commands to the correct + # file name if encounter with errors while extracting and move the files. + unzip -q 0325updated.task1train\(626p\).zip && unzip -q task1\&2_test\(361p\).zip && unzip -q text.task1\&2-test(361p\).zip + + # For images + mv 0325updated.task1train\(626p\)/*.jpg imgs/training && mv fulltext_test\(361p\) imgs/test + + # For annotations + mv 0325updated.task1train\(626p\) annotations/training && mv text.task1\&2-testги361p\)/ annotations/test + + rm 0325updated.task1train\(626p\).zip && rm task1\&2_test\(361p\).zip && rm text.task1\&2-test(361p\).zip + ``` + +- Step3: Generate `instances_training.json` and `instances_test.json` with the following command: + + ```bash + python tools/data/textdet/sroie_converter.py PATH/TO/sroie --nproc 4 + ``` + +- After running the above codes, the directory structure should be as follows: + + ```text + ├── sroie + │   ├── annotations + │   ├── imgs + │   ├── instances_test.json + │   └── instances_training.json + ``` ### Lecture Video DB - Step1: Download [IIIT-CVid.zip](http://cdn.iiit.ac.in/cdn/preon.iiit.ac.in/~kartik/IIIT-CVid.zip) to `lv/`. diff --git a/docs/en/datasets/recog.md b/docs/en/datasets/recog.md index 111406d2..14db0a7b 100644 --- a/docs/en/datasets/recog.md +++ b/docs/en/datasets/recog.md @@ -103,6 +103,7 @@ | Totaltext | [homepage](https://github.com/cs-chan/Total-Text-Dataset) | - | - | | | OpenVINO | [Open Images](https://github.com/cvdfoundation/open-images-dataset) | [annotations](https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/datasets/open_images_v5_text) | [annotations](https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/datasets/open_images_v5_text) | | | FUNSD | [homepage](https://guillaumejaume.github.io/FUNSD/) | - | - | | +| SROIE | [homepage](https://rrc.cvc.uab.es/?ch=13) | - | - | - | | Lecture Video DB | [homepage](https://cvit.iiit.ac.in/research/projects/cvit-projects/lecturevideodb) | - | - | - | @@ -327,6 +328,44 @@ rm dataset.zip && rm -rf dataset python tools/data/textrecog/funsd_converter.py PATH/TO/funsd --nproc 4 ``` +### SROIE + +- Step1: Step1: Download `0325updated.task1train(626p).zip`, `task1&2_test(361p).zip`, and `text.task1&2-test(361p).zip` from [homepage](https://rrc.cvc.uab.es/?ch=13&com=downloads) to `sroie/` + +- Step2: + + ```bash + mkdir sroie && cd sroie + mkdir imgs && mkdir annotations && mkdir imgs/training + + # Warnninig: The zip files downloaded from Google Drive and BaiduYun Cloud may + # be different, the user should revise the following commands to the correct + # file name if encounter with errors while extracting and move the files. + unzip -q 0325updated.task1train\(626p\).zip && unzip -q task1\&2_test\(361p\).zip && unzip -q text.task1\&2-test(361p\).zip + + # For images + mv 0325updated.task1train\(626p\)/*.jpg imgs/training && mv fulltext_test\(361p\) imgs/test + + # For annotations + mv 0325updated.task1train\(626p\) annotations/training && mv text.task1\&2-testги361p\)/ annotations/test + + rm 0325updated.task1train\(626p\).zip && rm task1\&2_test\(361p\).zip && rm text.task1\&2-test(361p\).zip + ``` + +- Step3: Generate `train_label.jsonl` and `test_label.jsonl` and crop images using 4 processes with the following command: + + ```bash + python tools/data/textrecog/sroie_converter.py PATH/TO/sroie --nproc 4 + ``` + +- After running the above codes, the directory structure should be as follows: + + ```text + ├── sroie + │ ├── crops + │ ├── train_label.jsonl + │ ├── test_label.jsonl + ``` ### Lecture Video DB **The LV dataset has already provided cropped images and the corresponding annotations** diff --git a/tools/data/textdet/sroie_converter.py b/tools/data/textdet/sroie_converter.py new file mode 100644 index 00000000..c5ef1294 --- /dev/null +++ b/tools/data/textdet/sroie_converter.py @@ -0,0 +1,166 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp + +import mmcv +import numpy as np + +from mmocr.utils import convert_annotations + + +def collect_files(img_dir, gt_dir): + """Collect all images and their corresponding groundtruth files. + + Args: + img_dir (str): The image directory + gt_dir (str): The groundtruth directory + + Returns: + files (list): The list of tuples (img_file, groundtruth_file) + """ + + assert isinstance(img_dir, str) + assert img_dir + assert isinstance(gt_dir, str) + assert gt_dir + + ann_list, imgs_list = [], [] + for gt_file in os.listdir(gt_dir): + # Filtering repeated and missing images + if '(' in gt_file or gt_file == 'X51006619570.txt': + continue + ann_list.append(osp.join(gt_dir, gt_file)) + imgs_list.append(osp.join(img_dir, gt_file.replace('.txt', '.jpg'))) + + files = list(zip(sorted(imgs_list), sorted(ann_list))) + assert len(files), f'No images found in {img_dir}' + + print(f'Loaded {len(files)} images from {img_dir}') + + return files + + +def collect_annotations(files, nproc=1): + """Collect the annotation information. + + Args: + files (list): The list of tuples (image_file, groundtruth_file) + nproc (int): The number of process to collect annotations + + Returns: + images (list): The list of image information dicts + """ + + assert isinstance(files, list) + assert isinstance(nproc, int) + + if nproc > 1: + images = mmcv.track_parallel_progress( + load_img_info, files, nproc=nproc) + else: + images = mmcv.track_progress(load_img_info, files) + + return images + + +def load_img_info(files): + """Load the information of one image. + + Args: + files (tuple): The tuple of (img_file, groundtruth_file) + + Returns: + img_info (dict): The dict of the img and annotation information + """ + + assert isinstance(files, tuple) + + img_file, gt_file = files + assert osp.basename(gt_file).split('.')[0] == osp.basename(img_file).split( + '.')[0] + # read imgs while ignoring orientations + img = mmcv.imread(img_file, 'unchanged') + + img_info = dict( + file_name=osp.join(osp.basename(img_file)), + height=img.shape[0], + width=img.shape[1], + segm_file=osp.join(osp.basename(gt_file))) + + if osp.splitext(gt_file)[1] == '.txt': + img_info = load_txt_info(gt_file, img_info) + else: + raise NotImplementedError + + return img_info + + +def load_txt_info(gt_file, img_info): + """Collect the annotation information. + + Args: + gt_file (list): The list of tuples (image_file, groundtruth_file) + img_info (int): The dict of the img and annotation information + + Returns: + img_info (list): The dict of the img and annotation information + """ + + with open(gt_file, 'r', encoding='unicode_escape') as f: + anno_info = [] + for ann in f.readlines(): + + # annotation format [x1, y1, x2, y2, x3, y3, x4, y4, transcript] + try: + ann_box = np.array(ann.split(',')[0:8]).astype(int).tolist() + except ValueError: + # skip invalid annotation line + continue + x = max(0, min(ann_box[0::2])) + y = max(0, min(ann_box[1::2])) + w, h = max(ann_box[0::2]) - x, max(ann_box[1::2]) - y + bbox = [x, y, w, h] + segmentation = ann_box + + anno = dict( + iscrowd=0, + category_id=1, + bbox=bbox, + area=w * h, + segmentation=[segmentation]) + anno_info.append(anno) + + img_info.update(anno_info=anno_info) + + return img_info + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Generate training and test set of SROIE') + parser.add_argument('root_path', help='Root dir path of SROIE') + parser.add_argument( + '--nproc', default=1, type=int, help='Number of process') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + root_path = args.root_path + + for split in ['training', 'test']: + print(f'Processing {split} set...') + with mmcv.Timer(print_tmpl='It takes {}s to convert SROIE annotation'): + files = collect_files( + osp.join(root_path, 'imgs', split), + osp.join(root_path, 'annotations', split)) + image_infos = collect_annotations(files, nproc=args.nproc) + convert_annotations( + image_infos, osp.join(root_path, + 'instances_' + split + '.json')) + + +if __name__ == '__main__': + main() diff --git a/tools/data/textrecog/sroie_converter.py b/tools/data/textrecog/sroie_converter.py new file mode 100644 index 00000000..3f0ae477 --- /dev/null +++ b/tools/data/textrecog/sroie_converter.py @@ -0,0 +1,219 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import json +import os +import os.path as osp + +import mmcv +import numpy as np + +from mmocr.datasets.pipelines.crop import crop_img +from mmocr.utils.fileio import list_to_file + + +def collect_files(img_dir, gt_dir): + """Collect all images and their corresponding groundtruth files. + + Args: + img_dir (str): The image directory + gt_dir (str): The groundtruth directory + + Returns: + files (list): The list of tuples (img_file, groundtruth_file) + """ + + assert isinstance(img_dir, str) + assert img_dir + assert isinstance(gt_dir, str) + assert gt_dir + + ann_list, imgs_list = [], [] + for gt_file in os.listdir(gt_dir): + # Filtering repeated and missing images + if '(' in gt_file or gt_file == 'X51006619570.txt': + continue + ann_list.append(osp.join(gt_dir, gt_file)) + imgs_list.append(osp.join(img_dir, gt_file.replace('.txt', '.jpg'))) + + files = list(zip(sorted(imgs_list), sorted(ann_list))) + assert len(files), f'No images found in {img_dir}' + + print(f'Loaded {len(files)} images from {img_dir}') + + return files + + +def collect_annotations(files, nproc=1): + """Collect the annotation information. + + Args: + files (list): The list of tuples (image_file, groundtruth_file) + nproc (int): The number of process to collect annotations + + Returns: + images (list): The list of image information dicts + """ + + assert isinstance(files, list) + assert isinstance(nproc, int) + + if nproc > 1: + images = mmcv.track_parallel_progress( + load_img_info, files, nproc=nproc) + else: + images = mmcv.track_progress(load_img_info, files) + + return images + + +def load_img_info(files): + """Load the information of one image. + + Args: + files (tuple): The tuple of (img_file, groundtruth_file) + + Returns: + img_info (dict): The dict of the img and annotation information + """ + + assert isinstance(files, tuple) + + img_file, gt_file = files + assert osp.basename(gt_file).split('.')[0] == osp.basename(img_file).split( + '.')[0] + # read imgs while ignoring orientations + img = mmcv.imread(img_file, 'unchanged') + + img_info = dict( + file_name=osp.join(osp.basename(img_file)), + height=img.shape[0], + width=img.shape[1], + segm_file=osp.join(osp.basename(gt_file))) + + if osp.splitext(gt_file)[1] == '.txt': + img_info = load_txt_info(gt_file, img_info) + else: + raise NotImplementedError + + return img_info + + +def load_txt_info(gt_file, img_info): + """Collect the annotation information. + + Annotation Format + x1, y1, x2, y2, x3, y3, x4, y4, transcript + + Args: + gt_file (list): The list of tuples (image_file, groundtruth_file) + img_info (int): The dict of the img and annotation information + + Returns: + img_info (list): The dict of the img and annotation information + """ + + with open(gt_file, 'r', encoding='unicode_escape') as f: + anno_info = [] + for ann in f.readlines(): + # skip invalid annotation line + try: + bbox = np.array(ann.split(',')[0:8]).astype(int).tolist() + except ValueError: + + continue + word = ann.split(',')[-1].replace('\n', '').strip() + + anno = dict(bbox=bbox, word=word) + anno_info.append(anno) + + img_info.update(anno_info=anno_info) + + return img_info + + +def generate_ann(root_path, split, image_infos, format): + """Generate cropped annotations and label txt file. + + Args: + root_path (str): The root path of the dataset + split (str): The split of dataset. Namely: training or test + image_infos (list[dict]): A list of dicts of the img and + annotation information + format (str): Annotation format, should be either 'jsonl' or 'txt' + """ + + dst_image_root = osp.join(root_path, 'crops', split) + if split == 'training': + dst_label_file = osp.join(root_path, f'train_label.{format}') + elif split == 'test': + dst_label_file = osp.join(root_path, f'test_label.{format}') + os.makedirs(dst_image_root, exist_ok=True) + + lines = [] + for image_info in image_infos: + index = 1 + src_img_path = osp.join(root_path, 'imgs', split, + image_info['file_name']) + image = mmcv.imread(src_img_path) + src_img_root = image_info['file_name'].split('.')[0] + + for anno in image_info['anno_info']: + word = anno['word'] + dst_img = crop_img(image, anno['bbox'], 0, 0) + + # Skip invalid annotations + if min(dst_img.shape) == 0 or len(word) == 0: + continue + + dst_img_name = f'{src_img_root}_{index}.png' + index += 1 + dst_img_path = osp.join(dst_image_root, dst_img_name) + mmcv.imwrite(dst_img, dst_img_path) + + if format == 'txt': + lines.append(f'{osp.basename(dst_image_root)}/{dst_img_name} ' + f'{word}') + elif format == 'jsonl': + lines.append( + json.dumps({ + 'filename': + f'{osp.basename(dst_image_root)}/{dst_img_name}', + 'text': word + })) + else: + raise NotImplementedError + + list_to_file(dst_label_file, lines) + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Generate training and test set of SROIE') + parser.add_argument('root_path', help='Root dir path of SROIE') + parser.add_argument( + '--nproc', default=1, type=int, help='Number of process') + parser.add_argument( + '--format', + default='jsonl', + help='Use jsonl or string to format annotations', + choices=['jsonl', 'txt']) + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + root_path = args.root_path + + for split in ['training', 'test']: + print(f'Processing {split} set...') + with mmcv.Timer(print_tmpl='It takes {}s to convert SROIE annotation'): + files = collect_files( + osp.join(root_path, 'imgs', split), + osp.join(root_path, 'annotations', split)) + image_infos = collect_annotations(files, nproc=args.nproc) + generate_ann(root_path, split, image_infos, args.format) + + +if __name__ == '__main__': + main()