From 75c06d34bbc01d3d11dfd7afc098b6cdeee82579 Mon Sep 17 00:00:00 2001
From: Qing Jiang <mountchicken@outlook.com>
Date: Wed, 8 Mar 2023 17:32:00 +0800
Subject: [PATCH] [Dataset Preparer] Add SCUT-CTW1500 (#1677)

* update metafile and download

* update parser

* updata ctw1500 to new dataprepare  design

* add lexicon into ctw1500 textspotting

* fix

---------

Co-authored-by: liukuikun <641417025@qq.com>
Co-authored-by: gaotongxiao <gaotongxiao@gmail.com>
---
 dataset_zoo/ctw1500/metafile.yml              |  32 ++++++
 dataset_zoo/ctw1500/textdet.py                |  76 +++++++++++++
 dataset_zoo/ctw1500/textrecog.py              |   9 ++
 dataset_zoo/ctw1500/textspotting.py           |  19 ++++
 mmocr/datasets/preparers/parsers/__init__.py  |   3 +-
 .../preparers/parsers/ctw1500_parser.py       | 102 ++++++++++++++++++
 .../test_parsers/test_ctw1500_parser.py       |  72 +++++++++++++
 7 files changed, 312 insertions(+), 1 deletion(-)
 create mode 100644 dataset_zoo/ctw1500/metafile.yml
 create mode 100644 dataset_zoo/ctw1500/textdet.py
 create mode 100644 dataset_zoo/ctw1500/textrecog.py
 create mode 100644 dataset_zoo/ctw1500/textspotting.py
 create mode 100644 mmocr/datasets/preparers/parsers/ctw1500_parser.py
 create mode 100644 tests/test_datasets/test_preparers/test_parsers/test_ctw1500_parser.py

diff --git a/dataset_zoo/ctw1500/metafile.yml b/dataset_zoo/ctw1500/metafile.yml
new file mode 100644
index 00000000..58af8367
--- /dev/null
+++ b/dataset_zoo/ctw1500/metafile.yml
@@ -0,0 +1,32 @@
+Name: 'CTW1500'
+Paper:
+  Title: Curved scene text detection via transverse and longitudinal sequence connection
+  URL: https://www.sciencedirect.com/science/article/pii/S0031320319300664
+  Venue: PR
+  Year: '2019'
+  BibTeX: '@article{liu2019curved,
+  title={Curved scene text detection via transverse and longitudinal sequence connection},
+  author={Liu, Yuliang and Jin, Lianwen and Zhang, Shuaitao and Luo, Canjie and Zhang, Sheng},
+  journal={Pattern Recognition},
+  volume={90},
+  pages={337--345},
+  year={2019},
+  publisher={Elsevier}
+}'
+Data:
+  Website: https://github.com/Yuliang-Liu/Curve-Text-Detector
+  Language:
+    - English
+  Scene:
+    - Scene
+  Granularity:
+    - Word
+    - Line
+  Tasks:
+    - textrecog
+    - textdet
+    - textspotting
+  License:
+    Type: N/A
+    Link: N/A
+  Format: .xml
diff --git a/dataset_zoo/ctw1500/textdet.py b/dataset_zoo/ctw1500/textdet.py
new file mode 100644
index 00000000..82783421
--- /dev/null
+++ b/dataset_zoo/ctw1500/textdet.py
@@ -0,0 +1,76 @@
+data_root = 'data/ctw1500'
+cache_path = 'data/cache'
+
+train_preparer = dict(
+    obtainer=dict(
+        type='NaiveDataObtainer',
+        cache_path=cache_path,
+        files=[
+            dict(
+                url='https://universityofadelaide.box.com/shared/static/'
+                'py5uwlfyyytbb2pxzq9czvu6fuqbjdh8.zip',
+                save_name='ctw1500_train_images.zip',
+                md5='f1453464b764343040644464d5c0c4fa',
+                split=['train'],
+                content=['image'],
+                mapping=[[
+                    'ctw1500_train_images/train_images', 'textdet_imgs/train'
+                ]]),
+            dict(
+                url='https://universityofadelaide.box.com/shared/static/'
+                'jikuazluzyj4lq6umzei7m2ppmt3afyw.zip',
+                save_name='ctw1500_train_labels.zip',
+                md5='d9ba721b25be95c2d78aeb54f812a5b1',
+                split=['train'],
+                content=['annotation'],
+                mapping=[[
+                    'ctw1500_train_labels/ctw1500_train_labels/',
+                    'annotations/train'
+                ]])
+        ]),
+    gatherer=dict(
+        type='PairGatherer',
+        img_suffixes=['.jpg', '.JPG'],
+        rule=[r'(\d{4}).jpg', r'\1.xml']),
+    parser=dict(type='CTW1500AnnParser'),
+    packer=dict(type='TextDetPacker'),
+    dumper=dict(type='JsonDumper'),
+)
+
+test_preparer = dict(
+    obtainer=dict(
+        type='NaiveDataObtainer',
+        cache_path=cache_path,
+        files=[
+            dict(
+                url='https://universityofadelaide.box.com/shared/static/'
+                't4w48ofnqkdw7jyc4t11nsukoeqk9c3d.zip',
+                save_name='ctw1500_test_images.zip',
+                md5='79103fd77dfdd2c70ae6feb3a2fb4530',
+                split=['test'],
+                content=['image'],
+                mapping=[[
+                    'ctw1500_test_images/test_images', 'textdet_imgs/test'
+                ]]),
+            dict(
+                url='https://cloudstor.aarnet.edu.au/plus/s/uoeFl0pCN9BOCN5/'
+                'download',
+                save_name='ctw1500_test_labels.zip',
+                md5='7f650933a30cf1bcdbb7874e4962a52b',
+                split=['test'],
+                content=['annotation'],
+                mapping=[['ctw1500_test_labels', 'annotations/test']])
+        ]),
+    gatherer=dict(
+        type='PairGatherer',
+        img_suffixes=['.jpg', '.JPG'],
+        rule=[r'(\d{4}).jpg', r'000\1.txt']),
+    parser=dict(type='CTW1500AnnParser'),
+    packer=dict(type='TextDetPacker'),
+    dumper=dict(type='JsonDumper'),
+)
+delete = [
+    'ctw1500_train_images', 'ctw1500_test_images', 'annotations',
+    'ctw1500_train_labels', 'ctw1500_test_labels'
+]
+config_generator = dict(type='TextDetConfigGenerator')
diff --git a/dataset_zoo/ctw1500/textrecog.py b/dataset_zoo/ctw1500/textrecog.py
new file mode 100644
index 00000000..c4436e07
--- /dev/null
+++ b/dataset_zoo/ctw1500/textrecog.py
@@ -0,0 +1,9 @@
+_base_ = ['textdet.py']
+
+_base_.train_preparer.gatherer.img_dir = 'textdet_imgs/train'
+_base_.test_preparer.gatherer.img_dir = 'textdet_imgs/test'
+
+_base_.train_preparer.packer.type = 'TextRecogCropPacker'
+_base_.test_preparer.packer.type = 'TextRecogCropPacker'
+
+config_generator = dict(type='TextRecogConfigGenerator')
diff --git a/dataset_zoo/ctw1500/textspotting.py b/dataset_zoo/ctw1500/textspotting.py
new file mode 100644
index 00000000..55d196d8
--- /dev/null
+++ b/dataset_zoo/ctw1500/textspotting.py
@@ -0,0 +1,19 @@
+_base_ = ['textdet.py']
+
+_base_.train_preparer.gatherer.img_dir = 'textdet_imgs/train'
+_base_.test_preparer.gatherer.img_dir = 'textdet_imgs/test'
+
+_base_.train_preparer.packer.type = 'TextSpottingPacker'
+_base_.test_preparer.packer.type = 'TextSpottingPacker'
+
+_base_.test_preparer.obtainer.files.append(
+    dict(
+        url='https://download.openmmlab.com/mmocr/data/1.x/textspotting/'
+        'ctw1500/lexicons.zip',
+        save_name='ctw1500_lexicons.zip',
+        md5='168150ca45da161917bf35a20e45b8d6',
+        content=['lexicons'],
+        mapping=[['ctw1500_lexicons/lexicons', 'lexicons']]))
+
+_base_.delete.append('ctw1500_lexicons')
+config_generator = dict(type='TextSpottingConfigGenerator')
diff --git a/mmocr/datasets/preparers/parsers/__init__.py b/mmocr/datasets/preparers/parsers/__init__.py
index 9b818863..8c79ca99 100644
--- a/mmocr/datasets/preparers/parsers/__init__.py
+++ b/mmocr/datasets/preparers/parsers/__init__.py
@@ -1,6 +1,7 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 from .base import BaseParser
 from .coco_parser import COCOTextDetAnnParser
+from .ctw1500_parser import CTW1500AnnParser
 from .funsd_parser import FUNSDTextDetAnnParser
 from .icdar_txt_parser import (ICDARTxtTextDetAnnParser,
                                ICDARTxtTextRecogAnnParser)
@@ -14,5 +15,5 @@ __all__ = [
     'BaseParser', 'ICDARTxtTextDetAnnParser', 'ICDARTxtTextRecogAnnParser',
     'TotaltextTextDetAnnParser', 'WildreceiptKIEAnnParser',
     'COCOTextDetAnnParser', 'SVTTextDetAnnParser', 'FUNSDTextDetAnnParser',
-    'SROIETextDetAnnParser', 'NAFAnnParser'
+    'SROIETextDetAnnParser', 'NAFAnnParser', 'CTW1500AnnParser'
 ]
diff --git a/mmocr/datasets/preparers/parsers/ctw1500_parser.py b/mmocr/datasets/preparers/parsers/ctw1500_parser.py
new file mode 100644
index 00000000..4c6bdbc5
--- /dev/null
+++ b/mmocr/datasets/preparers/parsers/ctw1500_parser.py
@@ -0,0 +1,102 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import xml.etree.ElementTree as ET
+from typing import List, Tuple
+
+import numpy as np
+
+from mmocr.datasets.preparers.data_preparer import DATA_PARSERS
+from mmocr.datasets.preparers.parsers.base import BaseParser
+from mmocr.utils import list_from_file
+
+
+@DATA_PARSERS.register_module()
+class CTW1500AnnParser(BaseParser):
+    """SCUT-CTW1500 dataset parser.
+
+    Args:
+        ignore (str): The text of the ignored instances. Defaults to
+            '###'.
+    """
+
+    def __init__(self, ignore: str = '###', **kwargs) -> None:
+        self.ignore = ignore
+        super().__init__(**kwargs)
+
+    def parse_file(self, img_path: str, ann_path: str) -> Tuple:
+        """Convert annotation for a single image.
+
+        Args:
+            img_path (str): The path of image.
+            ann_path (str): The path of annotation.
+
+        Returns:
+            Tuple: A tuple of (img_path, instance).
+
+            - img_path (str): The path of image file, which can be read
+              directly by opencv.
+            - instance: instance is a list of dict containing parsed
+              annotations, which should contain the following keys:
+
+              - 'poly' or 'box' (textdet or textspotting)
+              - 'text' (textspotting or textrecog)
+              - 'ignore' (all task)
+
+        Examples:
+            An example of returned values:
+            >>> ('imgs/train/xxx.jpg',
+            >>> dict(
+            >>>    poly=[[[0, 1], [1, 1], [1, 0], [0, 0]]],
+            >>>    text='hello',
+            >>>    ignore=False)
+            >>> )
+        """
+
+        if self.split == 'train':
+            instances = self.load_xml_info(ann_path)
+        elif self.split == 'test':
+            instances = self.load_txt_info(ann_path)
+        return img_path, instances
+
+    def load_txt_info(self, anno_dir: str) -> List:
+        """Load the annotation of the SCUT-CTW dataset (test split).
+        Args:
+            anno_dir (str): Path to the annotation file.
+
+        Returns:
+            list[Dict]: List of instances.
+        """
+        instances = list()
+        for line in list_from_file(anno_dir):
+            # each line has one ploygen (n vetices), and one text.
+            # e.g., 695,885,866,888,867,1146,696,1143,####Latin 9
+            line = line.strip()
+            strs = line.split(',')
+            assert strs[28][0] == '#'
+            xy = [int(x) for x in strs[0:28]]
+            assert len(xy) == 28
+            poly = np.array(xy).reshape(-1).tolist()
+            text = strs[28][4:]
+            instances.append(
+                dict(poly=poly, text=text, ignore=text == self.ignore))
+        return instances
+
+    def load_xml_info(self, anno_dir: str) -> List:
+        """Load the annotation of the SCUT-CTW dataset (train split).
+        Args:
+            anno_dir (str): Path to the annotation file.
+
+        Returns:
+            list[Dict]: List of instances.
+        """
+        obj = ET.parse(anno_dir)
+        instances = list()
+        for image in obj.getroot():  # image
+            for box in image:  # image
+                text = box[0].text
+                segs = box[1].text
+                pts = segs.strip().split(',')
+                pts = [int(x) for x in pts]
+                assert len(pts) == 28
+                poly = np.array(pts).reshape(-1).tolist()
+                instances.append(dict(poly=poly, text=text, ignore=0))
+        return instances
diff --git a/tests/test_datasets/test_preparers/test_parsers/test_ctw1500_parser.py b/tests/test_datasets/test_preparers/test_parsers/test_ctw1500_parser.py
new file mode 100644
index 00000000..c6d5d52e
--- /dev/null
+++ b/tests/test_datasets/test_preparers/test_parsers/test_ctw1500_parser.py
@@ -0,0 +1,72 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import tempfile
+import unittest
+
+from mmocr.datasets.preparers.parsers import CTW1500AnnParser
+from mmocr.utils import list_to_file
+
+
+class TestCTW1500AnnParser(unittest.TestCase):
+
+    def setUp(self) -> None:
+        self.root = tempfile.TemporaryDirectory()
+
+    def _create_dummy_ctw1500_det(self):
+        fake_train_anno = [
+            '<Annotations>',
+            ' <image file="0200.jpg">',
+            '   <box height="197" left="131" top="49" width="399">',
+            '     <label>OLATHE</label>',
+            '     <segs>131,58,208,49,279,56,346,76,412,101,473,141,530,192,510,246,458,210,405,175,350,151,291,137,228,133,165,134</segs>',  # noqa: E501
+            '     <pts x="183" y="95" />',
+            '     <pts x="251" y="89" />',
+            '     <pts x="322" y="107" />',
+            '     <pts x="383" y="124" />',
+            '     <pts x="441" y="161" />',
+            '     <pts x="493" y="201" />',
+            '   </box>',
+            ' </image>',
+            '</Annotations>',
+        ]
+        train_ann_file = osp.join(self.root.name, 'ctw1500_train.xml')
+        list_to_file(train_ann_file, fake_train_anno)
+
+        fake_test_anno = [
+            '48,84,61,79,75,73,88,68,102,74,116,79,130,84,135,73,119,67,104,60,89,56,74,61,59,67,45,73,#######',  # noqa: E501
+            '51,137,58,137,66,137,74,137,82,137,90,137,98,137,98,119,90,119,82,119,74,119,66,119,58,119,50,119,####E-313',  # noqa: E501
+            '41,155,49,155,57,155,65,155,73,155,81,155,89,155,87,136,79,136,71,136,64,136,56,136,48,136,41,137,#######',  # noqa: E501
+            '41,193,57,193,74,194,90,194,107,195,123,195,140,196,146,168,128,167,110,167,92,167,74,166,56,166,39,166,####F.D.N.Y.',  # noqa: E501
+        ]
+        test_ann_file = osp.join(self.root.name, 'ctw1500_test.txt')
+        list_to_file(test_ann_file, fake_test_anno)
+        return (osp.join(self.root.name,
+                         'ctw1500.jpg'), train_ann_file, test_ann_file)
+
+    def test_textdet_parsers(self):
+        parser = CTW1500AnnParser(split='train')
+        img_path, train_file, test_file = self._create_dummy_ctw1500_det()
+        img_path, instances = parser.parse_file(img_path, train_file)
+        self.assertEqual(img_path, osp.join(self.root.name, 'ctw1500.jpg'))
+        self.assertEqual(len(instances), 1)
+        self.assertEqual(instances[0]['text'], 'OLATHE')
+        self.assertEqual(instances[0]['poly'], [
+            131, 58, 208, 49, 279, 56, 346, 76, 412, 101, 473, 141, 530, 192,
+            510, 246, 458, 210, 405, 175, 350, 151, 291, 137, 228, 133, 165,
+            134
+        ])
+        self.assertEqual(instances[0]['ignore'], False)
+
+        parser = CTW1500AnnParser(split='test')
+        img_path, instances = parser.parse_file(img_path, test_file)
+        self.assertEqual(img_path, osp.join(self.root.name, 'ctw1500.jpg'))
+        self.assertEqual(len(instances), 4)
+        self.assertEqual(instances[0]['ignore'], True)
+        self.assertEqual(instances[1]['text'], 'E-313')
+        self.assertEqual(instances[3]['poly'], [
+            41, 193, 57, 193, 74, 194, 90, 194, 107, 195, 123, 195, 140, 196,
+            146, 168, 128, 167, 110, 167, 92, 167, 74, 166, 56, 166, 39, 166
+        ])
+
+    def tearDown(self) -> None:
+        self.root.cleanup()