mirror of https://github.com/open-mmlab/mmocr.git
Merge pull request #5 from yuexy/feature/crnn_and_robustscanner
[feature]: add CRNN and RobustScannerpull/2/head
commit
50287450ab
|
@ -1,5 +1,5 @@
|
|||
<div align="center">
|
||||
<img src="resources/mmocr-logo.jpg" width="500px"/>
|
||||
<img src="resources/mmocr-logo.png" width="500px"/>
|
||||
</div>
|
||||
|
||||
## Introduction
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
# An end-to-end trainable neural network for image-based sequence recognition and its application to scene text recognition
|
||||
|
||||
## Introduction
|
||||
|
||||
[ALGORITHM]
|
||||
|
||||
```latex
|
||||
@article{shi2016end,
|
||||
title={An end-to-end trainable neural network for image-based sequence recognition and its application to scene text recognition},
|
||||
author={Shi, Baoguang and Bai, Xiang and Yao, Cong},
|
||||
journal={IEEE transactions on pattern analysis and machine intelligence},
|
||||
year={2016}
|
||||
}
|
||||
```
|
||||
|
||||
## Results and Models
|
||||
|
||||
### Train Dataset
|
||||
|
||||
| trainset | instance_num | repeat_num | note |
|
||||
| :------: | :----------: | :--------: | :---: |
|
||||
| Syn90k | 8919273 | 1 | synth |
|
||||
|
||||
### Test Dataset
|
||||
|
||||
| testset | instance_num | note |
|
||||
| :-----: | :----------: | :-----: |
|
||||
| IIIT5K | 3000 | regular |
|
||||
| SVT | 647 | regular |
|
||||
| IC13 | 1015 | regular |
|
||||
|
||||
## Results and models
|
||||
|
||||
| methods | | Regular Text | | | | Irregular Text | | download |
|
||||
| :-----: | :----: | :----------: | :--: | :-: | :--: | :------------: | :--: | :---------------------------------------------------------: |
|
||||
| methods | IIIT5K | SVT | IC13 | | IC15 | SVTP | CT80 |
|
||||
| CRNN | 80.5 | 81.5 | 86.5 | | - | - | - | [config](https://download.openmmlab.com/mmocr/textrecog/crnn/crnn_academic_dataset.py) [log]() [model](https) |
|
|
@ -0,0 +1,155 @@
|
|||
_base_ = []
|
||||
checkpoint_config = dict(interval=1)
|
||||
# yapf:disable
|
||||
log_config = dict(
|
||||
interval=1,
|
||||
hooks=[
|
||||
dict(type='TextLoggerHook')
|
||||
|
||||
])
|
||||
# yapf:enable
|
||||
dist_params = dict(backend='nccl')
|
||||
log_level = 'INFO'
|
||||
load_from = None
|
||||
resume_from = None
|
||||
workflow = [('train', 1)]
|
||||
|
||||
# model
|
||||
label_convertor = dict(
|
||||
type='CTCConvertor', dict_type='DICT36', with_unknown=False, lower=True)
|
||||
|
||||
model = dict(
|
||||
type='CRNNNet',
|
||||
preprocessor=None,
|
||||
backbone=dict(type='VeryDeepVgg', leakyRelu=False, input_channels=1),
|
||||
encoder=None,
|
||||
decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True),
|
||||
loss=dict(type='CTCLoss'),
|
||||
label_convertor=label_convertor,
|
||||
pretrained=None)
|
||||
|
||||
train_cfg = None
|
||||
test_cfg = None
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(type='Adadelta', lr=1.0)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[])
|
||||
total_epochs = 5
|
||||
|
||||
# data
|
||||
img_norm_cfg = dict(mean=[0.5], std=[0.5])
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeOCR',
|
||||
height=32,
|
||||
min_width=100,
|
||||
max_width=100,
|
||||
keep_aspect_ratio=False),
|
||||
dict(type='ToTensorOCR'),
|
||||
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=[
|
||||
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
|
||||
]),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeOCR',
|
||||
height=32,
|
||||
min_width=4,
|
||||
max_width=None,
|
||||
keep_aspect_ratio=True),
|
||||
dict(type='ToTensorOCR'),
|
||||
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=['filename', 'ori_shape', 'img_shape', 'valid_ratio']),
|
||||
]
|
||||
|
||||
dataset_type = 'OCRDataset'
|
||||
|
||||
train_img_prefix = 'data/mixture/mnt/ramdisk/max/90kDICT32px'
|
||||
train_ann_file = 'data/mixture/mnt/ramdisk/max/90kDICT32px/label.txt'
|
||||
|
||||
train1 = dict(
|
||||
type=dataset_type,
|
||||
img_prefix=train_img_prefix,
|
||||
ann_file=train_ann_file,
|
||||
loader=dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=1,
|
||||
parser=dict(
|
||||
type='LineStrParser',
|
||||
keys=['filename', 'text'],
|
||||
keys_idx=[0, 1],
|
||||
separator=' ')),
|
||||
pipeline=train_pipeline,
|
||||
test_mode=False)
|
||||
|
||||
test1 = dict(
|
||||
type=dataset_type,
|
||||
img_prefix=train_img_prefix,
|
||||
ann_file=train_ann_file,
|
||||
loader=dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=1,
|
||||
parser=dict(
|
||||
type='LineStrParser',
|
||||
keys=['filename', 'text'],
|
||||
keys_idx=[0, 1],
|
||||
separator=' ')),
|
||||
pipeline=test_pipeline,
|
||||
test_mode=True)
|
||||
|
||||
test_img_prefix = 'data/mixture/'
|
||||
ic13_path = 'testset/icdar_2013/Challenge2_Test_Task3_Images/'
|
||||
test_img_prefix1 = test_img_prefix + ic13_path
|
||||
test_img_prefix2 = test_img_prefix + 'testset/IIIT5K/'
|
||||
test_img_prefix3 = test_img_prefix + 'testset/svt/'
|
||||
|
||||
test_ann_prefix = 'data/mixture/'
|
||||
test_ann_file1 = test_ann_prefix + 'testset/icdar_2013/test_label_1015.txt'
|
||||
test_ann_file2 = test_ann_prefix + 'testset/IIIT5K/label.txt'
|
||||
test_ann_file3 = test_ann_prefix + 'testset/svt/test_list.txt'
|
||||
|
||||
test1 = dict(
|
||||
type=dataset_type,
|
||||
img_prefix=test_img_prefix1,
|
||||
ann_file=test_ann_file1,
|
||||
loader=dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=1,
|
||||
parser=dict(
|
||||
type='LineStrParser',
|
||||
keys=['filename', 'text'],
|
||||
keys_idx=[0, 1],
|
||||
separator=' ')),
|
||||
pipeline=test_pipeline,
|
||||
test_mode=True)
|
||||
|
||||
test2 = {key: value for key, value in test1.items()}
|
||||
test2['img_prefix'] = test_img_prefix2
|
||||
test2['ann_file'] = test_ann_file2
|
||||
|
||||
test3 = {key: value for key, value in test1.items()}
|
||||
test3['img_prefix'] = test_img_prefix3
|
||||
test3['ann_file'] = test_ann_file3
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=64,
|
||||
workers_per_gpu=4,
|
||||
train=dict(type='ConcatDataset', datasets=[train1]),
|
||||
val=dict(type='ConcatDataset', datasets=[test1, test2, test3]),
|
||||
test=dict(type='ConcatDataset', datasets=[test1, test2, test3]))
|
||||
|
||||
evaluation = dict(interval=1, metric='acc')
|
||||
|
||||
cudnn_benchmark = True
|
|
@ -0,0 +1,6 @@
|
|||
_base_ = [
|
||||
'../../_base_/schedules/schedule_adadelta_8e.py',
|
||||
'../../_base_/default_runtime.py',
|
||||
'../../_base_/recog_datasets/toy_dataset.py',
|
||||
'../../_base_/recog_models/crnn.py'
|
||||
]
|
|
@ -0,0 +1,12 @@
|
|||
_base_ = [
|
||||
'../../_base_/default_runtime.py',
|
||||
'../../_base_/recog_models/robust_scanner.py',
|
||||
'../../_base_/recog_datasets/toy_dataset.py'
|
||||
]
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(type='Adam', lr=1e-3)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[3, 4])
|
||||
total_epochs = 6
|
|
@ -0,0 +1,198 @@
|
|||
_base_ = [
|
||||
'../../_base_/default_runtime.py',
|
||||
'../../_base_/recog_models/robust_scanner.py'
|
||||
]
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(type='Adam', lr=1e-3)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[3, 4])
|
||||
total_epochs = 5
|
||||
|
||||
img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeOCR',
|
||||
height=48,
|
||||
min_width=48,
|
||||
max_width=160,
|
||||
keep_aspect_ratio=True,
|
||||
width_downsample_ratio=0.25),
|
||||
dict(type='ToTensorOCR'),
|
||||
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=[
|
||||
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
|
||||
]),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiRotateAugOCR',
|
||||
rotate_degrees=[0, 90, 270],
|
||||
transforms=[
|
||||
dict(
|
||||
type='ResizeOCR',
|
||||
height=48,
|
||||
min_width=48,
|
||||
max_width=160,
|
||||
keep_aspect_ratio=True,
|
||||
width_downsample_ratio=0.25),
|
||||
dict(type='ToTensorOCR'),
|
||||
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=[
|
||||
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
|
||||
]),
|
||||
])
|
||||
]
|
||||
|
||||
dataset_type = 'OCRDataset'
|
||||
|
||||
prefix = 'data/mixture/'
|
||||
|
||||
train_img_prefix1 = prefix + 'icdar_2011/Challenge1_Training_Task3_Images_GT'
|
||||
train_img_prefix2 = prefix + 'icdar_2013/recog_train_data/' + \
|
||||
'Challenge2_Training_Task3_Images_GT'
|
||||
train_img_prefix3 = prefix + 'icdar_2015/ch4_training_word_images_gt'
|
||||
train_img_prefix4 = prefix + 'coco_text/train_words'
|
||||
train_img_prefix5 = prefix + 'III5K'
|
||||
train_img_prefix6 = prefix + 'SynthText_Add/SynthText_Add'
|
||||
train_img_prefix7 = prefix + 'SynthText/synthtext/SynthText_patch_horizontal'
|
||||
train_img_prefix8 = prefix + 'mnt/ramdisk/max/90kDICT32px'
|
||||
|
||||
train_ann_file1 = prefix + 'icdar_2011/training_label_fix.txt',
|
||||
train_ann_file2 = prefix + 'icdar_2013/recog_train_data/train_label.txt',
|
||||
train_ann_file3 = prefix + 'icdar_2015/training_label_fix.txt',
|
||||
train_ann_file4 = prefix + 'coco_text/train_label.txt',
|
||||
train_ann_file5 = prefix + 'III5K/train_label.txt',
|
||||
train_ann_file6 = prefix + 'SynthText_Add/SynthText_Add/' + \
|
||||
'annotationlist/label.lmdb',
|
||||
train_ann_file7 = prefix + 'SynthText/synthtext/shuffle_labels.lmdb',
|
||||
train_ann_file8 = prefix + 'mnt/ramdisk/max/90kDICT32px/shuffle_labels.lmdb'
|
||||
|
||||
train1 = dict(
|
||||
type=dataset_type,
|
||||
img_prefix=train_img_prefix1,
|
||||
ann_file=train_ann_file1,
|
||||
loader=dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=20,
|
||||
parser=dict(
|
||||
type='LineStrParser',
|
||||
keys=['filename', 'text'],
|
||||
keys_idx=[0, 1],
|
||||
separator=' ')),
|
||||
pipeline=train_pipeline,
|
||||
test_mode=False)
|
||||
|
||||
train2 = {key: value for key, value in train1.items()}
|
||||
train2['img_prefix'] = train_img_prefix2
|
||||
train2['ann_file'] = train_ann_file2
|
||||
|
||||
train3 = {key: value for key, value in train1.items()}
|
||||
train3['img_prefix'] = train_img_prefix3
|
||||
train3['ann_file'] = train_ann_file3
|
||||
|
||||
train4 = {key: value for key, value in train1.items()}
|
||||
train4['img_prefix'] = train_img_prefix4
|
||||
train4['ann_file'] = train_ann_file4
|
||||
|
||||
train5 = {key: value for key, value in train1.items()}
|
||||
train5['img_prefix'] = train_img_prefix5
|
||||
train5['ann_file'] = train_ann_file5
|
||||
|
||||
train6 = dict(
|
||||
type=dataset_type,
|
||||
img_prefix=train_img_prefix6,
|
||||
ann_file=train_ann_file6,
|
||||
loader=dict(
|
||||
type='LmdbLoader',
|
||||
repeat=1,
|
||||
parser=dict(
|
||||
type='LineStrParser',
|
||||
keys=['filename', 'text'],
|
||||
keys_idx=[0, 1],
|
||||
separator=' ')),
|
||||
pipeline=train_pipeline,
|
||||
test_mode=False)
|
||||
|
||||
train7 = {key: value for key, value in train6.items()}
|
||||
train7['img_prefix'] = train_img_prefix7
|
||||
train7['ann_file'] = train_ann_file7
|
||||
|
||||
train8 = {key: value for key, value in train6.items()}
|
||||
train8['img_prefix'] = train_img_prefix8
|
||||
train8['ann_file'] = train_ann_file8
|
||||
|
||||
test_img_prefix1 = prefix + 'testset/IIIT5K/'
|
||||
test_img_prefix2 = prefix + 'testset/svt/'
|
||||
test_img_prefix3 = prefix + 'testset/icdar_2013/Challenge2_Test_Task3_Images/'
|
||||
test_img_prefix4 = prefix + 'testset/icdar_2015/ch4_test_word_images_gt'
|
||||
test_img_prefix5 = prefix + 'testset/svtp/'
|
||||
test_img_prefix6 = prefix + 'testset/ct80/'
|
||||
|
||||
test_ann_file1 = prefix + 'testset/IIIT5K/label.txt'
|
||||
test_ann_file2 = prefix + 'testset/svt/test_list.txt'
|
||||
test_ann_file3 = prefix + 'testset/icdar_2013/test_label_1015.txt'
|
||||
test_ann_file4 = prefix + 'testset/icdar_2015/test_label.txt'
|
||||
test_ann_file5 = prefix + 'testset/svtp/imagelist.txt'
|
||||
test_ann_file6 = prefix + 'testset/ct80/imagelist.txt'
|
||||
|
||||
test1 = dict(
|
||||
type=dataset_type,
|
||||
img_prefix=test_img_prefix1,
|
||||
ann_file=test_ann_file1,
|
||||
loader=dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=1,
|
||||
parser=dict(
|
||||
type='LineStrParser',
|
||||
keys=['filename', 'text'],
|
||||
keys_idx=[0, 1],
|
||||
separator=' ')),
|
||||
pipeline=test_pipeline,
|
||||
test_mode=True)
|
||||
|
||||
test2 = {key: value for key, value in test1.items()}
|
||||
test2['img_prefix'] = test_img_prefix2
|
||||
test2['ann_file'] = test_ann_file2
|
||||
|
||||
test3 = {key: value for key, value in test1.items()}
|
||||
test3['img_prefix'] = test_img_prefix3
|
||||
test3['ann_file'] = test_ann_file3
|
||||
|
||||
test4 = {key: value for key, value in test1.items()}
|
||||
test4['img_prefix'] = test_img_prefix4
|
||||
test4['ann_file'] = test_ann_file4
|
||||
|
||||
test5 = {key: value for key, value in test1.items()}
|
||||
test5['img_prefix'] = test_img_prefix5
|
||||
test5['ann_file'] = test_ann_file5
|
||||
|
||||
test6 = {key: value for key, value in test1.items()}
|
||||
test6['img_prefix'] = test_img_prefix6
|
||||
test6['ann_file'] = test_ann_file6
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=64,
|
||||
workers_per_gpu=2,
|
||||
train=dict(
|
||||
type='ConcatDataset',
|
||||
datasets=[
|
||||
train1, train2, train3, train4, train5, train6, train7, train8
|
||||
]),
|
||||
val=dict(
|
||||
type='ConcatDataset',
|
||||
datasets=[test1, test2, test3, test4, test5, test6]),
|
||||
test=dict(
|
||||
type='ConcatDataset',
|
||||
datasets=[test1, test2, test3, test4, test5, test6]))
|
||||
|
||||
evaluation = dict(interval=1, metric='acc')
|
|
@ -0,0 +1,216 @@
|
|||
import numpy as np
|
||||
|
||||
import mmocr.utils as utils
|
||||
from . import utils as eval_utils
|
||||
|
||||
|
||||
def compute_recall_precision(gt_polys, pred_polys):
|
||||
"""Compute the recall and the precision matrices between gt and predicted
|
||||
polygons.
|
||||
|
||||
Args:
|
||||
gt_polys (list[Polygon]): List of gt polygons.
|
||||
pred_polys (list[Polygon]): List of predicted polygons.
|
||||
|
||||
Returns:
|
||||
recall (ndarray): Recall matrix of size gt_num x det_num.
|
||||
precision (ndarray): Precision matrix of size gt_num x det_num.
|
||||
"""
|
||||
assert isinstance(gt_polys, list)
|
||||
assert isinstance(pred_polys, list)
|
||||
|
||||
gt_num = len(gt_polys)
|
||||
det_num = len(pred_polys)
|
||||
sz = [gt_num, det_num]
|
||||
|
||||
recall = np.zeros(sz)
|
||||
precision = np.zeros(sz)
|
||||
# compute area recall and precision for each (gt, det) pair
|
||||
# in one img
|
||||
for gt_id in range(gt_num):
|
||||
for pred_id in range(det_num):
|
||||
gt = gt_polys[gt_id]
|
||||
det = pred_polys[pred_id]
|
||||
|
||||
inter_area, _ = eval_utils.poly_intersection(det, gt)
|
||||
gt_area = gt.area()
|
||||
det_area = det.area()
|
||||
if gt_area != 0:
|
||||
recall[gt_id, pred_id] = inter_area / gt_area
|
||||
if det_area != 0:
|
||||
precision[gt_id, pred_id] = inter_area / det_area
|
||||
|
||||
return recall, precision
|
||||
|
||||
|
||||
def eval_hmean_ic13(det_boxes,
|
||||
gt_boxes,
|
||||
gt_ignored_boxes,
|
||||
precision_thr=0.4,
|
||||
recall_thr=0.8,
|
||||
center_dist_thr=1.0,
|
||||
one2one_score=1.,
|
||||
one2many_score=0.8,
|
||||
many2one_score=1.):
|
||||
"""Evalute hmean of text detection using the icdar2013 standard.
|
||||
|
||||
Args:
|
||||
det_boxes (list[list[list[float]]]): List of arrays of shape (n, 2k).
|
||||
Each element is the det_boxes for one img. k>=4.
|
||||
gt_boxes (list[list[list[float]]]): List of arrays of shape (m, 2k).
|
||||
Each element is the gt_boxes for one img. k>=4.
|
||||
gt_ignored_boxes (list[list[list[float]]]): List of arrays of
|
||||
(l, 2k). Each element is the ignored gt_boxes for one img. k>=4.
|
||||
precision_thr (float): Precision threshold of the iou of one
|
||||
(gt_box, det_box) pair.
|
||||
recall_thr (float): Recall threshold of the iou of one
|
||||
(gt_box, det_box) pair.
|
||||
center_dist_thr (float): Distance threshold of one (gt_box, det_box)
|
||||
center point pair.
|
||||
one2one_score (float): Reward when one gt matches one det_box.
|
||||
one2many_score (float): Reward when one gt matches many det_boxes.
|
||||
many2one_score (float): Reward when many gts match one det_box.
|
||||
|
||||
Returns:
|
||||
hmean (tuple[dict]): Tuple of dicts which encodes the hmean for
|
||||
the dataset and all images.
|
||||
"""
|
||||
assert utils.is_3dlist(det_boxes)
|
||||
assert utils.is_3dlist(gt_boxes)
|
||||
assert utils.is_3dlist(gt_ignored_boxes)
|
||||
|
||||
assert 0 <= precision_thr <= 1
|
||||
assert 0 <= recall_thr <= 1
|
||||
assert center_dist_thr > 0
|
||||
assert 0 <= one2one_score <= 1
|
||||
assert 0 <= one2many_score <= 1
|
||||
assert 0 <= many2one_score <= 1
|
||||
|
||||
img_num = len(det_boxes)
|
||||
assert img_num == len(gt_boxes)
|
||||
assert img_num == len(gt_ignored_boxes)
|
||||
|
||||
dataset_gt_num = 0
|
||||
dataset_pred_num = 0
|
||||
dataset_hit_recall = 0.0
|
||||
dataset_hit_prec = 0.0
|
||||
|
||||
img_results = []
|
||||
|
||||
for i in range(img_num):
|
||||
gt = gt_boxes[i]
|
||||
gt_ignored = gt_ignored_boxes[i]
|
||||
pred = det_boxes[i]
|
||||
|
||||
gt_num = len(gt)
|
||||
ignored_num = len(gt_ignored)
|
||||
pred_num = len(pred)
|
||||
|
||||
accum_recall = 0.
|
||||
accum_precision = 0.
|
||||
|
||||
gt_points = gt + gt_ignored
|
||||
gt_polys = [eval_utils.points2polygon(p) for p in gt_points]
|
||||
gt_ignored_index = [gt_num + i for i in range(len(gt_ignored))]
|
||||
gt_num = len(gt_polys)
|
||||
|
||||
pred_polys, pred_points, pred_ignored_index = eval_utils.ignore_pred(
|
||||
pred, gt_ignored_index, gt_polys, precision_thr)
|
||||
|
||||
if pred_num > 0 and gt_num > 0:
|
||||
|
||||
gt_hit = np.zeros(gt_num, np.int8).tolist()
|
||||
pred_hit = np.zeros(pred_num, np.int8).tolist()
|
||||
|
||||
# compute area recall and precision for each (gt, pred) pair
|
||||
# in one img.
|
||||
recall_mat, precision_mat = compute_recall_precision(
|
||||
gt_polys, pred_polys)
|
||||
|
||||
# match one gt to one pred box.
|
||||
for gt_id in range(gt_num):
|
||||
for pred_id in range(pred_num):
|
||||
if gt_hit[gt_id] != 0 or pred_hit[
|
||||
pred_id] != 0 or gt_id in gt_ignored_index \
|
||||
or pred_id in pred_ignored_index:
|
||||
continue
|
||||
match = eval_utils.one2one_match_ic13(
|
||||
gt_id, pred_id, recall_mat, precision_mat, recall_thr,
|
||||
precision_thr)
|
||||
|
||||
if match:
|
||||
gt_point = np.array(gt_points[gt_id])
|
||||
det_point = np.array(pred_points[pred_id])
|
||||
|
||||
norm_dist = eval_utils.box_center_distance(
|
||||
det_point, gt_point)
|
||||
norm_dist /= eval_utils.box_diag(
|
||||
det_point) + eval_utils.box_diag(gt_point)
|
||||
norm_dist *= 2.0
|
||||
|
||||
if norm_dist < center_dist_thr:
|
||||
gt_hit[gt_id] = 1
|
||||
pred_hit[pred_id] = 1
|
||||
accum_recall += one2one_score
|
||||
accum_precision += one2one_score
|
||||
|
||||
# match one gt to many det boxes.
|
||||
for gt_id in range(gt_num):
|
||||
if gt_id in gt_ignored_index:
|
||||
continue
|
||||
match, match_det_set = eval_utils.one2many_match_ic13(
|
||||
gt_id, recall_mat, precision_mat, recall_thr,
|
||||
precision_thr, gt_hit, pred_hit, pred_ignored_index)
|
||||
|
||||
if match:
|
||||
gt_hit[gt_id] = 1
|
||||
accum_recall += one2many_score
|
||||
accum_precision += one2many_score * len(match_det_set)
|
||||
for pred_id in match_det_set:
|
||||
pred_hit[pred_id] = 1
|
||||
|
||||
# match many gt to one det box. One pair of (det,gt) are matched
|
||||
# successfully if their recall, precision, normalized distance
|
||||
# meet some thresholds.
|
||||
for pred_id in range(pred_num):
|
||||
if pred_id in pred_ignored_index:
|
||||
continue
|
||||
|
||||
match, match_gt_set = eval_utils.many2one_match_ic13(
|
||||
pred_id, recall_mat, precision_mat, recall_thr,
|
||||
precision_thr, gt_hit, pred_hit, gt_ignored_index)
|
||||
|
||||
if match:
|
||||
pred_hit[pred_id] = 1
|
||||
accum_recall += many2one_score * len(match_gt_set)
|
||||
accum_precision += many2one_score
|
||||
for gt_id in match_gt_set:
|
||||
gt_hit[gt_id] = 1
|
||||
|
||||
gt_care_number = gt_num - ignored_num
|
||||
pred_care_number = pred_num - len(pred_ignored_index)
|
||||
|
||||
r, p, h = eval_utils.compute_hmean(accum_recall, accum_precision,
|
||||
gt_care_number, pred_care_number)
|
||||
|
||||
img_results.append({'recall': r, 'precision': p, 'hmean': h})
|
||||
|
||||
dataset_gt_num += gt_care_number
|
||||
dataset_pred_num += pred_care_number
|
||||
dataset_hit_recall += accum_recall
|
||||
dataset_hit_prec += accum_precision
|
||||
|
||||
total_r, total_p, total_h = eval_utils.compute_hmean(
|
||||
dataset_hit_recall, dataset_hit_prec, dataset_gt_num, dataset_pred_num)
|
||||
|
||||
dataset_results = {
|
||||
'num_gts': dataset_gt_num,
|
||||
'num_dets': dataset_pred_num,
|
||||
'num_recall': dataset_hit_recall,
|
||||
'num_precision': dataset_hit_prec,
|
||||
'recall': total_r,
|
||||
'precision': total_p,
|
||||
'hmean': total_h
|
||||
}
|
||||
|
||||
return dataset_results, img_results
|
|
@ -0,0 +1,116 @@
|
|||
import numpy as np
|
||||
|
||||
import mmocr.utils as utils
|
||||
from . import utils as eval_utils
|
||||
|
||||
|
||||
def eval_hmean_iou(pred_boxes,
|
||||
gt_boxes,
|
||||
gt_ignored_boxes,
|
||||
iou_thr=0.5,
|
||||
precision_thr=0.5):
|
||||
"""Evalute hmean of text detection using IOU standard.
|
||||
|
||||
Args:
|
||||
pred_boxes (list[list[list[float]]]): Text boxes for an img list. Each
|
||||
box has 2k (>=8) values.
|
||||
gt_boxes (list[list[list[float]]]): Ground truth text boxes for an img
|
||||
list. Each box has 2k (>=8) values.
|
||||
gt_ignored_boxes (list[list[list[float]]]): Ignored ground truth text
|
||||
boxes for an img list. Each box has 2k (>=8) values.
|
||||
iou_thr (float): Iou threshold when one (gt_box, det_box) pair is
|
||||
matched.
|
||||
precision_thr (float): Precision threshold when one (gt_box, det_box)
|
||||
pair is matched.
|
||||
|
||||
Returns:
|
||||
hmean (tuple[dict]): Tuple of dicts indicates the hmean for the dataset
|
||||
and all images.
|
||||
"""
|
||||
assert utils.is_3dlist(pred_boxes)
|
||||
assert utils.is_3dlist(gt_boxes)
|
||||
assert utils.is_3dlist(gt_ignored_boxes)
|
||||
assert 0 <= iou_thr <= 1
|
||||
assert 0 <= precision_thr <= 1
|
||||
|
||||
img_num = len(pred_boxes)
|
||||
assert img_num == len(gt_boxes)
|
||||
assert img_num == len(gt_ignored_boxes)
|
||||
|
||||
dataset_gt_num = 0
|
||||
dataset_pred_num = 0
|
||||
dataset_hit_num = 0
|
||||
|
||||
img_results = []
|
||||
|
||||
for i in range(img_num):
|
||||
gt = gt_boxes[i]
|
||||
gt_ignored = gt_ignored_boxes[i]
|
||||
pred = pred_boxes[i]
|
||||
|
||||
gt_num = len(gt)
|
||||
gt_ignored_num = len(gt_ignored)
|
||||
pred_num = len(pred)
|
||||
|
||||
hit_num = 0
|
||||
|
||||
# get gt polygons.
|
||||
gt_all = gt + gt_ignored
|
||||
gt_polys = [eval_utils.points2polygon(p) for p in gt_all]
|
||||
gt_ignored_index = [gt_num + i for i in range(len(gt_ignored))]
|
||||
gt_num = len(gt_polys)
|
||||
pred_polys, _, pred_ignored_index = eval_utils.ignore_pred(
|
||||
pred, gt_ignored_index, gt_polys, precision_thr)
|
||||
|
||||
# match.
|
||||
if gt_num > 0 and pred_num > 0:
|
||||
sz = [gt_num, pred_num]
|
||||
iou_mat = np.zeros(sz)
|
||||
|
||||
gt_hit = np.zeros(gt_num, np.int8)
|
||||
pred_hit = np.zeros(pred_num, np.int8)
|
||||
|
||||
for gt_id in range(gt_num):
|
||||
for pred_id in range(pred_num):
|
||||
gt_pol = gt_polys[gt_id]
|
||||
det_pol = pred_polys[pred_id]
|
||||
|
||||
iou_mat[gt_id,
|
||||
pred_id] = eval_utils.poly_iou(det_pol, gt_pol)
|
||||
|
||||
for gt_id in range(gt_num):
|
||||
for pred_id in range(pred_num):
|
||||
if gt_hit[gt_id] != 0 or pred_hit[
|
||||
pred_id] != 0 or gt_id in gt_ignored_index \
|
||||
or pred_id in pred_ignored_index:
|
||||
continue
|
||||
if iou_mat[gt_id, pred_id] > iou_thr:
|
||||
gt_hit[gt_id] = 1
|
||||
pred_hit[pred_id] = 1
|
||||
hit_num += 1
|
||||
|
||||
gt_care_number = gt_num - gt_ignored_num
|
||||
pred_care_number = pred_num - len(pred_ignored_index)
|
||||
|
||||
r, p, h = eval_utils.compute_hmean(hit_num, hit_num, gt_care_number,
|
||||
pred_care_number)
|
||||
|
||||
img_results.append({'recall': r, 'precision': p, 'hmean': h})
|
||||
|
||||
dataset_hit_num += hit_num
|
||||
dataset_gt_num += gt_care_number
|
||||
dataset_pred_num += pred_care_number
|
||||
|
||||
dataset_r, dataset_p, dataset_h = eval_utils.compute_hmean(
|
||||
dataset_hit_num, dataset_hit_num, dataset_gt_num, dataset_pred_num)
|
||||
|
||||
dataset_results = {
|
||||
'num_gts': dataset_gt_num,
|
||||
'num_dets': dataset_pred_num,
|
||||
'num_match': dataset_hit_num,
|
||||
'recall': dataset_r,
|
||||
'precision': dataset_p,
|
||||
'hmean': dataset_h
|
||||
}
|
||||
|
||||
return dataset_results, img_results
|
|
@ -0,0 +1,64 @@
|
|||
import torch.nn as nn
|
||||
from mmcv.cnn import uniform_init, xavier_init
|
||||
|
||||
from mmdet.models.builder import BACKBONES
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class VeryDeepVgg(nn.Module):
|
||||
|
||||
def __init__(self, leakyRelu=True, input_channels=3):
|
||||
super().__init__()
|
||||
|
||||
ks = [3, 3, 3, 3, 3, 3, 2]
|
||||
ps = [1, 1, 1, 1, 1, 1, 0]
|
||||
ss = [1, 1, 1, 1, 1, 1, 1]
|
||||
nm = [64, 128, 256, 256, 512, 512, 512]
|
||||
|
||||
self.channels = nm
|
||||
|
||||
cnn = nn.Sequential()
|
||||
|
||||
def convRelu(i, batchNormalization=False):
|
||||
nIn = input_channels if i == 0 else nm[i - 1]
|
||||
nOut = nm[i]
|
||||
cnn.add_module('conv{0}'.format(i),
|
||||
nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
|
||||
if batchNormalization:
|
||||
cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
|
||||
if leakyRelu:
|
||||
cnn.add_module('relu{0}'.format(i),
|
||||
nn.LeakyReLU(0.2, inplace=True))
|
||||
else:
|
||||
cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
|
||||
|
||||
convRelu(0)
|
||||
cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64
|
||||
convRelu(1)
|
||||
cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
|
||||
convRelu(2, True)
|
||||
convRelu(3)
|
||||
cnn.add_module('pooling{0}'.format(2),
|
||||
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
|
||||
convRelu(4, True)
|
||||
convRelu(5)
|
||||
cnn.add_module('pooling{0}'.format(3),
|
||||
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
|
||||
convRelu(6, True) # 512x1x16
|
||||
|
||||
self.cnn = cnn
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
xavier_init(m)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
uniform_init(m)
|
||||
|
||||
def out_channels(self):
|
||||
return self.channels[-1]
|
||||
|
||||
def forward(self, x):
|
||||
output = self.cnn(x)
|
||||
|
||||
return output
|
|
@ -0,0 +1,138 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmocr.models.builder import DECODERS
|
||||
from mmocr.models.textrecog.layers import (DotProductAttentionLayer,
|
||||
PositionAwareLayer)
|
||||
from .base_decoder import BaseDecoder
|
||||
|
||||
|
||||
@DECODERS.register_module()
|
||||
class PositionAttentionDecoder(BaseDecoder):
|
||||
|
||||
def __init__(self,
|
||||
num_classes=None,
|
||||
rnn_layers=2,
|
||||
dim_input=512,
|
||||
dim_model=128,
|
||||
max_seq_len=40,
|
||||
mask=True,
|
||||
return_feature=False,
|
||||
encode_value=False):
|
||||
super().__init__()
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.dim_input = dim_input
|
||||
self.dim_model = dim_model
|
||||
self.max_seq_len = max_seq_len
|
||||
self.return_feature = return_feature
|
||||
self.encode_value = encode_value
|
||||
self.mask = mask
|
||||
|
||||
self.embedding = nn.Embedding(self.max_seq_len + 1, self.dim_model)
|
||||
|
||||
self.position_aware_module = PositionAwareLayer(
|
||||
self.dim_model, rnn_layers)
|
||||
|
||||
self.attention_layer = DotProductAttentionLayer()
|
||||
|
||||
self.prediction = None
|
||||
if not self.return_feature:
|
||||
pred_num_classes = num_classes - 1
|
||||
self.prediction = nn.Linear(
|
||||
dim_model if encode_value else dim_input, pred_num_classes)
|
||||
|
||||
def init_weights(self):
|
||||
pass
|
||||
|
||||
def _get_position_index(self, length, batch_size, device=None):
|
||||
position_index = torch.arange(0, length, device=device)
|
||||
position_index = position_index.repeat([batch_size, 1])
|
||||
position_index = position_index.long()
|
||||
return position_index
|
||||
|
||||
def forward_train(self, feat, out_enc, targets_dict, img_metas):
|
||||
valid_ratios = [
|
||||
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||
] if self.mask else None
|
||||
|
||||
targets = targets_dict['padded_targets'].to(feat.device)
|
||||
|
||||
#
|
||||
n, c_enc, h, w = out_enc.size()
|
||||
assert c_enc == self.dim_model
|
||||
_, c_feat, _, _ = feat.size()
|
||||
assert c_feat == self.dim_input
|
||||
_, len_q = targets.size()
|
||||
assert len_q <= self.max_seq_len
|
||||
|
||||
position_index = self._get_position_index(len_q, n, feat.device)
|
||||
|
||||
position_out_enc = self.position_aware_module(out_enc)
|
||||
|
||||
query = self.embedding(position_index)
|
||||
query = query.permute(0, 2, 1).contiguous()
|
||||
key = position_out_enc.view(n, c_enc, h * w)
|
||||
if self.encode_value:
|
||||
value = out_enc.view(n, c_enc, h * w)
|
||||
else:
|
||||
value = feat.view(n, c_feat, h * w)
|
||||
|
||||
mask = None
|
||||
if valid_ratios is not None:
|
||||
mask = query.new_zeros((n, h, w))
|
||||
for i, valid_ratio in enumerate(valid_ratios):
|
||||
valid_width = min(w, math.ceil(w * valid_ratio))
|
||||
mask[i, :, valid_width:] = 1
|
||||
mask = mask.bool()
|
||||
mask = mask.view(n, h * w)
|
||||
|
||||
attn_out = self.attention_layer(query, key, value, mask)
|
||||
attn_out = attn_out.permute(0, 2, 1).contiguous() # [n, len_q, dim_v]
|
||||
|
||||
if self.return_feature:
|
||||
return attn_out
|
||||
|
||||
return self.prediction(attn_out)
|
||||
|
||||
def forward_test(self, feat, out_enc, img_metas):
|
||||
valid_ratios = [
|
||||
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||
] if self.mask else None
|
||||
|
||||
seq_len = self.max_seq_len
|
||||
n, c_enc, h, w = out_enc.size()
|
||||
assert c_enc == self.dim_model
|
||||
_, c_feat, _, _ = feat.size()
|
||||
assert c_feat == self.dim_input
|
||||
|
||||
position_index = self._get_position_index(seq_len, n, feat.device)
|
||||
|
||||
position_out_enc = self.position_aware_module(out_enc)
|
||||
|
||||
query = self.embedding(position_index)
|
||||
query = query.permute(0, 2, 1).contiguous()
|
||||
key = position_out_enc.view(n, c_enc, h * w)
|
||||
if self.encode_value:
|
||||
value = out_enc.view(n, c_enc, h * w)
|
||||
else:
|
||||
value = feat.view(n, c_feat, h * w)
|
||||
|
||||
mask = None
|
||||
if valid_ratios is not None:
|
||||
mask = query.new_zeros((n, h, w))
|
||||
for i, valid_ratio in enumerate(valid_ratios):
|
||||
valid_width = min(w, math.ceil(w * valid_ratio))
|
||||
mask[i, :, valid_width:] = 1
|
||||
mask = mask.bool()
|
||||
mask = mask.view(n, h * w)
|
||||
|
||||
attn_out = self.attention_layer(query, key, value, mask)
|
||||
attn_out = attn_out.permute(0, 2, 1).contiguous()
|
||||
|
||||
if self.return_feature:
|
||||
return attn_out
|
||||
|
||||
return self.prediction(attn_out)
|
|
@ -0,0 +1,107 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmocr.models.builder import DECODERS, build_decoder
|
||||
from mmocr.models.textrecog.layers import RobustScannerFusionLayer
|
||||
from .base_decoder import BaseDecoder
|
||||
|
||||
|
||||
@DECODERS.register_module()
|
||||
class RobustScannerDecoder(BaseDecoder):
|
||||
|
||||
def __init__(self,
|
||||
num_classes=None,
|
||||
dim_input=512,
|
||||
dim_model=128,
|
||||
max_seq_len=40,
|
||||
start_idx=0,
|
||||
mask=True,
|
||||
padding_idx=None,
|
||||
encode_value=False,
|
||||
hybrid_decoder=None,
|
||||
position_decoder=None):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.dim_input = dim_input
|
||||
self.dim_model = dim_model
|
||||
self.max_seq_len = max_seq_len
|
||||
self.encode_value = encode_value
|
||||
self.start_idx = start_idx
|
||||
self.padding_idx = padding_idx
|
||||
self.mask = mask
|
||||
|
||||
# init hybrid decoder
|
||||
hybrid_decoder.update(num_classes=self.num_classes)
|
||||
hybrid_decoder.update(dim_input=self.dim_input)
|
||||
hybrid_decoder.update(dim_model=self.dim_model)
|
||||
hybrid_decoder.update(start_idx=self.start_idx)
|
||||
hybrid_decoder.update(padding_idx=self.padding_idx)
|
||||
hybrid_decoder.update(max_seq_len=self.max_seq_len)
|
||||
hybrid_decoder.update(mask=self.mask)
|
||||
hybrid_decoder.update(encode_value=self.encode_value)
|
||||
hybrid_decoder.update(return_feature=True)
|
||||
|
||||
self.hybrid_decoder = build_decoder(hybrid_decoder)
|
||||
|
||||
# init position decoder
|
||||
position_decoder.update(num_classes=self.num_classes)
|
||||
position_decoder.update(dim_input=self.dim_input)
|
||||
position_decoder.update(dim_model=self.dim_model)
|
||||
position_decoder.update(max_seq_len=self.max_seq_len)
|
||||
position_decoder.update(mask=self.mask)
|
||||
position_decoder.update(encode_value=self.encode_value)
|
||||
position_decoder.update(return_feature=True)
|
||||
|
||||
self.position_decoder = build_decoder(position_decoder)
|
||||
|
||||
self.fusion_module = RobustScannerFusionLayer(
|
||||
self.dim_model if encode_value else dim_input)
|
||||
|
||||
pred_num_classes = num_classes - 1
|
||||
self.prediction = nn.Linear(dim_model if encode_value else dim_input,
|
||||
pred_num_classes)
|
||||
|
||||
def init_weights(self):
|
||||
pass
|
||||
|
||||
def forward_train(self, feat, out_enc, targets_dict, img_metas):
|
||||
hybrid_glimpse = self.hybrid_decoder.forward_train(
|
||||
feat, out_enc, targets_dict, img_metas)
|
||||
position_glimpse = self.position_decoder.forward_train(
|
||||
feat, out_enc, targets_dict, img_metas)
|
||||
|
||||
fusion_out = self.fusion_module(hybrid_glimpse, position_glimpse)
|
||||
|
||||
out = self.prediction(fusion_out)
|
||||
|
||||
return out
|
||||
|
||||
def forward_test(self, feat, out_enc, img_metas):
|
||||
seq_len = self.max_seq_len
|
||||
batch_size = feat.size(0)
|
||||
|
||||
decode_sequence = (feat.new_ones(
|
||||
(batch_size, seq_len)) * self.start_idx).long()
|
||||
|
||||
position_glimpse = self.position_decoder.forward_test(
|
||||
feat, out_enc, img_metas)
|
||||
|
||||
outputs = []
|
||||
for i in range(seq_len):
|
||||
hybrid_glimpse_step = self.hybrid_decoder.forward_test_step(
|
||||
feat, out_enc, decode_sequence, i, img_metas)
|
||||
|
||||
fusion_out = self.fusion_module(hybrid_glimpse_step,
|
||||
position_glimpse[:, i, :])
|
||||
|
||||
char_out = self.prediction(fusion_out)
|
||||
char_out = F.softmax(char_out, -1)
|
||||
outputs.append(char_out)
|
||||
_, max_idx = torch.max(char_out, dim=1, keepdim=False)
|
||||
if i < seq_len - 1:
|
||||
decode_sequence[:, i + 1] = max_idx
|
||||
|
||||
outputs = torch.stack(outputs, 1)
|
||||
|
||||
return outputs
|
|
@ -0,0 +1,165 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmocr.models.builder import DECODERS
|
||||
from mmocr.models.textrecog.layers import DotProductAttentionLayer
|
||||
from .base_decoder import BaseDecoder
|
||||
|
||||
|
||||
@DECODERS.register_module()
|
||||
class SequenceAttentionDecoder(BaseDecoder):
|
||||
|
||||
def __init__(self,
|
||||
num_classes=None,
|
||||
rnn_layers=2,
|
||||
dim_input=512,
|
||||
dim_model=128,
|
||||
max_seq_len=40,
|
||||
start_idx=0,
|
||||
mask=True,
|
||||
padding_idx=None,
|
||||
dropout_ratio=0,
|
||||
return_feature=False,
|
||||
encode_value=False):
|
||||
super().__init__()
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.dim_input = dim_input
|
||||
self.dim_model = dim_model
|
||||
self.return_feature = return_feature
|
||||
self.encode_value = encode_value
|
||||
self.max_seq_len = max_seq_len
|
||||
self.start_idx = start_idx
|
||||
self.mask = mask
|
||||
|
||||
self.embedding = nn.Embedding(
|
||||
self.num_classes, self.dim_model, padding_idx=padding_idx)
|
||||
|
||||
self.sequence_layer = nn.LSTM(
|
||||
input_size=dim_model,
|
||||
hidden_size=dim_model,
|
||||
num_layers=rnn_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout_ratio)
|
||||
|
||||
self.attention_layer = DotProductAttentionLayer()
|
||||
|
||||
self.prediction = None
|
||||
if not self.return_feature:
|
||||
pred_num_classes = num_classes - 1
|
||||
self.prediction = nn.Linear(
|
||||
dim_model if encode_value else dim_input, pred_num_classes)
|
||||
|
||||
def init_weights(self):
|
||||
pass
|
||||
|
||||
def forward_train(self, feat, out_enc, targets_dict, img_metas):
|
||||
valid_ratios = [
|
||||
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||
] if self.mask else None
|
||||
|
||||
targets = targets_dict['padded_targets'].to(feat.device)
|
||||
tgt_embedding = self.embedding(targets)
|
||||
|
||||
n, c_enc, h, w = out_enc.size()
|
||||
assert c_enc == self.dim_model
|
||||
_, c_feat, _, _ = feat.size()
|
||||
assert c_feat == self.dim_input
|
||||
_, len_q, c_q = tgt_embedding.size()
|
||||
assert c_q == self.dim_model
|
||||
assert len_q <= self.max_seq_len
|
||||
|
||||
query, _ = self.sequence_layer(tgt_embedding)
|
||||
query = query.permute(0, 2, 1).contiguous()
|
||||
key = out_enc.view(n, c_enc, h * w)
|
||||
if self.encode_value:
|
||||
value = key
|
||||
else:
|
||||
value = feat.view(n, c_feat, h * w)
|
||||
|
||||
mask = None
|
||||
if valid_ratios is not None:
|
||||
mask = query.new_zeros((n, h, w))
|
||||
for i, valid_ratio in enumerate(valid_ratios):
|
||||
valid_width = min(w, math.ceil(w * valid_ratio))
|
||||
mask[i, :, valid_width:] = 1
|
||||
mask = mask.bool()
|
||||
mask = mask.view(n, h * w)
|
||||
|
||||
attn_out = self.attention_layer(query, key, value, mask)
|
||||
attn_out = attn_out.permute(0, 2, 1).contiguous()
|
||||
|
||||
if self.return_feature:
|
||||
return attn_out
|
||||
|
||||
out = self.prediction(attn_out)
|
||||
|
||||
return out
|
||||
|
||||
def forward_test(self, feat, out_enc, img_metas):
|
||||
seq_len = self.max_seq_len
|
||||
batch_size = feat.size(0)
|
||||
|
||||
decode_sequence = (feat.new_ones(
|
||||
(batch_size, seq_len)) * self.start_idx).long()
|
||||
|
||||
outputs = []
|
||||
for i in range(seq_len):
|
||||
step_out = self.forward_test_step(feat, out_enc, decode_sequence,
|
||||
i, img_metas)
|
||||
outputs.append(step_out)
|
||||
_, max_idx = torch.max(step_out, dim=1, keepdim=False)
|
||||
if i < seq_len - 1:
|
||||
decode_sequence[:, i + 1] = max_idx
|
||||
|
||||
outputs = torch.stack(outputs, 1)
|
||||
|
||||
return outputs
|
||||
|
||||
def forward_test_step(self, feat, out_enc, decode_sequence, current_step,
|
||||
img_metas):
|
||||
valid_ratios = [
|
||||
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||
] if self.mask else None
|
||||
|
||||
embed = self.embedding(decode_sequence)
|
||||
|
||||
n, c_enc, h, w = out_enc.size()
|
||||
assert c_enc == self.dim_model
|
||||
_, c_feat, _, _ = feat.size()
|
||||
assert c_feat == self.dim_input
|
||||
_, _, c_q = embed.size()
|
||||
assert c_q == self.dim_model
|
||||
|
||||
query, _ = self.sequence_layer(embed)
|
||||
query = query.permute(0, 2, 1).contiguous()
|
||||
key = out_enc.view(n, c_enc, h * w)
|
||||
if self.encode_value:
|
||||
value = key
|
||||
else:
|
||||
value = feat.view(n, c_feat, h * w)
|
||||
|
||||
mask = None
|
||||
if valid_ratios is not None:
|
||||
mask = query.new_zeros((n, h, w))
|
||||
for i, valid_ratio in enumerate(valid_ratios):
|
||||
valid_width = min(w, math.ceil(w * valid_ratio))
|
||||
mask[i, :, valid_width:] = 1
|
||||
mask = mask.bool()
|
||||
mask = mask.view(n, h * w)
|
||||
|
||||
# [n, c, l]
|
||||
attn_out = self.attention_layer(query, key, value, mask)
|
||||
|
||||
out = attn_out[:, :, current_step]
|
||||
|
||||
if self.return_feature:
|
||||
return out
|
||||
|
||||
out = self.prediction(out)
|
||||
out = F.softmax(out, dim=-1)
|
||||
|
||||
return out
|
|
@ -0,0 +1,23 @@
|
|||
import torch.nn as nn
|
||||
from mmcv.cnn import xavier_init
|
||||
|
||||
from mmocr.models.builder import ENCODERS
|
||||
from .base_encoder import BaseEncoder
|
||||
|
||||
|
||||
@ENCODERS.register_module()
|
||||
class ChannelReductionEncoder(BaseEncoder):
|
||||
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
|
||||
self.layer = nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
xavier_init(m)
|
||||
|
||||
def forward(self, feat, img_metas=None):
|
||||
return self.layer(feat)
|
|
@ -0,0 +1,27 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class DotProductAttentionLayer(nn.Module):
|
||||
|
||||
def __init__(self, dim_model=None):
|
||||
super().__init__()
|
||||
|
||||
self.scale = dim_model**-0.5 if dim_model is not None else 1.
|
||||
|
||||
def forward(self, query, key, value, mask=None):
|
||||
n, seq_len = mask.size()
|
||||
logits = torch.matmul(query.permute(0, 2, 1), key) * self.scale
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.view(n, 1, seq_len)
|
||||
logits = logits.masked_fill(mask, float('-inf'))
|
||||
|
||||
weights = F.softmax(logits, dim=2)
|
||||
|
||||
glimpse = torch.matmul(weights, value.transpose(1, 2))
|
||||
|
||||
glimpse = glimpse.permute(0, 2, 1).contiguous()
|
||||
|
||||
return glimpse
|
|
@ -0,0 +1,35 @@
|
|||
import torch.nn as nn
|
||||
|
||||
|
||||
class PositionAwareLayer(nn.Module):
|
||||
|
||||
def __init__(self, dim_model, rnn_layers=2):
|
||||
super().__init__()
|
||||
|
||||
self.dim_model = dim_model
|
||||
|
||||
self.rnn = nn.LSTM(
|
||||
input_size=dim_model,
|
||||
hidden_size=dim_model,
|
||||
num_layers=rnn_layers,
|
||||
batch_first=True)
|
||||
|
||||
self.mixer = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
dim_model, dim_model, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(
|
||||
dim_model, dim_model, kernel_size=3, stride=1, padding=1))
|
||||
|
||||
def forward(self, img_feature):
|
||||
n, c, h, w = img_feature.size()
|
||||
|
||||
rnn_input = img_feature.permute(0, 2, 3, 1).contiguous()
|
||||
rnn_input = rnn_input.view(n * h, w, c)
|
||||
rnn_output, _ = self.rnn(rnn_input)
|
||||
rnn_output = rnn_output.view(n, h, w, c)
|
||||
rnn_output = rnn_output.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
out = self.mixer(rnn_output)
|
||||
|
||||
return out
|
|
@ -0,0 +1,22 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class RobustScannerFusionLayer(nn.Module):
|
||||
|
||||
def __init__(self, dim_model, dim=-1):
|
||||
super().__init__()
|
||||
|
||||
self.dim_model = dim_model
|
||||
self.dim = dim
|
||||
|
||||
self.linear_layer = nn.Linear(dim_model * 2, dim_model * 2)
|
||||
self.glu_layer = nn.GLU(dim=dim)
|
||||
|
||||
def forward(self, x0, x1):
|
||||
assert x0.size() == x1.size()
|
||||
fusion_input = torch.cat([x0, x1], self.dim)
|
||||
output = self.linear_layer(fusion_input)
|
||||
output = self.glu_layer(output)
|
||||
|
||||
return output
|
|
@ -0,0 +1,10 @@
|
|||
from mmdet.models.builder import DETECTORS
|
||||
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
class RobustScanner(EncodeDecodeRecognizer):
|
||||
"""Implementation of `RobustScanner.
|
||||
|
||||
<https://arxiv.org/pdf/2007.07542.pdf>
|
||||
"""
|
Binary file not shown.
Before Width: | Height: | Size: 11 KiB |
Binary file not shown.
After Width: | Height: | Size: 31 KiB |
Loading…
Reference in New Issue