mirror of https://github.com/open-mmlab/mmocr.git
add model for chinese (#156)
* add model for chinese * update readme * update readme * fix link error; add vis for Chinesepull/165/head
parent
5054a3f78d
commit
892d486d01
configs
textdet/dbnet
textrecog/sar
mmocr/core
tests
test_core
test_utils/test_text
|
@ -20,4 +20,4 @@
|
|||
| Method | Pretrained Model | Training set | Test set | #epochs | Test size | Recall | Precision | Hmean | Download |
|
||||
| :--------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------: | :-------------: | :------------: | :-----: | :-------: | :----: | :-------: | :---: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
|
||||
| [DBNet_r18](/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py) | ImageNet | ICDAR2015 Train | ICDAR2015 Test | 1200 | 736 | 0.731 | 0.871 | 0.795 | [model](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.log.json) |
|
||||
| [DBNet_r50dcn](/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py) | [Synthtext](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_2e_synthtext_20210325-aa96e477.pth) | ICDAR2015 Train | ICDAR2015 Test | 1200 | 1024 | 0.796 | 0.866 | 0.830 | [model](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20210325-91cef9af.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20210325-91cef9af.pth.log.json) |
|
||||
| [DBNet_r50dcn](/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py) | [Synthtext](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_2e_synthtext_20210325-aa96e477.pth) | ICDAR2015 Train | ICDAR2015 Test | 1200 | 1024 | 0.796 | 0.866 | 0.830 | [model](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20210325-91cef9af.pth) \| [log](https://download.openmmlab.com/mmocr/textdet/dbnet/dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20210325-91cef9af.log.json) |
|
||||
|
|
|
@ -50,6 +50,14 @@
|
|||
| [SAR](/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py) | R31-1/8-1/4 | ParallelSARDecoder | 95.0 | 89.6 | 93.7 | | 79.0 | 82.2 | 88.9 | [model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_parallel_decoder_academic-dba3a4a3.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/sar/20210327_154129.log.json) |
|
||||
| [SAR](configs/textrecog/sar/sar_r31_sequential_decoder_academic.py) | R31-1/8-1/4 | SequentialSARDecoder | 95.2 | 88.7 | 92.4 | | 78.2 | 81.9 | 89.6 | [model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_sequential_decoder_academic-d06c9a8e.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/sar/20210330_105728.log.json) |
|
||||
|
||||
## Chinese Dataset
|
||||
|
||||
## Results and Models
|
||||
|
||||
|Methods| Backbone | Decoder || download |
|
||||
| :-----: | :------: | :-------: | :-------: | :---: |
|
||||
| [SAR](/configs/textrecog/sar/sar_r31_parallel_decoder_chineseocr.py) | R31-1/8-1/4 | ParallelSARDecoder || [model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_parallel_decoder_chineseocr_20210507-b4be8214.pth) \| [log](https://download.openmmlab.com/mmocr/textrecog/sar/20210506_225557.log.json) \| [dict](https://download.openmmlab.com/mmocr/textrecog/sar/dict_printed_chinese_english_digits.txt) |
|
||||
|
||||
**Notes:**
|
||||
|
||||
- `R31-1/8-1/4` means the height of feature from backbone is 1/8 of input image, where 1/4 for width.
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
_base_ = ['../../_base_/default_runtime.py']
|
||||
|
||||
dict_file = 'data/chineseocr/labels/dict_printed_chinese_english_digits.txt'
|
||||
label_convertor = dict(
|
||||
type='AttnConvertor', dict_file=dict_file, with_unknown=True)
|
||||
|
||||
model = dict(
|
||||
type='SARNet',
|
||||
backbone=dict(type='ResNet31OCR'),
|
||||
encoder=dict(
|
||||
type='SAREncoder',
|
||||
enc_bi_rnn=False,
|
||||
enc_do_rnn=0.1,
|
||||
enc_gru=False,
|
||||
),
|
||||
decoder=dict(
|
||||
type='ParallelSARDecoder',
|
||||
enc_bi_rnn=False,
|
||||
dec_bi_rnn=False,
|
||||
dec_do_rnn=0,
|
||||
dec_gru=False,
|
||||
pred_dropout=0.1,
|
||||
d_k=512,
|
||||
pred_concat=True),
|
||||
loss=dict(type='SARLoss'),
|
||||
label_convertor=label_convertor,
|
||||
max_seq_len=30)
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(type='Adam', lr=1e-3)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[3, 4])
|
||||
total_epochs = 5
|
||||
|
||||
img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeOCR',
|
||||
height=48,
|
||||
min_width=48,
|
||||
max_width=256,
|
||||
keep_aspect_ratio=True,
|
||||
width_downsample_ratio=0.25),
|
||||
dict(type='ToTensorOCR'),
|
||||
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=[
|
||||
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
|
||||
]),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiRotateAugOCR',
|
||||
rotate_degrees=[0, 90, 270],
|
||||
transforms=[
|
||||
dict(
|
||||
type='ResizeOCR',
|
||||
height=48,
|
||||
min_width=48,
|
||||
max_width=256,
|
||||
keep_aspect_ratio=True,
|
||||
width_downsample_ratio=0.25),
|
||||
dict(type='ToTensorOCR'),
|
||||
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=[
|
||||
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
|
||||
]),
|
||||
])
|
||||
]
|
||||
|
||||
dataset_type = 'OCRDataset'
|
||||
|
||||
train_prefix = 'data/chinese/'
|
||||
|
||||
train_ann_file = train_prefix + 'labels/train.txt'
|
||||
|
||||
train = dict(
|
||||
type=dataset_type,
|
||||
img_prefix=train_prefix,
|
||||
ann_file=train_ann_file,
|
||||
loader=dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=1,
|
||||
parser=dict(
|
||||
type='LineStrParser',
|
||||
keys=['filename', 'text'],
|
||||
keys_idx=[0, 1],
|
||||
separator=' ')),
|
||||
pipeline=train_pipeline,
|
||||
test_mode=False)
|
||||
|
||||
test_prefix = 'data/chineseocr/'
|
||||
|
||||
test_ann_file = test_prefix + 'labels/test.txt'
|
||||
|
||||
test = dict(
|
||||
type=dataset_type,
|
||||
img_prefix=test_prefix,
|
||||
ann_file=test_ann_file,
|
||||
loader=dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=1,
|
||||
parser=dict(
|
||||
type='LineStrParser',
|
||||
keys=['filename', 'text'],
|
||||
keys_idx=[0, 1],
|
||||
separator=' ')),
|
||||
pipeline=test_pipeline,
|
||||
test_mode=False)
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=40,
|
||||
workers_per_gpu=2,
|
||||
train=dict(type='ConcatDataset', datasets=[train]),
|
||||
val=dict(type='ConcatDataset', datasets=[test]),
|
||||
test=dict(type='ConcatDataset', datasets=[test]))
|
||||
|
||||
evaluation = dict(interval=1, metric='acc')
|
|
@ -1,4 +1,7 @@
|
|||
import math
|
||||
import os
|
||||
import shutil
|
||||
import urllib
|
||||
import warnings
|
||||
|
||||
import cv2
|
||||
|
@ -6,6 +9,7 @@ import mmcv
|
|||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
import mmocr.utils as utils
|
||||
|
||||
|
@ -348,16 +352,22 @@ def imshow_text_label(img,
|
|||
resize_width = int(1.0 * src_w / src_h * resize_height)
|
||||
img = cv2.resize(img, (resize_width, resize_height))
|
||||
h, w = img.shape[:2]
|
||||
pred_img = np.ones((h, w, 3), dtype=np.uint8) * 255
|
||||
gt_img = np.ones((h, w, 3), dtype=np.uint8) * 255
|
||||
|
||||
cv2.putText(pred_img, pred_label, (5, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.9,
|
||||
(0, 0, 255), 2)
|
||||
if is_contain_chinese(pred_label):
|
||||
pred_img = draw_texts_by_pil(img, [pred_label], None)
|
||||
else:
|
||||
pred_img = np.ones((h, w, 3), dtype=np.uint8) * 255
|
||||
cv2.putText(pred_img, pred_label, (5, 40), cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.9, (0, 0, 255), 2)
|
||||
images = [pred_img, img]
|
||||
|
||||
if gt_label != '':
|
||||
cv2.putText(gt_img, gt_label, (5, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.9,
|
||||
(255, 0, 0), 2)
|
||||
if is_contain_chinese(gt_label):
|
||||
gt_img = draw_texts_by_pil(img, [gt_label], None)
|
||||
else:
|
||||
gt_img = np.ones((h, w, 3), dtype=np.uint8) * 255
|
||||
cv2.putText(gt_img, gt_label, (5, 40), cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.9, (255, 0, 0), 2)
|
||||
images.append(gt_img)
|
||||
|
||||
img = tile_image(images)
|
||||
|
@ -504,6 +514,62 @@ def draw_texts(img, boxes, texts):
|
|||
return out_img
|
||||
|
||||
|
||||
def draw_texts_by_pil(img, texts, boxes=None):
|
||||
"""Draw boxes and texts on empty image, especially for Chinese.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): The original image.
|
||||
texts (list[str]): Recognized texts.
|
||||
boxes (list[list[float]]): Detected bounding boxes.
|
||||
Return:
|
||||
out_img (np.ndarray): Visualized text image.
|
||||
"""
|
||||
|
||||
color_list = gen_color()
|
||||
h, w = img.shape[:2]
|
||||
if boxes is None:
|
||||
boxes = [[0, 0, w, 0, w, h, 0, h]]
|
||||
|
||||
out_img = Image.new('RGB', (w, h), color=(255, 255, 255))
|
||||
out_draw = ImageDraw.Draw(out_img)
|
||||
for idx, (box, text) in enumerate(zip(boxes, texts)):
|
||||
min_x, max_x = min(box[0::2]), max(box[0::2])
|
||||
min_y, max_y = min(box[1::2]), max(box[1::2])
|
||||
color = tuple(list(color_list[idx % len(color_list)])[::-1])
|
||||
out_draw.line(box, fill=color, width=1)
|
||||
box_width = max(max_x - min_x, max_y - min_y)
|
||||
font_size = int(0.9 * box_width / len(text))
|
||||
dirname, _ = os.path.split(os.path.abspath(__file__))
|
||||
font_path = os.path.join(dirname, 'font.TTF')
|
||||
if not os.path.exists(font_path):
|
||||
url = ('http://download.openmmlab.com/mmocr/data/font.TTF')
|
||||
print(f'Downloading {url} ...')
|
||||
local_filename, _ = urllib.request.urlretrieve(url)
|
||||
shutil.move(local_filename, font_path)
|
||||
fnt = ImageFont.truetype(font_path, font_size)
|
||||
out_draw.text((min_x + 1, min_y + 1), text, font=fnt, fill=(0, 0, 0))
|
||||
|
||||
del out_draw
|
||||
|
||||
out_img = cv2.cvtColor(np.asarray(out_img), cv2.COLOR_RGB2BGR)
|
||||
|
||||
return out_img
|
||||
|
||||
|
||||
def is_contain_chinese(check_str):
|
||||
"""Check whether string contains Chinese or not.
|
||||
|
||||
Args:
|
||||
check_str (str): String to be checked.
|
||||
|
||||
Return True if contains Chinese, else False.
|
||||
"""
|
||||
for ch in check_str:
|
||||
if u'\u4e00' <= ch <= u'\u9fff':
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def det_recog_show_result(img, end2end_res):
|
||||
"""Draw `result`(boxes and texts) on `img`.
|
||||
Args:
|
||||
|
@ -519,7 +585,11 @@ def det_recog_show_result(img, end2end_res):
|
|||
boxes.append(res['box'])
|
||||
texts.append(res['text'])
|
||||
box_vis_img = draw_polygons(img, boxes)
|
||||
text_vis_img = draw_texts(img, boxes, texts)
|
||||
|
||||
if is_contain_chinese(''.join(texts)):
|
||||
text_vis_img = draw_texts_by_pil(img, texts, boxes)
|
||||
else:
|
||||
text_vis_img = draw_texts(img, boxes, texts)
|
||||
|
||||
h, w = img.shape[:2]
|
||||
out_img = np.ones((h, w * 2, 3), dtype=np.uint8)
|
||||
|
|
|
@ -15,6 +15,10 @@ def test_det_recog_show_result():
|
|||
}
|
||||
|
||||
vis_img = det_recog_show_result(img, det_recog_res)
|
||||
|
||||
assert vis_img.shape[0] == 100
|
||||
assert vis_img.shape[1] == 200
|
||||
assert vis_img.shape[2] == 3
|
||||
|
||||
det_recog_res['result'][0]['text'] = '中文'
|
||||
det_recog_show_result(img, det_recog_res)
|
||||
|
|
|
@ -57,7 +57,7 @@ def test_show_text_label(mock_imwrite, mock_imshow, mock_imread):
|
|||
visualize_utils.imshow_text_label(
|
||||
img, pred_label, gt_label, out_file=out_file)
|
||||
visualize_utils.imshow_text_label(
|
||||
img, pred_label, gt_label, out_file=None, show=True)
|
||||
img, '中文', '中文', out_file=None, show=True)
|
||||
|
||||
# test showing img
|
||||
mock_imshow.assert_called_once()
|
||||
|
|
Loading…
Reference in New Issue