add totaltext for recog and det (#357)

* add totaltext for recog and det

* add setup

* fix doc

* fix based on comments
pull/361/head
quincylin1 2021-07-08 21:52:50 +08:00 committed by GitHub
parent 1a5a880abb
commit 243f47dc03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 311 additions and 110 deletions

View File

@ -120,7 +120,7 @@ The structure of the text detection dataset directory is organized as follows.
python tools/data/textdet/textocr_converter.py /path/to/textocr
```
- For `Totaltext`:
- Step1: Download `totaltext.zip` from [github dataset](https://github.com/cs-chan/Total-Text-Dataset/tree/master/Dataset) and `groundtruth_text.zip` from [github Groundtruth](https://github.com/cs-chan/Total-Text-Dataset/tree/master/Groundtruth/Text) (We recommend downloading the text groundtruth with .mat format since our totaltext_converter.py supports groundtruth with .mat format only).
- Step1: Download `totaltext.zip` from [github dataset](https://github.com/cs-chan/Total-Text-Dataset/tree/master/Dataset) and `groundtruth_text.zip` from [github Groundtruth](https://github.com/cs-chan/Total-Text-Dataset/tree/master/Groundtruth/Text) (Our totaltext_converter.py supports groundtruth with both .mat and .txt format).
```bash
mkdir totaltext && cd totaltext
mkdir imgs && mkdir annotations
@ -339,7 +339,7 @@ python tools/data/utils/txt2lmdb.py -i data/mixture/Syn90k/label.txt -o data/mix
- For `Totaltext`:
- Step1: Download `totaltext.zip` from [github dataset](https://github.com/cs-chan/Total-Text-Dataset/tree/master/Dataset) and `groundtruth_text.zip` from [github Groundtruth](https://github.com/cs-chan/Total-Text-Dataset/tree/master/Groundtruth/Text) (We recommend downloading the text groundtruth with .mat format since our totaltext_converter.py supports groundtruth with .mat format only).
- Step1: Download `totaltext.zip` from [github dataset](https://github.com/cs-chan/Total-Text-Dataset/tree/master/Dataset) and `groundtruth_text.zip` from [github Groundtruth](https://github.com/cs-chan/Total-Text-Dataset/tree/master/Groundtruth/Text) (Our totaltext_converter.py supports groundtruth with both .mat and .txt format).
```bash
mkdir totaltext && cd totaltext
mkdir imgs && mkdir annotations

View File

@ -20,7 +20,7 @@ line_length = 79
multi_line_output = 0
known_standard_library = setuptools
known_first_party = mmocr
known_third_party = PIL,Polygon,cv2,imgaug,lanms,lmdb,matplotlib,mmcv,mmdet,numpy,packaging,pyclipper,pycocotools,pytest,rapidfuzz,scipy,shapely,skimage,titlecase,torch,torchvision
known_third_party = PIL,Polygon,cv2,imgaug,lanms,lmdb,matplotlib,mmcv,mmdet,numpy,packaging,pyclipper,pycocotools,pytest,rapidfuzz,scipy,shapely,skimage,titlecase,torch,torchvision,yaml
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY

View File

@ -1,12 +1,14 @@
import argparse
import glob
import os
import os.path as osp
from functools import partial
import re
import cv2
import mmcv
import numpy as np
import scipy.io as scio
import yaml
from shapely.geometry import Polygon
from mmocr.utils import convert_annotations, drop_orientation, is_not_png
@ -19,7 +21,6 @@ def collect_files(img_dir, gt_dir, split):
img_dir(str): The image directory
gt_dir(str): The groundtruth directory
split(str): The split of dataset. Namely: training or test
Returns:
files(list): The list of tuples (img_file, groundtruth_file)
"""
@ -37,63 +38,44 @@ def collect_files(img_dir, gt_dir, split):
for suffix in suffixes:
imgs_list.extend(glob.glob(osp.join(img_dir, '*' + suffix)))
imgs_list = [
drop_orientation(f) if is_not_png(f) else f for f in imgs_list
]
imgs_list = sorted(
[drop_orientation(f) if is_not_png(f) else f for f in imgs_list])
ann_list = sorted(
[osp.join(gt_dir, gt_file) for gt_file in os.listdir(gt_dir)])
files = []
if split == 'training':
for img_file in imgs_list:
gt_file = osp.join(
gt_dir,
'poly_gt_' + osp.splitext(osp.basename(img_file))[0] + '.mat')
files.append((img_file, gt_file))
assert len(files), f'No images found in {img_dir}'
print(f'Loaded {len(files)} images from {img_dir}')
elif split == 'test':
for img_file in imgs_list:
gt_file = osp.join(
gt_dir,
'poly_gt_' + osp.splitext(osp.basename(img_file))[0] + '.mat')
files.append((img_file, gt_file))
assert len(files), f'No images found in {img_dir}'
print(f'Loaded {len(files)} images from {img_dir}')
files = list(zip(imgs_list, ann_list))
assert len(files), f'No images found in {img_dir}'
print(f'Loaded {len(files)} images from {img_dir}')
return files
def collect_annotations(files, split, nproc=1):
def collect_annotations(files, nproc=1):
"""Collect the annotation information.
Args:
files(list): The list of tuples (image_file, groundtruth_file)
split(str): The split of dataset. Namely: training or test
nproc(int): The number of process to collect annotations
Returns:
images(list): The list of image information dicts
"""
assert isinstance(files, list)
assert isinstance(split, str)
assert isinstance(nproc, int)
load_img_info_with_split = partial(load_img_info, split=split)
if nproc > 1:
images = mmcv.track_parallel_progress(
load_img_info_with_split, files, nproc=nproc)
load_img_info, files, nproc=nproc)
else:
images = mmcv.track_progress(load_img_info_with_split, files)
images = mmcv.track_progress(load_img_info, files)
return images
def get_contours(gt_path, split):
"""Get the contours and words for each ground_truth file.
def get_contours_mat(gt_path):
"""Get the contours and words for each ground_truth mat file.
Args:
gt_path(str): The relative path of the ground_truth mat file
split(str): The split of dataset: training or test
Returns:
contours(list[lists]): A list of lists of contours
for the text instances
@ -101,15 +83,11 @@ def get_contours(gt_path, split):
for the text instances
"""
assert isinstance(gt_path, str)
assert isinstance(split, str)
contours = []
words = []
data = scio.loadmat(gt_path)
if split == 'training':
data_polygt = data['polygt']
elif split == 'test':
data_polygt = data['polygt']
data_polygt = data['polygt']
for i, lines in enumerate(data_polygt):
X = np.array(lines[1])
@ -138,23 +116,150 @@ def get_contours(gt_path, split):
return contours, words
def load_mat_info(img_info, gt_file, split):
def load_mat_info(img_info, gt_file):
"""Load the information of one ground truth in .mat format.
Args:
img_info(dict): The dict of only the image information
gt_file(str): The relative path of the ground_truth mat
file for one image
split(str): The split of dataset: training or test
Returns:
img_info(dict): The dict of the img and annotation information
"""
assert isinstance(img_info, dict)
assert isinstance(gt_file, str)
assert isinstance(split, str)
contours, words = get_contours(gt_file, split)
contours, words = get_contours_mat(gt_file)
anno_info = []
for contour in contours:
if contour.shape[0] == 2:
continue
category_id = 1
coordinates = np.array(contour).reshape(-1, 2)
polygon = Polygon(coordinates)
iscrowd = 0
area = polygon.area
# convert to COCO style XYWH format
min_x, min_y, max_x, max_y = polygon.bounds
bbox = [min_x, min_y, max_x - min_x, max_y - min_y]
anno = dict(
iscrowd=iscrowd,
category_id=category_id,
bbox=bbox,
area=area,
segmentation=[contour])
anno_info.append(anno)
img_info.update(anno_info=anno_info)
return img_info
def process_line(line, contours, words):
"""Get the contours and words by processing each line in the gt file.
Args:
line(str): The line in gt file containing annotation info
contours(list[lists]): A list of lists of contours
for the text instances
words(list[list]): A list of lists of words (string)
for the text instances
Returns:
contours(list[lists]): A list of lists of contours
for the text instances
words(list[list]): A list of lists of words (string)
for the text instances
"""
line = '{' + line.replace('[[', '[').replace(']]', ']') + '}'
ann_dict = re.sub('([0-9]) +([0-9])', r'\1,\2', line)
ann_dict = re.sub('([0-9]) +([ 0-9])', r'\1,\2', ann_dict)
ann_dict = re.sub('([0-9]) -([0-9])', r'\1,-\2', ann_dict)
ann_dict = ann_dict.replace("[u',']", "[u'#']")
ann_dict = yaml.load(ann_dict)
X = np.array([ann_dict['x']])
Y = np.array([ann_dict['y']])
if len(ann_dict['transcriptions']) == 0:
word = '???'
else:
word = ann_dict['transcriptions'][0]
if len(ann_dict['transcriptions']) > 1:
for ann_word in ann_dict['transcriptions'][1:]:
word += ',' + ann_word
word = str(eval(word))
words.append(word)
point_num = len(X[0])
arr = np.concatenate([X, Y]).T
contour = []
for i in range(point_num):
contour.append(arr[i][0])
contour.append(arr[i][1])
contours.append(np.asarray(contour))
return contours, words
def get_contours_txt(gt_path):
"""Get the contours and words for each ground_truth txt file.
Args:
gt_path(str): The relative path of the ground_truth mat file
Returns:
contours(list[lists]): A list of lists of contours
for the text instances
words(list[list]): A list of lists of words (string)
for the text instances
"""
assert isinstance(gt_path, str)
contours = []
words = []
with open(gt_path, 'r') as f:
tmp_line = ''
for idx, line in enumerate(f):
line = line.strip()
if idx == 0:
tmp_line = line
continue
if not line.startswith('x:'):
tmp_line += ' ' + line
continue
else:
complete_line = tmp_line
tmp_line = line
contours, words = process_line(complete_line, contours, words)
if tmp_line != '':
contours, words = process_line(tmp_line, contours, words)
for word in words:
if word == '#':
word = '###'
continue
return contours, words
def load_txt_info(gt_file, img_info):
"""Load the information of one ground truth in .txt format.
Args:
img_info(dict): The dict of only the image information
gt_file(str): The relative path of the ground_truth mat
file for one image
Returns:
img_info(dict): The dict of the img and annotation information
"""
contours, words = get_contours_txt(gt_file)
anno_info = []
for contour in contours:
if contour.shape[0] == 2:
@ -188,7 +293,6 @@ def load_png_info(gt_file, img_info):
Args:
gt_file(str): The relative path of the ground_truth file for one image
img_info(dict): The dict of only the image information
Returns:
img_info(dict): The dict of the img and annotation information
"""
@ -227,18 +331,15 @@ def load_png_info(gt_file, img_info):
return img_info
def load_img_info(files, split):
def load_img_info(files):
"""Load the information of one image.
Args:
files(tuple): The tuple of (img_file, groundtruth_file)
split(str): The split of dataset: training or test
Returns:
img_info(dict): The dict of the img and annotation information
"""
assert isinstance(files, tuple)
assert isinstance(split, str)
img_file, gt_file = files
# read imgs with ignoring orientations
@ -257,10 +358,10 @@ def load_img_info(files, split):
# anno_info=anno_info,
segm_file=osp.join(split_name, osp.basename(gt_file)))
if split == 'training':
img_info = load_mat_info(img_info, gt_file, split)
elif split == 'test':
img_info = load_mat_info(img_info, gt_file, split)
if osp.splitext(gt_file)[1] == '.mat':
img_info = load_mat_info(img_info, gt_file)
elif osp.splitext(gt_file)[1] == '.txt':
img_info = load_txt_info(gt_file, img_info)
else:
raise NotImplementedError
@ -303,7 +404,7 @@ def main():
print_tmpl='It takes {}s to convert totaltext annotation'):
files = collect_files(
osp.join(img_dir, split), osp.join(gt_dir, split), split)
image_infos = collect_annotations(files, split, nproc=args.nproc)
image_infos = collect_annotations(files, nproc=args.nproc)
convert_annotations(image_infos, osp.join(out_dir, json_name))

View File

@ -2,11 +2,12 @@ import argparse
import glob
import os
import os.path as osp
from functools import partial
import re
import mmcv
import numpy as np
import scipy.io as scio
import yaml
from shapely.geometry import Polygon
from mmocr.datasets.pipelines.crop import crop_img
@ -21,7 +22,6 @@ def collect_files(img_dir, gt_dir, split):
img_dir(str): The image directory
gt_dir(str): The groundtruth directory
split(str): The split of dataset. Namely: training or test
Returns:
files(list): The list of tuples (img_file, groundtruth_file)
"""
@ -32,70 +32,52 @@ def collect_files(img_dir, gt_dir, split):
# note that we handle png and jpg only. Pls convert others such as gif to
# jpg or png offline
suffixes = ['.png', '.jpg', '.jpeg']
suffixes = ['.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG']
# suffixes = ['.png']
imgs_list = []
for suffix in suffixes:
imgs_list.extend(glob.glob(osp.join(img_dir, '*' + suffix)))
imgs_list = [
drop_orientation(f) if is_not_png(f) else f for f in imgs_list
]
imgs_list = sorted(
[drop_orientation(f) if is_not_png(f) else f for f in imgs_list])
ann_list = sorted(
[osp.join(gt_dir, gt_file) for gt_file in os.listdir(gt_dir)])
files = []
if split == 'training':
for img_file in imgs_list:
gt_file = osp.join(
gt_dir,
'poly_gt_' + osp.splitext(osp.basename(img_file))[0] + '.mat')
files.append((img_file, gt_file))
assert len(files), f'No images found in {img_dir}'
print(f'Loaded {len(files)} images from {img_dir}')
elif split == 'test':
for img_file in imgs_list:
gt_file = osp.join(
gt_dir,
'poly_gt_' + osp.splitext(osp.basename(img_file))[0] + '.mat')
files.append((img_file, gt_file))
assert len(files), f'No images found in {img_dir}'
print(f'Loaded {len(files)} images from {img_dir}')
files = [(img_file, gt_file)
for (img_file, gt_file) in zip(imgs_list, ann_list)]
assert len(files), f'No images found in {img_dir}'
print(f'Loaded {len(files)} images from {img_dir}')
return files
def collect_annotations(files, split, nproc=1):
def collect_annotations(files, nproc=1):
"""Collect the annotation information.
Args:
files(list): The list of tuples (image_file, groundtruth_file)
split(str): The split of dataset. Namely: training or test
nproc(int): The number of process to collect annotations
Returns:
images(list): The list of image information dicts
"""
assert isinstance(files, list)
assert isinstance(split, str)
assert isinstance(nproc, int)
load_img_info_with_split = partial(load_img_info, split=split)
if nproc > 1:
images = mmcv.track_parallel_progress(
load_img_info_with_split, files, nproc=nproc)
load_img_info, files, nproc=nproc)
else:
images = mmcv.track_progress(load_img_info_with_split, files)
images = mmcv.track_progress(load_img_info, files)
return images
def get_contours(gt_path, split):
"""Get the contours and words for each ground_truth file.
def get_contours_mat(gt_path):
"""Get the contours and words for each ground_truth mat file.
Args:
gt_path(str): The relative path of the ground_truth mat file
split(str): The split of dataset: training or test
Returns:
contours(list[lists]): A list of lists of contours
for the text instances
@ -103,17 +85,13 @@ def get_contours(gt_path, split):
for the text instances
"""
assert isinstance(gt_path, str)
assert isinstance(split, str)
contours = []
words = []
data = scio.loadmat(gt_path)
if split == 'training':
data_polygt = data['polygt']
elif split == 'test':
data_polygt = data['polygt']
data_polygt = data['polygt']
for lines in data_polygt:
for i, lines in enumerate(data_polygt):
X = np.array(lines[1])
Y = np.array(lines[3])
@ -140,23 +118,140 @@ def get_contours(gt_path, split):
return contours, words
def load_mat_info(img_info, gt_file, split):
def load_mat_info(img_info, gt_file):
"""Load the information of one ground truth in .mat format.
Args:
img_info(dict): The dict of only the image information
gt_file(str): The relative path of the ground_truth mat
file for one image
split(str): The split of dataset: training or test
Returns:
img_info(dict): The dict of the img and annotation information
"""
assert isinstance(img_info, dict)
assert isinstance(gt_file, str)
assert isinstance(split, str)
contours, words = get_contours(gt_file, split)
contours, words = get_contours_mat(gt_file)
anno_info = []
for contour, word in zip(contours, words):
if contour.shape[0] == 2:
continue
coordinates = np.array(contour).reshape(-1, 2)
polygon = Polygon(coordinates)
# convert to COCO style XYWH format
min_x, min_y, max_x, max_y = polygon.bounds
bbox = [min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y]
anno = dict(word=word, bbox=bbox)
anno_info.append(anno)
img_info.update(anno_info=anno_info)
return img_info
def process_line(line, contours, words):
"""Get the contours and words by processing each line in the gt file.
Args:
line(str): The line in gt file containing annotation info
contours(list[lists]): A list of lists of contours
for the text instances
words(list[list]): A list of lists of words (string)
for the text instances
Returns:
contours(list[lists]): A list of lists of contours
for the text instances
words(list[list]): A list of lists of words (string)
for the text instances
"""
line = '{' + line.replace('[[', '[').replace(']]', ']') + '}'
ann_dict = re.sub('([0-9]) +([0-9])', r'\1,\2', line)
ann_dict = re.sub('([0-9]) +([ 0-9])', r'\1,\2', ann_dict)
ann_dict = re.sub('([0-9]) -([0-9])', r'\1,-\2', ann_dict)
ann_dict = ann_dict.replace("[u',']", "[u'#']")
ann_dict = yaml.load(ann_dict)
X = np.array([ann_dict['x']])
Y = np.array([ann_dict['y']])
if len(ann_dict['transcriptions']) == 0:
word = '???'
else:
word = ann_dict['transcriptions'][0]
if len(ann_dict['transcriptions']) > 1:
for ann_word in ann_dict['transcriptions'][1:]:
word += ',' + ann_word
word = str(eval(word))
words.append(word)
point_num = len(X[0])
arr = np.concatenate([X, Y]).T
contour = []
for i in range(point_num):
contour.append(arr[i][0])
contour.append(arr[i][1])
contours.append(np.asarray(contour))
return contours, words
def get_contours_txt(gt_path):
"""Get the contours and words for each ground_truth txt file.
Args:
gt_path(str): The relative path of the ground_truth mat file
Returns:
contours(list[lists]): A list of lists of contours
for the text instances
words(list[list]): A list of lists of words (string)
for the text instances
"""
assert isinstance(gt_path, str)
contours = []
words = []
with open(gt_path, 'r') as f:
tmp_line = ''
for idx, line in enumerate(f):
line = line.strip()
if idx == 0:
tmp_line = line
continue
if not line.startswith('x:'):
tmp_line += ' ' + line
continue
else:
complete_line = tmp_line
tmp_line = line
contours, words = process_line(complete_line, contours, words)
if tmp_line != '':
contours, words = process_line(tmp_line, contours, words)
for word in words:
if word == '#':
word = '###'
continue
return contours, words
def load_txt_info(gt_file, img_info):
"""Load the information of one ground truth in .txt format.
Args:
img_info(dict): The dict of only the image information
gt_file(str): The relative path of the ground_truth mat
file for one image
Returns:
img_info(dict): The dict of the img and annotation information
"""
contours, words = get_contours_txt(gt_file)
anno_info = []
for contour, word in zip(contours, words):
if contour.shape[0] == 2:
@ -175,6 +270,14 @@ def load_mat_info(img_info, gt_file, split):
def generate_ann(root_path, split, image_infos):
"""Generate cropped annotations and label txt file.
Args:
root_path(str): The relative path of the totaltext file
split(str): The split of dataset. Namely: training or test
image_infos(list[dict]): A list of dicts of the img and
annotation information
"""
dst_image_root = osp.join(root_path, 'dst_imgs', split)
if split == 'training':
@ -202,18 +305,15 @@ def generate_ann(root_path, split, image_infos):
list_to_file(dst_label_file, lines)
def load_img_info(files, split):
def load_img_info(files):
"""Load the information of one image.
Args:
files(tuple): The tuple of (img_file, groundtruth_file)
split(str): The split of dataset: training or test
Returns:
img_info(dict): The dict of the img and annotation information
"""
assert isinstance(files, tuple)
assert isinstance(split, str)
img_file, gt_file = files
# read imgs with ignoring orientations
@ -232,10 +332,10 @@ def load_img_info(files, split):
# anno_info=anno_info,
segm_file=osp.join(split_name, osp.basename(gt_file)))
if split == 'training':
img_info = load_mat_info(img_info, gt_file, split)
elif split == 'test':
img_info = load_mat_info(img_info, gt_file, split)
if osp.splitext(gt_file)[1] == '.mat':
img_info = load_mat_info(img_info, gt_file)
elif osp.splitext(gt_file)[1] == '.txt':
img_info = load_txt_info(gt_file, img_info)
else:
raise NotImplementedError
@ -278,7 +378,7 @@ def main():
print_tmpl='It takes {}s to convert totaltext annotation'):
files = collect_files(
osp.join(img_dir, split), osp.join(gt_dir, split), split)
image_infos = collect_annotations(files, split, nproc=args.nproc)
image_infos = collect_annotations(files, nproc=args.nproc)
generate_ann(root_path, split, image_infos)