add model for chinese (#156)

* add model for chinese

* update readme

* update readme

* fix link error; add vis for Chinese
pull/165/head
Hongbin Sun 2021-05-09 21:49:08 +08:00 committed by GitHub
parent 5054a3f78d
commit 892d486d01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 217 additions and 9 deletions

View File

@ -20,4 +20,4 @@
| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download | | Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download |
| :--------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------: | :-------------: | :------------: | :-----: | :-------: | :----: | :-------: | :---: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | | :--------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------: | :-------------: | :------------: | :-----: | :-------: | :----: | :-------: | :---: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| [DBNet_r18](/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py) | ImageNet | ICDAR2015 Train | ICDAR2015 Test | 1200 | 736 | 0.731 | 0.871 | 0.795 | [model](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.log.json) | | [DBNet_r18](/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py) | ImageNet | ICDAR2015 Train | ICDAR2015 Test | 1200 | 736 | 0.731 | 0.871 | 0.795 | [model](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.log.json) |
| [DBNet_r50dcn](/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py) | [Synthtext](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_2e_synthtext_20210325-aa96e477.pth) | ICDAR2015 Train | ICDAR2015 Test | 1200 | 1024 | 0.796 | 0.866 | 0.830 | [model](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20210325-91cef9af.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20210325-91cef9af.pth.log.json) | | [DBNet_r50dcn](/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py) | [Synthtext](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_2e_synthtext_20210325-aa96e477.pth) | ICDAR2015 Train | ICDAR2015 Test | 1200 | 1024 | 0.796 | 0.866 | 0.830 | [model](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20210325-91cef9af.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20210325-91cef9af.log.json) |

View File

@ -50,6 +50,14 @@
| [SAR](/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py) | R31-1/8-1/4 | ParallelSARDecoder | 95.0 | 89.6 | 93.7 | | 79.0 | 82.2 | 88.9 | [model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_parallel_decoder_academic-dba3a4a3.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/sar/20210327_154129.log.json) | | [SAR](/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py) | R31-1/8-1/4 | ParallelSARDecoder | 95.0 | 89.6 | 93.7 | | 79.0 | 82.2 | 88.9 | [model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_parallel_decoder_academic-dba3a4a3.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/sar/20210327_154129.log.json) |
| [SAR](configs/textrecog/sar/sar_r31_sequential_decoder_academic.py) | R31-1/8-1/4 | SequentialSARDecoder | 95.2 | 88.7 | 92.4 | | 78.2 | 81.9 | 89.6 | [model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_sequential_decoder_academic-d06c9a8e.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/sar/20210330_105728.log.json) | | [SAR](configs/textrecog/sar/sar_r31_sequential_decoder_academic.py) | R31-1/8-1/4 | SequentialSARDecoder | 95.2 | 88.7 | 92.4 | | 78.2 | 81.9 | 89.6 | [model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_sequential_decoder_academic-d06c9a8e.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/sar/20210330_105728.log.json) |
## Chinese Dataset
## Results and Models
|Methods| Backbone | Decoder || download |
| :-----: | :------: | :-------: | :-------: | :---: |
| [SAR](/configs/textrecog/sar/sar_r31_parallel_decoder_chineseocr.py) | R31-1/8-1/4 | ParallelSARDecoder || [model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_parallel_decoder_chineseocr_20210507-b4be8214.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/sar/20210506_225557.log.json) \| [dict](https://download.openmmlab.com/mmocr/textrecog/sar/dict_printed_chinese_english_digits.txt) |
**Notes:** **Notes:**
- `R31-1/8-1/4` means the height of feature from backbone is 1/8 of input image, where 1/4 for width. - `R31-1/8-1/4` means the height of feature from backbone is 1/8 of input image, where 1/4 for width.

View File

@ -0,0 +1,126 @@
_base_ = ['../../_base_/default_runtime.py']
dict_file = 'data/chineseocr/labels/dict_printed_chinese_english_digits.txt'
label_convertor = dict(
type='AttnConvertor', dict_file=dict_file, with_unknown=True)
model = dict(
type='SARNet',
backbone=dict(type='ResNet31OCR'),
encoder=dict(
type='SAREncoder',
enc_bi_rnn=False,
enc_do_rnn=0.1,
enc_gru=False,
),
decoder=dict(
type='ParallelSARDecoder',
enc_bi_rnn=False,
dec_bi_rnn=False,
dec_do_rnn=0,
dec_gru=False,
pred_dropout=0.1,
d_k=512,
pred_concat=True),
loss=dict(type='SARLoss'),
label_convertor=label_convertor,
max_seq_len=30)
# optimizer
optimizer = dict(type='Adam', lr=1e-3)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='step', step=[3, 4])
total_epochs = 5
img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeOCR',
height=48,
min_width=48,
max_width=256,
keep_aspect_ratio=True,
width_downsample_ratio=0.25),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
]),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiRotateAugOCR',
rotate_degrees=[0, 90, 270],
transforms=[
dict(
type='ResizeOCR',
height=48,
min_width=48,
max_width=256,
keep_aspect_ratio=True,
width_downsample_ratio=0.25),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
]),
])
]
dataset_type = 'OCRDataset'
train_prefix = 'data/chinese/'
train_ann_file = train_prefix + 'labels/train.txt'
train = dict(
type=dataset_type,
img_prefix=train_prefix,
ann_file=train_ann_file,
loader=dict(
type='HardDiskLoader',
repeat=1,
parser=dict(
type='LineStrParser',
keys=['filename', 'text'],
keys_idx=[0, 1],
separator=' ')),
pipeline=train_pipeline,
test_mode=False)
test_prefix = 'data/chineseocr/'
test_ann_file = test_prefix + 'labels/test.txt'
test = dict(
type=dataset_type,
img_prefix=test_prefix,
ann_file=test_ann_file,
loader=dict(
type='HardDiskLoader',
repeat=1,
parser=dict(
type='LineStrParser',
keys=['filename', 'text'],
keys_idx=[0, 1],
separator=' ')),
pipeline=test_pipeline,
test_mode=False)
data = dict(
samples_per_gpu=40,
workers_per_gpu=2,
train=dict(type='ConcatDataset', datasets=[train]),
val=dict(type='ConcatDataset', datasets=[test]),
test=dict(type='ConcatDataset', datasets=[test]))
evaluation = dict(interval=1, metric='acc')

View File

@ -1,4 +1,7 @@
import math import math
import os
import shutil
import urllib
import warnings import warnings
import cv2 import cv2
@ -6,6 +9,7 @@ import mmcv
import numpy as np import numpy as np
import torch import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from PIL import Image, ImageDraw, ImageFont
import mmocr.utils as utils import mmocr.utils as utils
@ -348,16 +352,22 @@ def imshow_text_label(img,
resize_width = int(1.0 * src_w / src_h * resize_height) resize_width = int(1.0 * src_w / src_h * resize_height)
img = cv2.resize(img, (resize_width, resize_height)) img = cv2.resize(img, (resize_width, resize_height))
h, w = img.shape[:2] h, w = img.shape[:2]
pred_img = np.ones((h, w, 3), dtype=np.uint8) * 255
gt_img = np.ones((h, w, 3), dtype=np.uint8) * 255
cv2.putText(pred_img, pred_label, (5, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.9, if is_contain_chinese(pred_label):
(0, 0, 255), 2) pred_img = draw_texts_by_pil(img, [pred_label], None)
else:
pred_img = np.ones((h, w, 3), dtype=np.uint8) * 255
cv2.putText(pred_img, pred_label, (5, 40), cv2.FONT_HERSHEY_SIMPLEX,
0.9, (0, 0, 255), 2)
images = [pred_img, img] images = [pred_img, img]
if gt_label != '': if gt_label != '':
cv2.putText(gt_img, gt_label, (5, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.9, if is_contain_chinese(gt_label):
(255, 0, 0), 2) gt_img = draw_texts_by_pil(img, [gt_label], None)
else:
gt_img = np.ones((h, w, 3), dtype=np.uint8) * 255
cv2.putText(gt_img, gt_label, (5, 40), cv2.FONT_HERSHEY_SIMPLEX,
0.9, (255, 0, 0), 2)
images.append(gt_img) images.append(gt_img)
img = tile_image(images) img = tile_image(images)
@ -504,6 +514,62 @@ def draw_texts(img, boxes, texts):
return out_img return out_img
def draw_texts_by_pil(img, texts, boxes=None):
"""Draw boxes and texts on empty image, especially for Chinese.
Args:
img (np.ndarray): The original image.
texts (list[str]): Recognized texts.
boxes (list[list[float]]): Detected bounding boxes.
Return:
out_img (np.ndarray): Visualized text image.
"""
color_list = gen_color()
h, w = img.shape[:2]
if boxes is None:
boxes = [[0, 0, w, 0, w, h, 0, h]]
out_img = Image.new('RGB', (w, h), color=(255, 255, 255))
out_draw = ImageDraw.Draw(out_img)
for idx, (box, text) in enumerate(zip(boxes, texts)):
min_x, max_x = min(box[0::2]), max(box[0::2])
min_y, max_y = min(box[1::2]), max(box[1::2])
color = tuple(list(color_list[idx % len(color_list)])[::-1])
out_draw.line(box, fill=color, width=1)
box_width = max(max_x - min_x, max_y - min_y)
font_size = int(0.9 * box_width / len(text))
dirname, _ = os.path.split(os.path.abspath(__file__))
font_path = os.path.join(dirname, 'font.TTF')
if not os.path.exists(font_path):
url = ('http://download.openmmlab.com/mmocr/data/font.TTF')
print(f'Downloading {url} ...')
local_filename, _ = urllib.request.urlretrieve(url)
shutil.move(local_filename, font_path)
fnt = ImageFont.truetype(font_path, font_size)
out_draw.text((min_x + 1, min_y + 1), text, font=fnt, fill=(0, 0, 0))
del out_draw
out_img = cv2.cvtColor(np.asarray(out_img), cv2.COLOR_RGB2BGR)
return out_img
def is_contain_chinese(check_str):
"""Check whether string contains Chinese or not.
Args:
check_str (str): String to be checked.
Return True if contains Chinese, else False.
"""
for ch in check_str:
if u'\u4e00' <= ch <= u'\u9fff':
return True
return False
def det_recog_show_result(img, end2end_res): def det_recog_show_result(img, end2end_res):
"""Draw `result`(boxes and texts) on `img`. """Draw `result`(boxes and texts) on `img`.
Args: Args:
@ -519,6 +585,10 @@ def det_recog_show_result(img, end2end_res):
boxes.append(res['box']) boxes.append(res['box'])
texts.append(res['text']) texts.append(res['text'])
box_vis_img = draw_polygons(img, boxes) box_vis_img = draw_polygons(img, boxes)
if is_contain_chinese(''.join(texts)):
text_vis_img = draw_texts_by_pil(img, texts, boxes)
else:
text_vis_img = draw_texts(img, boxes, texts) text_vis_img = draw_texts(img, boxes, texts)
h, w = img.shape[:2] h, w = img.shape[:2]

View File

@ -15,6 +15,10 @@ def test_det_recog_show_result():
} }
vis_img = det_recog_show_result(img, det_recog_res) vis_img = det_recog_show_result(img, det_recog_res)
assert vis_img.shape[0] == 100 assert vis_img.shape[0] == 100
assert vis_img.shape[1] == 200 assert vis_img.shape[1] == 200
assert vis_img.shape[2] == 3 assert vis_img.shape[2] == 3
det_recog_res['result'][0]['text'] = '中文'
det_recog_show_result(img, det_recog_res)

View File

@ -57,7 +57,7 @@ def test_show_text_label(mock_imwrite, mock_imshow, mock_imread):
visualize_utils.imshow_text_label( visualize_utils.imshow_text_label(
img, pred_label, gt_label, out_file=out_file) img, pred_label, gt_label, out_file=out_file)
visualize_utils.imshow_text_label( visualize_utils.imshow_text_label(
img, pred_label, gt_label, out_file=None, show=True) img, '中文', '中文', out_file=None, show=True)
# test showing img # test showing img
mock_imshow.assert_called_once() mock_imshow.assert_called_once()