mirror of https://github.com/open-mmlab/mmocr.git
Merge pull request #5 from yuexy/feature/crnn_and_robustscanner
[feature]: add CRNN and RobustScannerpull/2/head
commit
50287450ab
|
@ -1,5 +1,5 @@
|
||||||
<div align="center">
|
<div align="center">
|
||||||
<img src="resources/mmocr-logo.jpg" width="500px"/>
|
<img src="resources/mmocr-logo.png" width="500px"/>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
## Introduction
|
## Introduction
|
||||||
|
|
|
@ -0,0 +1,37 @@
|
||||||
|
# An end-to-end trainable neural network for image-based sequence recognition and its application to scene text recognition
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
[ALGORITHM]
|
||||||
|
|
||||||
|
```latex
|
||||||
|
@article{shi2016end,
|
||||||
|
title={An end-to-end trainable neural network for image-based sequence recognition and its application to scene text recognition},
|
||||||
|
author={Shi, Baoguang and Bai, Xiang and Yao, Cong},
|
||||||
|
journal={IEEE transactions on pattern analysis and machine intelligence},
|
||||||
|
year={2016}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Results and Models
|
||||||
|
|
||||||
|
### Train Dataset
|
||||||
|
|
||||||
|
| trainset | instance_num | repeat_num | note |
|
||||||
|
| :------: | :----------: | :--------: | :---: |
|
||||||
|
| Syn90k | 8919273 | 1 | synth |
|
||||||
|
|
||||||
|
### Test Dataset
|
||||||
|
|
||||||
|
| testset | instance_num | note |
|
||||||
|
| :-----: | :----------: | :-----: |
|
||||||
|
| IIIT5K | 3000 | regular |
|
||||||
|
| SVT | 647 | regular |
|
||||||
|
| IC13 | 1015 | regular |
|
||||||
|
|
||||||
|
## Results and models
|
||||||
|
|
||||||
|
| methods | | Regular Text | | | | Irregular Text | | download |
|
||||||
|
| :-----: | :----: | :----------: | :--: | :-: | :--: | :------------: | :--: | :---------------------------------------------------------: |
|
||||||
|
| methods | IIIT5K | SVT | IC13 | | IC15 | SVTP | CT80 |
|
||||||
|
| CRNN | 80.5 | 81.5 | 86.5 | | - | - | - | [config](https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_academic_dataset.py) [log]() [model](https) |
|
|
@ -0,0 +1,155 @@
|
||||||
|
_base_ = []
|
||||||
|
checkpoint_config = dict(interval=1)
|
||||||
|
# yapf:disable
|
||||||
|
log_config = dict(
|
||||||
|
interval=1,
|
||||||
|
hooks=[
|
||||||
|
dict(type='TextLoggerHook')
|
||||||
|
|
||||||
|
])
|
||||||
|
# yapf:enable
|
||||||
|
dist_params = dict(backend='nccl')
|
||||||
|
log_level = 'INFO'
|
||||||
|
load_from = None
|
||||||
|
resume_from = None
|
||||||
|
workflow = [('train', 1)]
|
||||||
|
|
||||||
|
# model
|
||||||
|
label_convertor = dict(
|
||||||
|
type='CTCConvertor', dict_type='DICT36', with_unknown=False, lower=True)
|
||||||
|
|
||||||
|
model = dict(
|
||||||
|
type='CRNNNet',
|
||||||
|
preprocessor=None,
|
||||||
|
backbone=dict(type='VeryDeepVgg', leakyRelu=False, input_channels=1),
|
||||||
|
encoder=None,
|
||||||
|
decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True),
|
||||||
|
loss=dict(type='CTCLoss'),
|
||||||
|
label_convertor=label_convertor,
|
||||||
|
pretrained=None)
|
||||||
|
|
||||||
|
train_cfg = None
|
||||||
|
test_cfg = None
|
||||||
|
|
||||||
|
# optimizer
|
||||||
|
optimizer = dict(type='Adadelta', lr=1.0)
|
||||||
|
optimizer_config = dict(grad_clip=None)
|
||||||
|
# learning policy
|
||||||
|
lr_config = dict(policy='step', step=[])
|
||||||
|
total_epochs = 5
|
||||||
|
|
||||||
|
# data
|
||||||
|
img_norm_cfg = dict(mean=[0.5], std=[0.5])
|
||||||
|
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='ResizeOCR',
|
||||||
|
height=32,
|
||||||
|
min_width=100,
|
||||||
|
max_width=100,
|
||||||
|
keep_aspect_ratio=False),
|
||||||
|
dict(type='ToTensorOCR'),
|
||||||
|
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||||
|
dict(
|
||||||
|
type='Collect',
|
||||||
|
keys=['img'],
|
||||||
|
meta_keys=[
|
||||||
|
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
|
||||||
|
]),
|
||||||
|
]
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='ResizeOCR',
|
||||||
|
height=32,
|
||||||
|
min_width=4,
|
||||||
|
max_width=None,
|
||||||
|
keep_aspect_ratio=True),
|
||||||
|
dict(type='ToTensorOCR'),
|
||||||
|
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||||
|
dict(
|
||||||
|
type='Collect',
|
||||||
|
keys=['img'],
|
||||||
|
meta_keys=['filename', 'ori_shape', 'img_shape', 'valid_ratio']),
|
||||||
|
]
|
||||||
|
|
||||||
|
dataset_type = 'OCRDataset'
|
||||||
|
|
||||||
|
train_img_prefix = 'data/mixture/mnt/ramdisk/max/90kDICT32px'
|
||||||
|
train_ann_file = 'data/mixture/mnt/ramdisk/max/90kDICT32px/label.txt'
|
||||||
|
|
||||||
|
train1 = dict(
|
||||||
|
type=dataset_type,
|
||||||
|
img_prefix=train_img_prefix,
|
||||||
|
ann_file=train_ann_file,
|
||||||
|
loader=dict(
|
||||||
|
type='HardDiskLoader',
|
||||||
|
repeat=1,
|
||||||
|
parser=dict(
|
||||||
|
type='LineStrParser',
|
||||||
|
keys=['filename', 'text'],
|
||||||
|
keys_idx=[0, 1],
|
||||||
|
separator=' ')),
|
||||||
|
pipeline=train_pipeline,
|
||||||
|
test_mode=False)
|
||||||
|
|
||||||
|
test1 = dict(
|
||||||
|
type=dataset_type,
|
||||||
|
img_prefix=train_img_prefix,
|
||||||
|
ann_file=train_ann_file,
|
||||||
|
loader=dict(
|
||||||
|
type='HardDiskLoader',
|
||||||
|
repeat=1,
|
||||||
|
parser=dict(
|
||||||
|
type='LineStrParser',
|
||||||
|
keys=['filename', 'text'],
|
||||||
|
keys_idx=[0, 1],
|
||||||
|
separator=' ')),
|
||||||
|
pipeline=test_pipeline,
|
||||||
|
test_mode=True)
|
||||||
|
|
||||||
|
test_img_prefix = 'data/mixture/'
|
||||||
|
ic13_path = 'testset/icdar_2013/Challenge2_Test_Task3_Images/'
|
||||||
|
test_img_prefix1 = test_img_prefix + ic13_path
|
||||||
|
test_img_prefix2 = test_img_prefix + 'testset/IIIT5K/'
|
||||||
|
test_img_prefix3 = test_img_prefix + 'testset/svt/'
|
||||||
|
|
||||||
|
test_ann_prefix = 'data/mixture/'
|
||||||
|
test_ann_file1 = test_ann_prefix + 'testset/icdar_2013/test_label_1015.txt'
|
||||||
|
test_ann_file2 = test_ann_prefix + 'testset/IIIT5K/label.txt'
|
||||||
|
test_ann_file3 = test_ann_prefix + 'testset/svt/test_list.txt'
|
||||||
|
|
||||||
|
test1 = dict(
|
||||||
|
type=dataset_type,
|
||||||
|
img_prefix=test_img_prefix1,
|
||||||
|
ann_file=test_ann_file1,
|
||||||
|
loader=dict(
|
||||||
|
type='HardDiskLoader',
|
||||||
|
repeat=1,
|
||||||
|
parser=dict(
|
||||||
|
type='LineStrParser',
|
||||||
|
keys=['filename', 'text'],
|
||||||
|
keys_idx=[0, 1],
|
||||||
|
separator=' ')),
|
||||||
|
pipeline=test_pipeline,
|
||||||
|
test_mode=True)
|
||||||
|
|
||||||
|
test2 = {key: value for key, value in test1.items()}
|
||||||
|
test2['img_prefix'] = test_img_prefix2
|
||||||
|
test2['ann_file'] = test_ann_file2
|
||||||
|
|
||||||
|
test3 = {key: value for key, value in test1.items()}
|
||||||
|
test3['img_prefix'] = test_img_prefix3
|
||||||
|
test3['ann_file'] = test_ann_file3
|
||||||
|
|
||||||
|
data = dict(
|
||||||
|
samples_per_gpu=64,
|
||||||
|
workers_per_gpu=4,
|
||||||
|
train=dict(type='ConcatDataset', datasets=[train1]),
|
||||||
|
val=dict(type='ConcatDataset', datasets=[test1, test2, test3]),
|
||||||
|
test=dict(type='ConcatDataset', datasets=[test1, test2, test3]))
|
||||||
|
|
||||||
|
evaluation = dict(interval=1, metric='acc')
|
||||||
|
|
||||||
|
cudnn_benchmark = True
|
|
@ -0,0 +1,6 @@
|
||||||
|
_base_ = [
|
||||||
|
'../../_base_/schedules/schedule_adadelta_8e.py',
|
||||||
|
'../../_base_/default_runtime.py',
|
||||||
|
'../../_base_/recog_datasets/toy_dataset.py',
|
||||||
|
'../../_base_/recog_models/crnn.py'
|
||||||
|
]
|
|
@ -0,0 +1,12 @@
|
||||||
|
_base_ = [
|
||||||
|
'../../_base_/default_runtime.py',
|
||||||
|
'../../_base_/recog_models/robust_scanner.py',
|
||||||
|
'../../_base_/recog_datasets/toy_dataset.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
# optimizer
|
||||||
|
optimizer = dict(type='Adam', lr=1e-3)
|
||||||
|
optimizer_config = dict(grad_clip=None)
|
||||||
|
# learning policy
|
||||||
|
lr_config = dict(policy='step', step=[3, 4])
|
||||||
|
total_epochs = 6
|
|
@ -0,0 +1,198 @@
|
||||||
|
_base_ = [
|
||||||
|
'../../_base_/default_runtime.py',
|
||||||
|
'../../_base_/recog_models/robust_scanner.py'
|
||||||
|
]
|
||||||
|
|
||||||
|
# optimizer
|
||||||
|
optimizer = dict(type='Adam', lr=1e-3)
|
||||||
|
optimizer_config = dict(grad_clip=None)
|
||||||
|
# learning policy
|
||||||
|
lr_config = dict(policy='step', step=[3, 4])
|
||||||
|
total_epochs = 5
|
||||||
|
|
||||||
|
img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='ResizeOCR',
|
||||||
|
height=48,
|
||||||
|
min_width=48,
|
||||||
|
max_width=160,
|
||||||
|
keep_aspect_ratio=True,
|
||||||
|
width_downsample_ratio=0.25),
|
||||||
|
dict(type='ToTensorOCR'),
|
||||||
|
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||||
|
dict(
|
||||||
|
type='Collect',
|
||||||
|
keys=['img'],
|
||||||
|
meta_keys=[
|
||||||
|
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
|
||||||
|
]),
|
||||||
|
]
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='MultiRotateAugOCR',
|
||||||
|
rotate_degrees=[0, 90, 270],
|
||||||
|
transforms=[
|
||||||
|
dict(
|
||||||
|
type='ResizeOCR',
|
||||||
|
height=48,
|
||||||
|
min_width=48,
|
||||||
|
max_width=160,
|
||||||
|
keep_aspect_ratio=True,
|
||||||
|
width_downsample_ratio=0.25),
|
||||||
|
dict(type='ToTensorOCR'),
|
||||||
|
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||||
|
dict(
|
||||||
|
type='Collect',
|
||||||
|
keys=['img'],
|
||||||
|
meta_keys=[
|
||||||
|
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
|
||||||
|
]),
|
||||||
|
])
|
||||||
|
]
|
||||||
|
|
||||||
|
dataset_type = 'OCRDataset'
|
||||||
|
|
||||||
|
prefix = 'data/mixture/'
|
||||||
|
|
||||||
|
train_img_prefix1 = prefix + 'icdar_2011/Challenge1_Training_Task3_Images_GT'
|
||||||
|
train_img_prefix2 = prefix + 'icdar_2013/recog_train_data/' + \
|
||||||
|
'Challenge2_Training_Task3_Images_GT'
|
||||||
|
train_img_prefix3 = prefix + 'icdar_2015/ch4_training_word_images_gt'
|
||||||
|
train_img_prefix4 = prefix + 'coco_text/train_words'
|
||||||
|
train_img_prefix5 = prefix + 'III5K'
|
||||||
|
train_img_prefix6 = prefix + 'SynthText_Add/SynthText_Add'
|
||||||
|
train_img_prefix7 = prefix + 'SynthText/synthtext/SynthText_patch_horizontal'
|
||||||
|
train_img_prefix8 = prefix + 'mnt/ramdisk/max/90kDICT32px'
|
||||||
|
|
||||||
|
train_ann_file1 = prefix + 'icdar_2011/training_label_fix.txt',
|
||||||
|
train_ann_file2 = prefix + 'icdar_2013/recog_train_data/train_label.txt',
|
||||||
|
train_ann_file3 = prefix + 'icdar_2015/training_label_fix.txt',
|
||||||
|
train_ann_file4 = prefix + 'coco_text/train_label.txt',
|
||||||
|
train_ann_file5 = prefix + 'III5K/train_label.txt',
|
||||||
|
train_ann_file6 = prefix + 'SynthText_Add/SynthText_Add/' + \
|
||||||
|
'annotationlist/label.lmdb',
|
||||||
|
train_ann_file7 = prefix + 'SynthText/synthtext/shuffle_labels.lmdb',
|
||||||
|
train_ann_file8 = prefix + 'mnt/ramdisk/max/90kDICT32px/shuffle_labels.lmdb'
|
||||||
|
|
||||||
|
train1 = dict(
|
||||||
|
type=dataset_type,
|
||||||
|
img_prefix=train_img_prefix1,
|
||||||
|
ann_file=train_ann_file1,
|
||||||
|
loader=dict(
|
||||||
|
type='HardDiskLoader',
|
||||||
|
repeat=20,
|
||||||
|
parser=dict(
|
||||||
|
type='LineStrParser',
|
||||||
|
keys=['filename', 'text'],
|
||||||
|
keys_idx=[0, 1],
|
||||||
|
separator=' ')),
|
||||||
|
pipeline=train_pipeline,
|
||||||
|
test_mode=False)
|
||||||
|
|
||||||
|
train2 = {key: value for key, value in train1.items()}
|
||||||
|
train2['img_prefix'] = train_img_prefix2
|
||||||
|
train2['ann_file'] = train_ann_file2
|
||||||
|
|
||||||
|
train3 = {key: value for key, value in train1.items()}
|
||||||
|
train3['img_prefix'] = train_img_prefix3
|
||||||
|
train3['ann_file'] = train_ann_file3
|
||||||
|
|
||||||
|
train4 = {key: value for key, value in train1.items()}
|
||||||
|
train4['img_prefix'] = train_img_prefix4
|
||||||
|
train4['ann_file'] = train_ann_file4
|
||||||
|
|
||||||
|
train5 = {key: value for key, value in train1.items()}
|
||||||
|
train5['img_prefix'] = train_img_prefix5
|
||||||
|
train5['ann_file'] = train_ann_file5
|
||||||
|
|
||||||
|
train6 = dict(
|
||||||
|
type=dataset_type,
|
||||||
|
img_prefix=train_img_prefix6,
|
||||||
|
ann_file=train_ann_file6,
|
||||||
|
loader=dict(
|
||||||
|
type='LmdbLoader',
|
||||||
|
repeat=1,
|
||||||
|
parser=dict(
|
||||||
|
type='LineStrParser',
|
||||||
|
keys=['filename', 'text'],
|
||||||
|
keys_idx=[0, 1],
|
||||||
|
separator=' ')),
|
||||||
|
pipeline=train_pipeline,
|
||||||
|
test_mode=False)
|
||||||
|
|
||||||
|
train7 = {key: value for key, value in train6.items()}
|
||||||
|
train7['img_prefix'] = train_img_prefix7
|
||||||
|
train7['ann_file'] = train_ann_file7
|
||||||
|
|
||||||
|
train8 = {key: value for key, value in train6.items()}
|
||||||
|
train8['img_prefix'] = train_img_prefix8
|
||||||
|
train8['ann_file'] = train_ann_file8
|
||||||
|
|
||||||
|
test_img_prefix1 = prefix + 'testset/IIIT5K/'
|
||||||
|
test_img_prefix2 = prefix + 'testset/svt/'
|
||||||
|
test_img_prefix3 = prefix + 'testset/icdar_2013/Challenge2_Test_Task3_Images/'
|
||||||
|
test_img_prefix4 = prefix + 'testset/icdar_2015/ch4_test_word_images_gt'
|
||||||
|
test_img_prefix5 = prefix + 'testset/svtp/'
|
||||||
|
test_img_prefix6 = prefix + 'testset/ct80/'
|
||||||
|
|
||||||
|
test_ann_file1 = prefix + 'testset/IIIT5K/label.txt'
|
||||||
|
test_ann_file2 = prefix + 'testset/svt/test_list.txt'
|
||||||
|
test_ann_file3 = prefix + 'testset/icdar_2013/test_label_1015.txt'
|
||||||
|
test_ann_file4 = prefix + 'testset/icdar_2015/test_label.txt'
|
||||||
|
test_ann_file5 = prefix + 'testset/svtp/imagelist.txt'
|
||||||
|
test_ann_file6 = prefix + 'testset/ct80/imagelist.txt'
|
||||||
|
|
||||||
|
test1 = dict(
|
||||||
|
type=dataset_type,
|
||||||
|
img_prefix=test_img_prefix1,
|
||||||
|
ann_file=test_ann_file1,
|
||||||
|
loader=dict(
|
||||||
|
type='HardDiskLoader',
|
||||||
|
repeat=1,
|
||||||
|
parser=dict(
|
||||||
|
type='LineStrParser',
|
||||||
|
keys=['filename', 'text'],
|
||||||
|
keys_idx=[0, 1],
|
||||||
|
separator=' ')),
|
||||||
|
pipeline=test_pipeline,
|
||||||
|
test_mode=True)
|
||||||
|
|
||||||
|
test2 = {key: value for key, value in test1.items()}
|
||||||
|
test2['img_prefix'] = test_img_prefix2
|
||||||
|
test2['ann_file'] = test_ann_file2
|
||||||
|
|
||||||
|
test3 = {key: value for key, value in test1.items()}
|
||||||
|
test3['img_prefix'] = test_img_prefix3
|
||||||
|
test3['ann_file'] = test_ann_file3
|
||||||
|
|
||||||
|
test4 = {key: value for key, value in test1.items()}
|
||||||
|
test4['img_prefix'] = test_img_prefix4
|
||||||
|
test4['ann_file'] = test_ann_file4
|
||||||
|
|
||||||
|
test5 = {key: value for key, value in test1.items()}
|
||||||
|
test5['img_prefix'] = test_img_prefix5
|
||||||
|
test5['ann_file'] = test_ann_file5
|
||||||
|
|
||||||
|
test6 = {key: value for key, value in test1.items()}
|
||||||
|
test6['img_prefix'] = test_img_prefix6
|
||||||
|
test6['ann_file'] = test_ann_file6
|
||||||
|
|
||||||
|
data = dict(
|
||||||
|
samples_per_gpu=64,
|
||||||
|
workers_per_gpu=2,
|
||||||
|
train=dict(
|
||||||
|
type='ConcatDataset',
|
||||||
|
datasets=[
|
||||||
|
train1, train2, train3, train4, train5, train6, train7, train8
|
||||||
|
]),
|
||||||
|
val=dict(
|
||||||
|
type='ConcatDataset',
|
||||||
|
datasets=[test1, test2, test3, test4, test5, test6]),
|
||||||
|
test=dict(
|
||||||
|
type='ConcatDataset',
|
||||||
|
datasets=[test1, test2, test3, test4, test5, test6]))
|
||||||
|
|
||||||
|
evaluation = dict(interval=1, metric='acc')
|
|
@ -0,0 +1,216 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mmocr.utils as utils
|
||||||
|
from . import utils as eval_utils
|
||||||
|
|
||||||
|
|
||||||
|
def compute_recall_precision(gt_polys, pred_polys):
|
||||||
|
"""Compute the recall and the precision matrices between gt and predicted
|
||||||
|
polygons.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gt_polys (list[Polygon]): List of gt polygons.
|
||||||
|
pred_polys (list[Polygon]): List of predicted polygons.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
recall (ndarray): Recall matrix of size gt_num x det_num.
|
||||||
|
precision (ndarray): Precision matrix of size gt_num x det_num.
|
||||||
|
"""
|
||||||
|
assert isinstance(gt_polys, list)
|
||||||
|
assert isinstance(pred_polys, list)
|
||||||
|
|
||||||
|
gt_num = len(gt_polys)
|
||||||
|
det_num = len(pred_polys)
|
||||||
|
sz = [gt_num, det_num]
|
||||||
|
|
||||||
|
recall = np.zeros(sz)
|
||||||
|
precision = np.zeros(sz)
|
||||||
|
# compute area recall and precision for each (gt, det) pair
|
||||||
|
# in one img
|
||||||
|
for gt_id in range(gt_num):
|
||||||
|
for pred_id in range(det_num):
|
||||||
|
gt = gt_polys[gt_id]
|
||||||
|
det = pred_polys[pred_id]
|
||||||
|
|
||||||
|
inter_area, _ = eval_utils.poly_intersection(det, gt)
|
||||||
|
gt_area = gt.area()
|
||||||
|
det_area = det.area()
|
||||||
|
if gt_area != 0:
|
||||||
|
recall[gt_id, pred_id] = inter_area / gt_area
|
||||||
|
if det_area != 0:
|
||||||
|
precision[gt_id, pred_id] = inter_area / det_area
|
||||||
|
|
||||||
|
return recall, precision
|
||||||
|
|
||||||
|
|
||||||
|
def eval_hmean_ic13(det_boxes,
|
||||||
|
gt_boxes,
|
||||||
|
gt_ignored_boxes,
|
||||||
|
precision_thr=0.4,
|
||||||
|
recall_thr=0.8,
|
||||||
|
center_dist_thr=1.0,
|
||||||
|
one2one_score=1.,
|
||||||
|
one2many_score=0.8,
|
||||||
|
many2one_score=1.):
|
||||||
|
"""Evalute hmean of text detection using the icdar2013 standard.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
det_boxes (list[list[list[float]]]): List of arrays of shape (n, 2k).
|
||||||
|
Each element is the det_boxes for one img. k>=4.
|
||||||
|
gt_boxes (list[list[list[float]]]): List of arrays of shape (m, 2k).
|
||||||
|
Each element is the gt_boxes for one img. k>=4.
|
||||||
|
gt_ignored_boxes (list[list[list[float]]]): List of arrays of
|
||||||
|
(l, 2k). Each element is the ignored gt_boxes for one img. k>=4.
|
||||||
|
precision_thr (float): Precision threshold of the iou of one
|
||||||
|
(gt_box, det_box) pair.
|
||||||
|
recall_thr (float): Recall threshold of the iou of one
|
||||||
|
(gt_box, det_box) pair.
|
||||||
|
center_dist_thr (float): Distance threshold of one (gt_box, det_box)
|
||||||
|
center point pair.
|
||||||
|
one2one_score (float): Reward when one gt matches one det_box.
|
||||||
|
one2many_score (float): Reward when one gt matches many det_boxes.
|
||||||
|
many2one_score (float): Reward when many gts match one det_box.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
hmean (tuple[dict]): Tuple of dicts which encodes the hmean for
|
||||||
|
the dataset and all images.
|
||||||
|
"""
|
||||||
|
assert utils.is_3dlist(det_boxes)
|
||||||
|
assert utils.is_3dlist(gt_boxes)
|
||||||
|
assert utils.is_3dlist(gt_ignored_boxes)
|
||||||
|
|
||||||
|
assert 0 <= precision_thr <= 1
|
||||||
|
assert 0 <= recall_thr <= 1
|
||||||
|
assert center_dist_thr > 0
|
||||||
|
assert 0 <= one2one_score <= 1
|
||||||
|
assert 0 <= one2many_score <= 1
|
||||||
|
assert 0 <= many2one_score <= 1
|
||||||
|
|
||||||
|
img_num = len(det_boxes)
|
||||||
|
assert img_num == len(gt_boxes)
|
||||||
|
assert img_num == len(gt_ignored_boxes)
|
||||||
|
|
||||||
|
dataset_gt_num = 0
|
||||||
|
dataset_pred_num = 0
|
||||||
|
dataset_hit_recall = 0.0
|
||||||
|
dataset_hit_prec = 0.0
|
||||||
|
|
||||||
|
img_results = []
|
||||||
|
|
||||||
|
for i in range(img_num):
|
||||||
|
gt = gt_boxes[i]
|
||||||
|
gt_ignored = gt_ignored_boxes[i]
|
||||||
|
pred = det_boxes[i]
|
||||||
|
|
||||||
|
gt_num = len(gt)
|
||||||
|
ignored_num = len(gt_ignored)
|
||||||
|
pred_num = len(pred)
|
||||||
|
|
||||||
|
accum_recall = 0.
|
||||||
|
accum_precision = 0.
|
||||||
|
|
||||||
|
gt_points = gt + gt_ignored
|
||||||
|
gt_polys = [eval_utils.points2polygon(p) for p in gt_points]
|
||||||
|
gt_ignored_index = [gt_num + i for i in range(len(gt_ignored))]
|
||||||
|
gt_num = len(gt_polys)
|
||||||
|
|
||||||
|
pred_polys, pred_points, pred_ignored_index = eval_utils.ignore_pred(
|
||||||
|
pred, gt_ignored_index, gt_polys, precision_thr)
|
||||||
|
|
||||||
|
if pred_num > 0 and gt_num > 0:
|
||||||
|
|
||||||
|
gt_hit = np.zeros(gt_num, np.int8).tolist()
|
||||||
|
pred_hit = np.zeros(pred_num, np.int8).tolist()
|
||||||
|
|
||||||
|
# compute area recall and precision for each (gt, pred) pair
|
||||||
|
# in one img.
|
||||||
|
recall_mat, precision_mat = compute_recall_precision(
|
||||||
|
gt_polys, pred_polys)
|
||||||
|
|
||||||
|
# match one gt to one pred box.
|
||||||
|
for gt_id in range(gt_num):
|
||||||
|
for pred_id in range(pred_num):
|
||||||
|
if gt_hit[gt_id] != 0 or pred_hit[
|
||||||
|
pred_id] != 0 or gt_id in gt_ignored_index \
|
||||||
|
or pred_id in pred_ignored_index:
|
||||||
|
continue
|
||||||
|
match = eval_utils.one2one_match_ic13(
|
||||||
|
gt_id, pred_id, recall_mat, precision_mat, recall_thr,
|
||||||
|
precision_thr)
|
||||||
|
|
||||||
|
if match:
|
||||||
|
gt_point = np.array(gt_points[gt_id])
|
||||||
|
det_point = np.array(pred_points[pred_id])
|
||||||
|
|
||||||
|
norm_dist = eval_utils.box_center_distance(
|
||||||
|
det_point, gt_point)
|
||||||
|
norm_dist /= eval_utils.box_diag(
|
||||||
|
det_point) + eval_utils.box_diag(gt_point)
|
||||||
|
norm_dist *= 2.0
|
||||||
|
|
||||||
|
if norm_dist < center_dist_thr:
|
||||||
|
gt_hit[gt_id] = 1
|
||||||
|
pred_hit[pred_id] = 1
|
||||||
|
accum_recall += one2one_score
|
||||||
|
accum_precision += one2one_score
|
||||||
|
|
||||||
|
# match one gt to many det boxes.
|
||||||
|
for gt_id in range(gt_num):
|
||||||
|
if gt_id in gt_ignored_index:
|
||||||
|
continue
|
||||||
|
match, match_det_set = eval_utils.one2many_match_ic13(
|
||||||
|
gt_id, recall_mat, precision_mat, recall_thr,
|
||||||
|
precision_thr, gt_hit, pred_hit, pred_ignored_index)
|
||||||
|
|
||||||
|
if match:
|
||||||
|
gt_hit[gt_id] = 1
|
||||||
|
accum_recall += one2many_score
|
||||||
|
accum_precision += one2many_score * len(match_det_set)
|
||||||
|
for pred_id in match_det_set:
|
||||||
|
pred_hit[pred_id] = 1
|
||||||
|
|
||||||
|
# match many gt to one det box. One pair of (det,gt) are matched
|
||||||
|
# successfully if their recall, precision, normalized distance
|
||||||
|
# meet some thresholds.
|
||||||
|
for pred_id in range(pred_num):
|
||||||
|
if pred_id in pred_ignored_index:
|
||||||
|
continue
|
||||||
|
|
||||||
|
match, match_gt_set = eval_utils.many2one_match_ic13(
|
||||||
|
pred_id, recall_mat, precision_mat, recall_thr,
|
||||||
|
precision_thr, gt_hit, pred_hit, gt_ignored_index)
|
||||||
|
|
||||||
|
if match:
|
||||||
|
pred_hit[pred_id] = 1
|
||||||
|
accum_recall += many2one_score * len(match_gt_set)
|
||||||
|
accum_precision += many2one_score
|
||||||
|
for gt_id in match_gt_set:
|
||||||
|
gt_hit[gt_id] = 1
|
||||||
|
|
||||||
|
gt_care_number = gt_num - ignored_num
|
||||||
|
pred_care_number = pred_num - len(pred_ignored_index)
|
||||||
|
|
||||||
|
r, p, h = eval_utils.compute_hmean(accum_recall, accum_precision,
|
||||||
|
gt_care_number, pred_care_number)
|
||||||
|
|
||||||
|
img_results.append({'recall': r, 'precision': p, 'hmean': h})
|
||||||
|
|
||||||
|
dataset_gt_num += gt_care_number
|
||||||
|
dataset_pred_num += pred_care_number
|
||||||
|
dataset_hit_recall += accum_recall
|
||||||
|
dataset_hit_prec += accum_precision
|
||||||
|
|
||||||
|
total_r, total_p, total_h = eval_utils.compute_hmean(
|
||||||
|
dataset_hit_recall, dataset_hit_prec, dataset_gt_num, dataset_pred_num)
|
||||||
|
|
||||||
|
dataset_results = {
|
||||||
|
'num_gts': dataset_gt_num,
|
||||||
|
'num_dets': dataset_pred_num,
|
||||||
|
'num_recall': dataset_hit_recall,
|
||||||
|
'num_precision': dataset_hit_prec,
|
||||||
|
'recall': total_r,
|
||||||
|
'precision': total_p,
|
||||||
|
'hmean': total_h
|
||||||
|
}
|
||||||
|
|
||||||
|
return dataset_results, img_results
|
|
@ -0,0 +1,116 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mmocr.utils as utils
|
||||||
|
from . import utils as eval_utils
|
||||||
|
|
||||||
|
|
||||||
|
def eval_hmean_iou(pred_boxes,
|
||||||
|
gt_boxes,
|
||||||
|
gt_ignored_boxes,
|
||||||
|
iou_thr=0.5,
|
||||||
|
precision_thr=0.5):
|
||||||
|
"""Evalute hmean of text detection using IOU standard.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred_boxes (list[list[list[float]]]): Text boxes for an img list. Each
|
||||||
|
box has 2k (>=8) values.
|
||||||
|
gt_boxes (list[list[list[float]]]): Ground truth text boxes for an img
|
||||||
|
list. Each box has 2k (>=8) values.
|
||||||
|
gt_ignored_boxes (list[list[list[float]]]): Ignored ground truth text
|
||||||
|
boxes for an img list. Each box has 2k (>=8) values.
|
||||||
|
iou_thr (float): Iou threshold when one (gt_box, det_box) pair is
|
||||||
|
matched.
|
||||||
|
precision_thr (float): Precision threshold when one (gt_box, det_box)
|
||||||
|
pair is matched.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
hmean (tuple[dict]): Tuple of dicts indicates the hmean for the dataset
|
||||||
|
and all images.
|
||||||
|
"""
|
||||||
|
assert utils.is_3dlist(pred_boxes)
|
||||||
|
assert utils.is_3dlist(gt_boxes)
|
||||||
|
assert utils.is_3dlist(gt_ignored_boxes)
|
||||||
|
assert 0 <= iou_thr <= 1
|
||||||
|
assert 0 <= precision_thr <= 1
|
||||||
|
|
||||||
|
img_num = len(pred_boxes)
|
||||||
|
assert img_num == len(gt_boxes)
|
||||||
|
assert img_num == len(gt_ignored_boxes)
|
||||||
|
|
||||||
|
dataset_gt_num = 0
|
||||||
|
dataset_pred_num = 0
|
||||||
|
dataset_hit_num = 0
|
||||||
|
|
||||||
|
img_results = []
|
||||||
|
|
||||||
|
for i in range(img_num):
|
||||||
|
gt = gt_boxes[i]
|
||||||
|
gt_ignored = gt_ignored_boxes[i]
|
||||||
|
pred = pred_boxes[i]
|
||||||
|
|
||||||
|
gt_num = len(gt)
|
||||||
|
gt_ignored_num = len(gt_ignored)
|
||||||
|
pred_num = len(pred)
|
||||||
|
|
||||||
|
hit_num = 0
|
||||||
|
|
||||||
|
# get gt polygons.
|
||||||
|
gt_all = gt + gt_ignored
|
||||||
|
gt_polys = [eval_utils.points2polygon(p) for p in gt_all]
|
||||||
|
gt_ignored_index = [gt_num + i for i in range(len(gt_ignored))]
|
||||||
|
gt_num = len(gt_polys)
|
||||||
|
pred_polys, _, pred_ignored_index = eval_utils.ignore_pred(
|
||||||
|
pred, gt_ignored_index, gt_polys, precision_thr)
|
||||||
|
|
||||||
|
# match.
|
||||||
|
if gt_num > 0 and pred_num > 0:
|
||||||
|
sz = [gt_num, pred_num]
|
||||||
|
iou_mat = np.zeros(sz)
|
||||||
|
|
||||||
|
gt_hit = np.zeros(gt_num, np.int8)
|
||||||
|
pred_hit = np.zeros(pred_num, np.int8)
|
||||||
|
|
||||||
|
for gt_id in range(gt_num):
|
||||||
|
for pred_id in range(pred_num):
|
||||||
|
gt_pol = gt_polys[gt_id]
|
||||||
|
det_pol = pred_polys[pred_id]
|
||||||
|
|
||||||
|
iou_mat[gt_id,
|
||||||
|
pred_id] = eval_utils.poly_iou(det_pol, gt_pol)
|
||||||
|
|
||||||
|
for gt_id in range(gt_num):
|
||||||
|
for pred_id in range(pred_num):
|
||||||
|
if gt_hit[gt_id] != 0 or pred_hit[
|
||||||
|
pred_id] != 0 or gt_id in gt_ignored_index \
|
||||||
|
or pred_id in pred_ignored_index:
|
||||||
|
continue
|
||||||
|
if iou_mat[gt_id, pred_id] > iou_thr:
|
||||||
|
gt_hit[gt_id] = 1
|
||||||
|
pred_hit[pred_id] = 1
|
||||||
|
hit_num += 1
|
||||||
|
|
||||||
|
gt_care_number = gt_num - gt_ignored_num
|
||||||
|
pred_care_number = pred_num - len(pred_ignored_index)
|
||||||
|
|
||||||
|
r, p, h = eval_utils.compute_hmean(hit_num, hit_num, gt_care_number,
|
||||||
|
pred_care_number)
|
||||||
|
|
||||||
|
img_results.append({'recall': r, 'precision': p, 'hmean': h})
|
||||||
|
|
||||||
|
dataset_hit_num += hit_num
|
||||||
|
dataset_gt_num += gt_care_number
|
||||||
|
dataset_pred_num += pred_care_number
|
||||||
|
|
||||||
|
dataset_r, dataset_p, dataset_h = eval_utils.compute_hmean(
|
||||||
|
dataset_hit_num, dataset_hit_num, dataset_gt_num, dataset_pred_num)
|
||||||
|
|
||||||
|
dataset_results = {
|
||||||
|
'num_gts': dataset_gt_num,
|
||||||
|
'num_dets': dataset_pred_num,
|
||||||
|
'num_match': dataset_hit_num,
|
||||||
|
'recall': dataset_r,
|
||||||
|
'precision': dataset_p,
|
||||||
|
'hmean': dataset_h
|
||||||
|
}
|
||||||
|
|
||||||
|
return dataset_results, img_results
|
|
@ -0,0 +1,64 @@
|
||||||
|
import torch.nn as nn
|
||||||
|
from mmcv.cnn import uniform_init, xavier_init
|
||||||
|
|
||||||
|
from mmdet.models.builder import BACKBONES
|
||||||
|
|
||||||
|
|
||||||
|
@BACKBONES.register_module()
|
||||||
|
class VeryDeepVgg(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, leakyRelu=True, input_channels=3):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
ks = [3, 3, 3, 3, 3, 3, 2]
|
||||||
|
ps = [1, 1, 1, 1, 1, 1, 0]
|
||||||
|
ss = [1, 1, 1, 1, 1, 1, 1]
|
||||||
|
nm = [64, 128, 256, 256, 512, 512, 512]
|
||||||
|
|
||||||
|
self.channels = nm
|
||||||
|
|
||||||
|
cnn = nn.Sequential()
|
||||||
|
|
||||||
|
def convRelu(i, batchNormalization=False):
|
||||||
|
nIn = input_channels if i == 0 else nm[i - 1]
|
||||||
|
nOut = nm[i]
|
||||||
|
cnn.add_module('conv{0}'.format(i),
|
||||||
|
nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
|
||||||
|
if batchNormalization:
|
||||||
|
cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
|
||||||
|
if leakyRelu:
|
||||||
|
cnn.add_module('relu{0}'.format(i),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True))
|
||||||
|
else:
|
||||||
|
cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
|
||||||
|
|
||||||
|
convRelu(0)
|
||||||
|
cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64
|
||||||
|
convRelu(1)
|
||||||
|
cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
|
||||||
|
convRelu(2, True)
|
||||||
|
convRelu(3)
|
||||||
|
cnn.add_module('pooling{0}'.format(2),
|
||||||
|
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
|
||||||
|
convRelu(4, True)
|
||||||
|
convRelu(5)
|
||||||
|
cnn.add_module('pooling{0}'.format(3),
|
||||||
|
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
|
||||||
|
convRelu(6, True) # 512x1x16
|
||||||
|
|
||||||
|
self.cnn = cnn
|
||||||
|
|
||||||
|
def init_weights(self, pretrained=None):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
xavier_init(m)
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
uniform_init(m)
|
||||||
|
|
||||||
|
def out_channels(self):
|
||||||
|
return self.channels[-1]
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output = self.cnn(x)
|
||||||
|
|
||||||
|
return output
|
|
@ -0,0 +1,138 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from mmocr.models.builder import DECODERS
|
||||||
|
from mmocr.models.textrecog.layers import (DotProductAttentionLayer,
|
||||||
|
PositionAwareLayer)
|
||||||
|
from .base_decoder import BaseDecoder
|
||||||
|
|
||||||
|
|
||||||
|
@DECODERS.register_module()
|
||||||
|
class PositionAttentionDecoder(BaseDecoder):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_classes=None,
|
||||||
|
rnn_layers=2,
|
||||||
|
dim_input=512,
|
||||||
|
dim_model=128,
|
||||||
|
max_seq_len=40,
|
||||||
|
mask=True,
|
||||||
|
return_feature=False,
|
||||||
|
encode_value=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.dim_input = dim_input
|
||||||
|
self.dim_model = dim_model
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.return_feature = return_feature
|
||||||
|
self.encode_value = encode_value
|
||||||
|
self.mask = mask
|
||||||
|
|
||||||
|
self.embedding = nn.Embedding(self.max_seq_len + 1, self.dim_model)
|
||||||
|
|
||||||
|
self.position_aware_module = PositionAwareLayer(
|
||||||
|
self.dim_model, rnn_layers)
|
||||||
|
|
||||||
|
self.attention_layer = DotProductAttentionLayer()
|
||||||
|
|
||||||
|
self.prediction = None
|
||||||
|
if not self.return_feature:
|
||||||
|
pred_num_classes = num_classes - 1
|
||||||
|
self.prediction = nn.Linear(
|
||||||
|
dim_model if encode_value else dim_input, pred_num_classes)
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _get_position_index(self, length, batch_size, device=None):
|
||||||
|
position_index = torch.arange(0, length, device=device)
|
||||||
|
position_index = position_index.repeat([batch_size, 1])
|
||||||
|
position_index = position_index.long()
|
||||||
|
return position_index
|
||||||
|
|
||||||
|
def forward_train(self, feat, out_enc, targets_dict, img_metas):
|
||||||
|
valid_ratios = [
|
||||||
|
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||||
|
] if self.mask else None
|
||||||
|
|
||||||
|
targets = targets_dict['padded_targets'].to(feat.device)
|
||||||
|
|
||||||
|
#
|
||||||
|
n, c_enc, h, w = out_enc.size()
|
||||||
|
assert c_enc == self.dim_model
|
||||||
|
_, c_feat, _, _ = feat.size()
|
||||||
|
assert c_feat == self.dim_input
|
||||||
|
_, len_q = targets.size()
|
||||||
|
assert len_q <= self.max_seq_len
|
||||||
|
|
||||||
|
position_index = self._get_position_index(len_q, n, feat.device)
|
||||||
|
|
||||||
|
position_out_enc = self.position_aware_module(out_enc)
|
||||||
|
|
||||||
|
query = self.embedding(position_index)
|
||||||
|
query = query.permute(0, 2, 1).contiguous()
|
||||||
|
key = position_out_enc.view(n, c_enc, h * w)
|
||||||
|
if self.encode_value:
|
||||||
|
value = out_enc.view(n, c_enc, h * w)
|
||||||
|
else:
|
||||||
|
value = feat.view(n, c_feat, h * w)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if valid_ratios is not None:
|
||||||
|
mask = query.new_zeros((n, h, w))
|
||||||
|
for i, valid_ratio in enumerate(valid_ratios):
|
||||||
|
valid_width = min(w, math.ceil(w * valid_ratio))
|
||||||
|
mask[i, :, valid_width:] = 1
|
||||||
|
mask = mask.bool()
|
||||||
|
mask = mask.view(n, h * w)
|
||||||
|
|
||||||
|
attn_out = self.attention_layer(query, key, value, mask)
|
||||||
|
attn_out = attn_out.permute(0, 2, 1).contiguous() # [n, len_q, dim_v]
|
||||||
|
|
||||||
|
if self.return_feature:
|
||||||
|
return attn_out
|
||||||
|
|
||||||
|
return self.prediction(attn_out)
|
||||||
|
|
||||||
|
def forward_test(self, feat, out_enc, img_metas):
|
||||||
|
valid_ratios = [
|
||||||
|
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||||
|
] if self.mask else None
|
||||||
|
|
||||||
|
seq_len = self.max_seq_len
|
||||||
|
n, c_enc, h, w = out_enc.size()
|
||||||
|
assert c_enc == self.dim_model
|
||||||
|
_, c_feat, _, _ = feat.size()
|
||||||
|
assert c_feat == self.dim_input
|
||||||
|
|
||||||
|
position_index = self._get_position_index(seq_len, n, feat.device)
|
||||||
|
|
||||||
|
position_out_enc = self.position_aware_module(out_enc)
|
||||||
|
|
||||||
|
query = self.embedding(position_index)
|
||||||
|
query = query.permute(0, 2, 1).contiguous()
|
||||||
|
key = position_out_enc.view(n, c_enc, h * w)
|
||||||
|
if self.encode_value:
|
||||||
|
value = out_enc.view(n, c_enc, h * w)
|
||||||
|
else:
|
||||||
|
value = feat.view(n, c_feat, h * w)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if valid_ratios is not None:
|
||||||
|
mask = query.new_zeros((n, h, w))
|
||||||
|
for i, valid_ratio in enumerate(valid_ratios):
|
||||||
|
valid_width = min(w, math.ceil(w * valid_ratio))
|
||||||
|
mask[i, :, valid_width:] = 1
|
||||||
|
mask = mask.bool()
|
||||||
|
mask = mask.view(n, h * w)
|
||||||
|
|
||||||
|
attn_out = self.attention_layer(query, key, value, mask)
|
||||||
|
attn_out = attn_out.permute(0, 2, 1).contiguous()
|
||||||
|
|
||||||
|
if self.return_feature:
|
||||||
|
return attn_out
|
||||||
|
|
||||||
|
return self.prediction(attn_out)
|
|
@ -0,0 +1,107 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from mmocr.models.builder import DECODERS, build_decoder
|
||||||
|
from mmocr.models.textrecog.layers import RobustScannerFusionLayer
|
||||||
|
from .base_decoder import BaseDecoder
|
||||||
|
|
||||||
|
|
||||||
|
@DECODERS.register_module()
|
||||||
|
class RobustScannerDecoder(BaseDecoder):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_classes=None,
|
||||||
|
dim_input=512,
|
||||||
|
dim_model=128,
|
||||||
|
max_seq_len=40,
|
||||||
|
start_idx=0,
|
||||||
|
mask=True,
|
||||||
|
padding_idx=None,
|
||||||
|
encode_value=False,
|
||||||
|
hybrid_decoder=None,
|
||||||
|
position_decoder=None):
|
||||||
|
super().__init__()
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.dim_input = dim_input
|
||||||
|
self.dim_model = dim_model
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.encode_value = encode_value
|
||||||
|
self.start_idx = start_idx
|
||||||
|
self.padding_idx = padding_idx
|
||||||
|
self.mask = mask
|
||||||
|
|
||||||
|
# init hybrid decoder
|
||||||
|
hybrid_decoder.update(num_classes=self.num_classes)
|
||||||
|
hybrid_decoder.update(dim_input=self.dim_input)
|
||||||
|
hybrid_decoder.update(dim_model=self.dim_model)
|
||||||
|
hybrid_decoder.update(start_idx=self.start_idx)
|
||||||
|
hybrid_decoder.update(padding_idx=self.padding_idx)
|
||||||
|
hybrid_decoder.update(max_seq_len=self.max_seq_len)
|
||||||
|
hybrid_decoder.update(mask=self.mask)
|
||||||
|
hybrid_decoder.update(encode_value=self.encode_value)
|
||||||
|
hybrid_decoder.update(return_feature=True)
|
||||||
|
|
||||||
|
self.hybrid_decoder = build_decoder(hybrid_decoder)
|
||||||
|
|
||||||
|
# init position decoder
|
||||||
|
position_decoder.update(num_classes=self.num_classes)
|
||||||
|
position_decoder.update(dim_input=self.dim_input)
|
||||||
|
position_decoder.update(dim_model=self.dim_model)
|
||||||
|
position_decoder.update(max_seq_len=self.max_seq_len)
|
||||||
|
position_decoder.update(mask=self.mask)
|
||||||
|
position_decoder.update(encode_value=self.encode_value)
|
||||||
|
position_decoder.update(return_feature=True)
|
||||||
|
|
||||||
|
self.position_decoder = build_decoder(position_decoder)
|
||||||
|
|
||||||
|
self.fusion_module = RobustScannerFusionLayer(
|
||||||
|
self.dim_model if encode_value else dim_input)
|
||||||
|
|
||||||
|
pred_num_classes = num_classes - 1
|
||||||
|
self.prediction = nn.Linear(dim_model if encode_value else dim_input,
|
||||||
|
pred_num_classes)
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def forward_train(self, feat, out_enc, targets_dict, img_metas):
|
||||||
|
hybrid_glimpse = self.hybrid_decoder.forward_train(
|
||||||
|
feat, out_enc, targets_dict, img_metas)
|
||||||
|
position_glimpse = self.position_decoder.forward_train(
|
||||||
|
feat, out_enc, targets_dict, img_metas)
|
||||||
|
|
||||||
|
fusion_out = self.fusion_module(hybrid_glimpse, position_glimpse)
|
||||||
|
|
||||||
|
out = self.prediction(fusion_out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward_test(self, feat, out_enc, img_metas):
|
||||||
|
seq_len = self.max_seq_len
|
||||||
|
batch_size = feat.size(0)
|
||||||
|
|
||||||
|
decode_sequence = (feat.new_ones(
|
||||||
|
(batch_size, seq_len)) * self.start_idx).long()
|
||||||
|
|
||||||
|
position_glimpse = self.position_decoder.forward_test(
|
||||||
|
feat, out_enc, img_metas)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for i in range(seq_len):
|
||||||
|
hybrid_glimpse_step = self.hybrid_decoder.forward_test_step(
|
||||||
|
feat, out_enc, decode_sequence, i, img_metas)
|
||||||
|
|
||||||
|
fusion_out = self.fusion_module(hybrid_glimpse_step,
|
||||||
|
position_glimpse[:, i, :])
|
||||||
|
|
||||||
|
char_out = self.prediction(fusion_out)
|
||||||
|
char_out = F.softmax(char_out, -1)
|
||||||
|
outputs.append(char_out)
|
||||||
|
_, max_idx = torch.max(char_out, dim=1, keepdim=False)
|
||||||
|
if i < seq_len - 1:
|
||||||
|
decode_sequence[:, i + 1] = max_idx
|
||||||
|
|
||||||
|
outputs = torch.stack(outputs, 1)
|
||||||
|
|
||||||
|
return outputs
|
|
@ -0,0 +1,165 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from mmocr.models.builder import DECODERS
|
||||||
|
from mmocr.models.textrecog.layers import DotProductAttentionLayer
|
||||||
|
from .base_decoder import BaseDecoder
|
||||||
|
|
||||||
|
|
||||||
|
@DECODERS.register_module()
|
||||||
|
class SequenceAttentionDecoder(BaseDecoder):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_classes=None,
|
||||||
|
rnn_layers=2,
|
||||||
|
dim_input=512,
|
||||||
|
dim_model=128,
|
||||||
|
max_seq_len=40,
|
||||||
|
start_idx=0,
|
||||||
|
mask=True,
|
||||||
|
padding_idx=None,
|
||||||
|
dropout_ratio=0,
|
||||||
|
return_feature=False,
|
||||||
|
encode_value=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.dim_input = dim_input
|
||||||
|
self.dim_model = dim_model
|
||||||
|
self.return_feature = return_feature
|
||||||
|
self.encode_value = encode_value
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.start_idx = start_idx
|
||||||
|
self.mask = mask
|
||||||
|
|
||||||
|
self.embedding = nn.Embedding(
|
||||||
|
self.num_classes, self.dim_model, padding_idx=padding_idx)
|
||||||
|
|
||||||
|
self.sequence_layer = nn.LSTM(
|
||||||
|
input_size=dim_model,
|
||||||
|
hidden_size=dim_model,
|
||||||
|
num_layers=rnn_layers,
|
||||||
|
batch_first=True,
|
||||||
|
dropout=dropout_ratio)
|
||||||
|
|
||||||
|
self.attention_layer = DotProductAttentionLayer()
|
||||||
|
|
||||||
|
self.prediction = None
|
||||||
|
if not self.return_feature:
|
||||||
|
pred_num_classes = num_classes - 1
|
||||||
|
self.prediction = nn.Linear(
|
||||||
|
dim_model if encode_value else dim_input, pred_num_classes)
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def forward_train(self, feat, out_enc, targets_dict, img_metas):
|
||||||
|
valid_ratios = [
|
||||||
|
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||||
|
] if self.mask else None
|
||||||
|
|
||||||
|
targets = targets_dict['padded_targets'].to(feat.device)
|
||||||
|
tgt_embedding = self.embedding(targets)
|
||||||
|
|
||||||
|
n, c_enc, h, w = out_enc.size()
|
||||||
|
assert c_enc == self.dim_model
|
||||||
|
_, c_feat, _, _ = feat.size()
|
||||||
|
assert c_feat == self.dim_input
|
||||||
|
_, len_q, c_q = tgt_embedding.size()
|
||||||
|
assert c_q == self.dim_model
|
||||||
|
assert len_q <= self.max_seq_len
|
||||||
|
|
||||||
|
query, _ = self.sequence_layer(tgt_embedding)
|
||||||
|
query = query.permute(0, 2, 1).contiguous()
|
||||||
|
key = out_enc.view(n, c_enc, h * w)
|
||||||
|
if self.encode_value:
|
||||||
|
value = key
|
||||||
|
else:
|
||||||
|
value = feat.view(n, c_feat, h * w)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if valid_ratios is not None:
|
||||||
|
mask = query.new_zeros((n, h, w))
|
||||||
|
for i, valid_ratio in enumerate(valid_ratios):
|
||||||
|
valid_width = min(w, math.ceil(w * valid_ratio))
|
||||||
|
mask[i, :, valid_width:] = 1
|
||||||
|
mask = mask.bool()
|
||||||
|
mask = mask.view(n, h * w)
|
||||||
|
|
||||||
|
attn_out = self.attention_layer(query, key, value, mask)
|
||||||
|
attn_out = attn_out.permute(0, 2, 1).contiguous()
|
||||||
|
|
||||||
|
if self.return_feature:
|
||||||
|
return attn_out
|
||||||
|
|
||||||
|
out = self.prediction(attn_out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward_test(self, feat, out_enc, img_metas):
|
||||||
|
seq_len = self.max_seq_len
|
||||||
|
batch_size = feat.size(0)
|
||||||
|
|
||||||
|
decode_sequence = (feat.new_ones(
|
||||||
|
(batch_size, seq_len)) * self.start_idx).long()
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for i in range(seq_len):
|
||||||
|
step_out = self.forward_test_step(feat, out_enc, decode_sequence,
|
||||||
|
i, img_metas)
|
||||||
|
outputs.append(step_out)
|
||||||
|
_, max_idx = torch.max(step_out, dim=1, keepdim=False)
|
||||||
|
if i < seq_len - 1:
|
||||||
|
decode_sequence[:, i + 1] = max_idx
|
||||||
|
|
||||||
|
outputs = torch.stack(outputs, 1)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def forward_test_step(self, feat, out_enc, decode_sequence, current_step,
|
||||||
|
img_metas):
|
||||||
|
valid_ratios = [
|
||||||
|
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||||
|
] if self.mask else None
|
||||||
|
|
||||||
|
embed = self.embedding(decode_sequence)
|
||||||
|
|
||||||
|
n, c_enc, h, w = out_enc.size()
|
||||||
|
assert c_enc == self.dim_model
|
||||||
|
_, c_feat, _, _ = feat.size()
|
||||||
|
assert c_feat == self.dim_input
|
||||||
|
_, _, c_q = embed.size()
|
||||||
|
assert c_q == self.dim_model
|
||||||
|
|
||||||
|
query, _ = self.sequence_layer(embed)
|
||||||
|
query = query.permute(0, 2, 1).contiguous()
|
||||||
|
key = out_enc.view(n, c_enc, h * w)
|
||||||
|
if self.encode_value:
|
||||||
|
value = key
|
||||||
|
else:
|
||||||
|
value = feat.view(n, c_feat, h * w)
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if valid_ratios is not None:
|
||||||
|
mask = query.new_zeros((n, h, w))
|
||||||
|
for i, valid_ratio in enumerate(valid_ratios):
|
||||||
|
valid_width = min(w, math.ceil(w * valid_ratio))
|
||||||
|
mask[i, :, valid_width:] = 1
|
||||||
|
mask = mask.bool()
|
||||||
|
mask = mask.view(n, h * w)
|
||||||
|
|
||||||
|
# [n, c, l]
|
||||||
|
attn_out = self.attention_layer(query, key, value, mask)
|
||||||
|
|
||||||
|
out = attn_out[:, :, current_step]
|
||||||
|
|
||||||
|
if self.return_feature:
|
||||||
|
return out
|
||||||
|
|
||||||
|
out = self.prediction(out)
|
||||||
|
out = F.softmax(out, dim=-1)
|
||||||
|
|
||||||
|
return out
|
|
@ -0,0 +1,23 @@
|
||||||
|
import torch.nn as nn
|
||||||
|
from mmcv.cnn import xavier_init
|
||||||
|
|
||||||
|
from mmocr.models.builder import ENCODERS
|
||||||
|
from .base_encoder import BaseEncoder
|
||||||
|
|
||||||
|
|
||||||
|
@ENCODERS.register_module()
|
||||||
|
class ChannelReductionEncoder(BaseEncoder):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layer = nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
xavier_init(m)
|
||||||
|
|
||||||
|
def forward(self, feat, img_metas=None):
|
||||||
|
return self.layer(feat)
|
|
@ -0,0 +1,27 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class DotProductAttentionLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim_model=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.scale = dim_model**-0.5 if dim_model is not None else 1.
|
||||||
|
|
||||||
|
def forward(self, query, key, value, mask=None):
|
||||||
|
n, seq_len = mask.size()
|
||||||
|
logits = torch.matmul(query.permute(0, 2, 1), key) * self.scale
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask.view(n, 1, seq_len)
|
||||||
|
logits = logits.masked_fill(mask, float('-inf'))
|
||||||
|
|
||||||
|
weights = F.softmax(logits, dim=2)
|
||||||
|
|
||||||
|
glimpse = torch.matmul(weights, value.transpose(1, 2))
|
||||||
|
|
||||||
|
glimpse = glimpse.permute(0, 2, 1).contiguous()
|
||||||
|
|
||||||
|
return glimpse
|
|
@ -0,0 +1,35 @@
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class PositionAwareLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim_model, rnn_layers=2):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim_model = dim_model
|
||||||
|
|
||||||
|
self.rnn = nn.LSTM(
|
||||||
|
input_size=dim_model,
|
||||||
|
hidden_size=dim_model,
|
||||||
|
num_layers=rnn_layers,
|
||||||
|
batch_first=True)
|
||||||
|
|
||||||
|
self.mixer = nn.Sequential(
|
||||||
|
nn.Conv2d(
|
||||||
|
dim_model, dim_model, kernel_size=3, stride=1, padding=1),
|
||||||
|
nn.ReLU(True),
|
||||||
|
nn.Conv2d(
|
||||||
|
dim_model, dim_model, kernel_size=3, stride=1, padding=1))
|
||||||
|
|
||||||
|
def forward(self, img_feature):
|
||||||
|
n, c, h, w = img_feature.size()
|
||||||
|
|
||||||
|
rnn_input = img_feature.permute(0, 2, 3, 1).contiguous()
|
||||||
|
rnn_input = rnn_input.view(n * h, w, c)
|
||||||
|
rnn_output, _ = self.rnn(rnn_input)
|
||||||
|
rnn_output = rnn_output.view(n, h, w, c)
|
||||||
|
rnn_output = rnn_output.permute(0, 3, 1, 2).contiguous()
|
||||||
|
|
||||||
|
out = self.mixer(rnn_output)
|
||||||
|
|
||||||
|
return out
|
|
@ -0,0 +1,22 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class RobustScannerFusionLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim_model, dim=-1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim_model = dim_model
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
self.linear_layer = nn.Linear(dim_model * 2, dim_model * 2)
|
||||||
|
self.glu_layer = nn.GLU(dim=dim)
|
||||||
|
|
||||||
|
def forward(self, x0, x1):
|
||||||
|
assert x0.size() == x1.size()
|
||||||
|
fusion_input = torch.cat([x0, x1], self.dim)
|
||||||
|
output = self.linear_layer(fusion_input)
|
||||||
|
output = self.glu_layer(output)
|
||||||
|
|
||||||
|
return output
|
|
@ -0,0 +1,10 @@
|
||||||
|
from mmdet.models.builder import DETECTORS
|
||||||
|
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
||||||
|
|
||||||
|
|
||||||
|
@DETECTORS.register_module()
|
||||||
|
class RobustScanner(EncodeDecodeRecognizer):
|
||||||
|
"""Implementation of `RobustScanner.
|
||||||
|
|
||||||
|
<https://arxiv.org/pdf/2007.07542.pdf>
|
||||||
|
"""
|
Binary file not shown.
Before Width: | Height: | Size: 11 KiB |
Binary file not shown.
After Width: | Height: | Size: 31 KiB |
Loading…
Reference in New Issue