add model for chinese ()

* add model for chinese

* update readme

* update readme

* fix link error; add vis for Chinese
pull/165/head
Hongbin Sun 2021-05-09 21:49:08 +08:00 committed by GitHub
parent 5054a3f78d
commit 892d486d01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 217 additions and 9 deletions
mmocr/core
tests
test_utils/test_text

View File

@ -20,4 +20,4 @@
| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download |
| :--------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------: | :-------------: | :------------: | :-----: | :-------: | :----: | :-------: | :---: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| [DBNet_r18](/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py) | ImageNet | ICDAR2015 Train | ICDAR2015 Test | 1200 | 736 | 0.731 | 0.871 | 0.795 | [model](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.log.json) |
| [DBNet_r50dcn](/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py) | [Synthtext](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_2e_synthtext_20210325-aa96e477.pth) | ICDAR2015 Train | ICDAR2015 Test | 1200 | 1024 | 0.796 | 0.866 | 0.830 | [model](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20210325-91cef9af.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20210325-91cef9af.pth.log.json) |
| [DBNet_r50dcn](/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py) | [Synthtext](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_2e_synthtext_20210325-aa96e477.pth) | ICDAR2015 Train | ICDAR2015 Test | 1200 | 1024 | 0.796 | 0.866 | 0.830 | [model](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20210325-91cef9af.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20210325-91cef9af.log.json) |

View File

@ -50,6 +50,14 @@
| [SAR](/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py) | R31-1/8-1/4 | ParallelSARDecoder | 95.0 | 89.6 | 93.7 | | 79.0 | 82.2 | 88.9 | [model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_parallel_decoder_academic-dba3a4a3.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/sar/20210327_154129.log.json) |
| [SAR](configs/textrecog/sar/sar_r31_sequential_decoder_academic.py) | R31-1/8-1/4 | SequentialSARDecoder | 95.2 | 88.7 | 92.4 | | 78.2 | 81.9 | 89.6 | [model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_sequential_decoder_academic-d06c9a8e.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/sar/20210330_105728.log.json) |
## Chinese Dataset
## Results and Models
|Methods| Backbone | Decoder || download |
| :-----: | :------: | :-------: | :-------: | :---: |
| [SAR](/configs/textrecog/sar/sar_r31_parallel_decoder_chineseocr.py) | R31-1/8-1/4 | ParallelSARDecoder || [model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_parallel_decoder_chineseocr_20210507-b4be8214.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/sar/20210506_225557.log.json) \| [dict](https://download.openmmlab.com/mmocr/textrecog/sar/dict_printed_chinese_english_digits.txt) |
**Notes:**
- `R31-1/8-1/4` means the height of feature from backbone is 1/8 of input image, where 1/4 for width.

View File

@ -0,0 +1,126 @@
_base_ = ['../../_base_/default_runtime.py']
dict_file = 'data/chineseocr/labels/dict_printed_chinese_english_digits.txt'
label_convertor = dict(
type='AttnConvertor', dict_file=dict_file, with_unknown=True)
model = dict(
type='SARNet',
backbone=dict(type='ResNet31OCR'),
encoder=dict(
type='SAREncoder',
enc_bi_rnn=False,
enc_do_rnn=0.1,
enc_gru=False,
),
decoder=dict(
type='ParallelSARDecoder',
enc_bi_rnn=False,
dec_bi_rnn=False,
dec_do_rnn=0,
dec_gru=False,
pred_dropout=0.1,
d_k=512,
pred_concat=True),
loss=dict(type='SARLoss'),
label_convertor=label_convertor,
max_seq_len=30)
# optimizer
optimizer = dict(type='Adam', lr=1e-3)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='step', step=[3, 4])
total_epochs = 5
img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeOCR',
height=48,
min_width=48,
max_width=256,
keep_aspect_ratio=True,
width_downsample_ratio=0.25),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
]),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiRotateAugOCR',
rotate_degrees=[0, 90, 270],
transforms=[
dict(
type='ResizeOCR',
height=48,
min_width=48,
max_width=256,
keep_aspect_ratio=True,
width_downsample_ratio=0.25),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
]),
])
]
dataset_type = 'OCRDataset'
train_prefix = 'data/chinese/'
train_ann_file = train_prefix + 'labels/train.txt'
train = dict(
type=dataset_type,
img_prefix=train_prefix,
ann_file=train_ann_file,
loader=dict(
type='HardDiskLoader',
repeat=1,
parser=dict(
type='LineStrParser',
keys=['filename', 'text'],
keys_idx=[0, 1],
separator=' ')),
pipeline=train_pipeline,
test_mode=False)
test_prefix = 'data/chineseocr/'
test_ann_file = test_prefix + 'labels/test.txt'
test = dict(
type=dataset_type,
img_prefix=test_prefix,
ann_file=test_ann_file,
loader=dict(
type='HardDiskLoader',
repeat=1,
parser=dict(
type='LineStrParser',
keys=['filename', 'text'],
keys_idx=[0, 1],
separator=' ')),
pipeline=test_pipeline,
test_mode=False)
data = dict(
samples_per_gpu=40,
workers_per_gpu=2,
train=dict(type='ConcatDataset', datasets=[train]),
val=dict(type='ConcatDataset', datasets=[test]),
test=dict(type='ConcatDataset', datasets=[test]))
evaluation = dict(interval=1, metric='acc')

View File

@ -1,4 +1,7 @@
import math
import os
import shutil
import urllib
import warnings
import cv2
@ -6,6 +9,7 @@ import mmcv
import numpy as np
import torch
from matplotlib import pyplot as plt
from PIL import Image, ImageDraw, ImageFont
import mmocr.utils as utils
@ -348,16 +352,22 @@ def imshow_text_label(img,
resize_width = int(1.0 * src_w / src_h * resize_height)
img = cv2.resize(img, (resize_width, resize_height))
h, w = img.shape[:2]
pred_img = np.ones((h, w, 3), dtype=np.uint8) * 255
gt_img = np.ones((h, w, 3), dtype=np.uint8) * 255
cv2.putText(pred_img, pred_label, (5, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.9,
(0, 0, 255), 2)
if is_contain_chinese(pred_label):
pred_img = draw_texts_by_pil(img, [pred_label], None)
else:
pred_img = np.ones((h, w, 3), dtype=np.uint8) * 255
cv2.putText(pred_img, pred_label, (5, 40), cv2.FONT_HERSHEY_SIMPLEX,
0.9, (0, 0, 255), 2)
images = [pred_img, img]
if gt_label != '':
cv2.putText(gt_img, gt_label, (5, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.9,
(255, 0, 0), 2)
if is_contain_chinese(gt_label):
gt_img = draw_texts_by_pil(img, [gt_label], None)
else:
gt_img = np.ones((h, w, 3), dtype=np.uint8) * 255
cv2.putText(gt_img, gt_label, (5, 40), cv2.FONT_HERSHEY_SIMPLEX,
0.9, (255, 0, 0), 2)
images.append(gt_img)
img = tile_image(images)
@ -504,6 +514,62 @@ def draw_texts(img, boxes, texts):
return out_img
def draw_texts_by_pil(img, texts, boxes=None):
"""Draw boxes and texts on empty image, especially for Chinese.
Args:
img (np.ndarray): The original image.
texts (list[str]): Recognized texts.
boxes (list[list[float]]): Detected bounding boxes.
Return:
out_img (np.ndarray): Visualized text image.
"""
color_list = gen_color()
h, w = img.shape[:2]
if boxes is None:
boxes = [[0, 0, w, 0, w, h, 0, h]]
out_img = Image.new('RGB', (w, h), color=(255, 255, 255))
out_draw = ImageDraw.Draw(out_img)
for idx, (box, text) in enumerate(zip(boxes, texts)):
min_x, max_x = min(box[0::2]), max(box[0::2])
min_y, max_y = min(box[1::2]), max(box[1::2])
color = tuple(list(color_list[idx % len(color_list)])[::-1])
out_draw.line(box, fill=color, width=1)
box_width = max(max_x - min_x, max_y - min_y)
font_size = int(0.9 * box_width / len(text))
dirname, _ = os.path.split(os.path.abspath(__file__))
font_path = os.path.join(dirname, 'font.TTF')
if not os.path.exists(font_path):
url = ('http://download.openmmlab.com/mmocr/data/font.TTF')
print(f'Downloading {url} ...')
local_filename, _ = urllib.request.urlretrieve(url)
shutil.move(local_filename, font_path)
fnt = ImageFont.truetype(font_path, font_size)
out_draw.text((min_x + 1, min_y + 1), text, font=fnt, fill=(0, 0, 0))
del out_draw
out_img = cv2.cvtColor(np.asarray(out_img), cv2.COLOR_RGB2BGR)
return out_img
def is_contain_chinese(check_str):
"""Check whether string contains Chinese or not.
Args:
check_str (str): String to be checked.
Return True if contains Chinese, else False.
"""
for ch in check_str:
if u'\u4e00' <= ch <= u'\u9fff':
return True
return False
def det_recog_show_result(img, end2end_res):
"""Draw `result`(boxes and texts) on `img`.
Args:
@ -519,7 +585,11 @@ def det_recog_show_result(img, end2end_res):
boxes.append(res['box'])
texts.append(res['text'])
box_vis_img = draw_polygons(img, boxes)
text_vis_img = draw_texts(img, boxes, texts)
if is_contain_chinese(''.join(texts)):
text_vis_img = draw_texts_by_pil(img, texts, boxes)
else:
text_vis_img = draw_texts(img, boxes, texts)
h, w = img.shape[:2]
out_img = np.ones((h, w * 2, 3), dtype=np.uint8)

View File

@ -15,6 +15,10 @@ def test_det_recog_show_result():
}
vis_img = det_recog_show_result(img, det_recog_res)
assert vis_img.shape[0] == 100
assert vis_img.shape[1] == 200
assert vis_img.shape[2] == 3
det_recog_res['result'][0]['text'] = '中文'
det_recog_show_result(img, det_recog_res)

View File

@ -57,7 +57,7 @@ def test_show_text_label(mock_imwrite, mock_imshow, mock_imread):
visualize_utils.imshow_text_label(
img, pred_label, gt_label, out_file=out_file)
visualize_utils.imshow_text_label(
img, pred_label, gt_label, out_file=None, show=True)
img, '中文', '中文', out_file=None, show=True)
# test showing img
mock_imshow.assert_called_once()