mirror of https://github.com/open-mmlab/mmocr.git
add sar, seg and other components
parent
af78ffb407
commit
4ecd0cea8a
configs
_base_
recog_datasets
recog_models
mmocr
core/evaluation
models/textrecog
backbones
convertors
tests
test_utils/test_text
tools
|
@ -0,0 +1,96 @@
|
|||
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
gt_label_convertor = dict(
|
||||
type='SegConvertor', dict_type='DICT36', with_unknown=True, lower=True)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomPaddingOCR',
|
||||
max_ratio=[0.15, 0.2, 0.15, 0.2],
|
||||
box_type='char_quads'),
|
||||
dict(type='OpencvToPil'),
|
||||
dict(
|
||||
type='RandomRotateImageBox',
|
||||
min_angle=-17,
|
||||
max_angle=17,
|
||||
box_type='char_quads'),
|
||||
dict(type='PilToOpencv'),
|
||||
dict(
|
||||
type='ResizeOCR',
|
||||
height=64,
|
||||
min_width=64,
|
||||
max_width=512,
|
||||
keep_aspect_ratio=True),
|
||||
dict(
|
||||
type='OCRSegTargets',
|
||||
label_convertor=gt_label_convertor,
|
||||
box_type='char_quads'),
|
||||
dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
|
||||
dict(type='ToTensorOCR'),
|
||||
dict(type='FancyPCA'),
|
||||
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||
dict(
|
||||
type='CustomFormatBundle',
|
||||
keys=['gt_kernels'],
|
||||
visualize=dict(flag=False, boundary_key=None),
|
||||
call_super=False),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img', 'gt_kernels'],
|
||||
meta_keys=['filename', 'ori_shape', 'img_shape'])
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeOCR',
|
||||
height=64,
|
||||
min_width=64,
|
||||
max_width=None,
|
||||
keep_aspect_ratio=True),
|
||||
dict(type='ToTensorOCR'),
|
||||
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||
dict(type='CustomFormatBundle', call_super=False),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=['filename', 'ori_shape', 'img_shape'])
|
||||
]
|
||||
|
||||
prefix = 'tests/data/ocr_char_ann_toy_dataset/'
|
||||
train = dict(
|
||||
type='OCRSegDataset',
|
||||
img_prefix=prefix + 'imgs',
|
||||
ann_file=prefix + 'instances_train.txt',
|
||||
loader=dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=100,
|
||||
parser=dict(
|
||||
type='LineJsonParser', keys=['file_name', 'annotations', 'text'])),
|
||||
pipeline=train_pipeline,
|
||||
test_mode=True)
|
||||
|
||||
test = dict(
|
||||
type='OCRDataset',
|
||||
img_prefix=prefix + 'imgs',
|
||||
ann_file=prefix + 'instances_test.txt',
|
||||
loader=dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=1,
|
||||
parser=dict(
|
||||
type='LineStrParser',
|
||||
keys=['filename', 'text'],
|
||||
keys_idx=[0, 1],
|
||||
separator=' ')),
|
||||
pipeline=test_pipeline,
|
||||
test_mode=True)
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=8,
|
||||
workers_per_gpu=1,
|
||||
train=dict(type='ConcatDataset', datasets=[train]),
|
||||
val=dict(type='ConcatDataset', datasets=[test]),
|
||||
test=dict(type='ConcatDataset', datasets=[test]))
|
||||
|
||||
evaluation = dict(interval=1, metric='acc')
|
|
@ -0,0 +1,99 @@
|
|||
img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeOCR',
|
||||
height=32,
|
||||
min_width=32,
|
||||
max_width=160,
|
||||
keep_aspect_ratio=True),
|
||||
dict(type='ToTensorOCR'),
|
||||
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=[
|
||||
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
|
||||
]),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiRotateAugOCR',
|
||||
rotate_degrees=[0, 90, 270],
|
||||
transforms=[
|
||||
dict(
|
||||
type='ResizeOCR',
|
||||
height=32,
|
||||
min_width=32,
|
||||
max_width=160,
|
||||
keep_aspect_ratio=True),
|
||||
dict(type='ToTensorOCR'),
|
||||
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=[
|
||||
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
|
||||
]),
|
||||
])
|
||||
]
|
||||
|
||||
dataset_type = 'OCRDataset'
|
||||
img_prefix = 'tests/data/ocr_toy_dataset/imgs'
|
||||
train_anno_file1 = 'tests/data/ocr_toy_dataset/label.txt'
|
||||
train1 = dict(
|
||||
type=dataset_type,
|
||||
img_prefix=img_prefix,
|
||||
ann_file=train_anno_file1,
|
||||
loader=dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=100,
|
||||
parser=dict(
|
||||
type='LineStrParser',
|
||||
keys=['filename', 'text'],
|
||||
keys_idx=[0, 1],
|
||||
separator=' ')),
|
||||
pipeline=train_pipeline,
|
||||
test_mode=False)
|
||||
|
||||
train_anno_file2 = 'tests/data/ocr_toy_dataset/label.lmdb'
|
||||
train2 = dict(
|
||||
type=dataset_type,
|
||||
img_prefix=img_prefix,
|
||||
ann_file=train_anno_file2,
|
||||
loader=dict(
|
||||
type='LmdbLoader',
|
||||
repeat=100,
|
||||
parser=dict(
|
||||
type='LineStrParser',
|
||||
keys=['filename', 'text'],
|
||||
keys_idx=[0, 1],
|
||||
separator=' ')),
|
||||
pipeline=train_pipeline,
|
||||
test_mode=False)
|
||||
|
||||
test_anno_file1 = 'tests/data/ocr_toy_dataset/label.lmdb'
|
||||
test = dict(
|
||||
type=dataset_type,
|
||||
img_prefix=img_prefix,
|
||||
ann_file=test_anno_file1,
|
||||
loader=dict(
|
||||
type='LmdbLoader',
|
||||
repeat=1,
|
||||
parser=dict(
|
||||
type='LineStrParser',
|
||||
keys=['filename', 'text'],
|
||||
keys_idx=[0, 1],
|
||||
separator=' ')),
|
||||
pipeline=test_pipeline,
|
||||
test_mode=True)
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=16,
|
||||
workers_per_gpu=2,
|
||||
train=dict(type='ConcatDataset', datasets=[train1, train2]),
|
||||
val=dict(type='ConcatDataset', datasets=[test]),
|
||||
test=dict(type='ConcatDataset', datasets=[test]))
|
||||
|
||||
evaluation = dict(interval=1, metric='acc')
|
|
@ -0,0 +1,11 @@
|
|||
label_convertor = dict(
|
||||
type='CTCConvertor', dict_type='DICT90', with_unknown=False)
|
||||
|
||||
model = dict(
|
||||
type='CRNNNet',
|
||||
preprocessor=None,
|
||||
backbone=dict(type='VeryDeepVgg', leakyRelu=False),
|
||||
encoder=None,
|
||||
decoder=dict(type='CRNNDecoder', in_channels=512, rnn_flag=True),
|
||||
loss=dict(type='CTCLoss', flatten=False),
|
||||
label_convertor=label_convertor)
|
|
@ -0,0 +1,24 @@
|
|||
label_convertor = dict(
|
||||
type='AttnConvertor', dict_type='DICT90', with_unknown=True)
|
||||
|
||||
hybrid_decoder = dict(type='SequenceAttentionDecoder')
|
||||
|
||||
position_decoder = dict(type='PositionAttentionDecoder')
|
||||
|
||||
model = dict(
|
||||
type='RobustScanner',
|
||||
backbone=dict(type='ResNet31OCR'),
|
||||
encoder=dict(
|
||||
type='ChannelReductionEncoder',
|
||||
in_channels=512,
|
||||
out_channels=128,
|
||||
),
|
||||
decoder=dict(
|
||||
type='RobustScannerDecoder',
|
||||
dim_input=512,
|
||||
dim_model=128,
|
||||
hybrid_decoder=hybrid_decoder,
|
||||
position_decoder=position_decoder),
|
||||
loss=dict(type='SARLoss'),
|
||||
label_convertor=label_convertor,
|
||||
max_seq_len=30)
|
|
@ -0,0 +1,24 @@
|
|||
label_convertor = dict(
|
||||
type='AttnConvertor', dict_type='DICT90', with_unknown=True)
|
||||
|
||||
model = dict(
|
||||
type='SARNet',
|
||||
backbone=dict(type='ResNet31OCR'),
|
||||
encoder=dict(
|
||||
type='SAREncoder',
|
||||
enc_bi_rnn=False,
|
||||
enc_do_rnn=0.1,
|
||||
enc_gru=False,
|
||||
),
|
||||
decoder=dict(
|
||||
type='ParallelSARDecoder',
|
||||
enc_bi_rnn=False,
|
||||
dec_bi_rnn=False,
|
||||
dec_do_rnn=0,
|
||||
dec_gru=False,
|
||||
pred_dropout=0.1,
|
||||
d_k=512,
|
||||
pred_concat=True),
|
||||
loss=dict(type='SARLoss'),
|
||||
label_convertor=label_convertor,
|
||||
max_seq_len=30)
|
|
@ -0,0 +1,11 @@
|
|||
label_convertor = dict(
|
||||
type='AttnConvertor', dict_type='DICT90', with_unknown=False)
|
||||
|
||||
model = dict(
|
||||
type='TransformerNet',
|
||||
backbone=dict(type='ResNet31OCR'),
|
||||
encoder=dict(type='TFEncoder'),
|
||||
decoder=dict(type='TFDecoder'),
|
||||
loss=dict(type='TFLoss'),
|
||||
label_convertor=label_convertor,
|
||||
max_seq_len=40)
|
|
@ -0,0 +1,65 @@
|
|||
# Show, Attend and Read: A Simple and Strong Baseline for Irregular Text Recognition
|
||||
|
||||
## Introduction
|
||||
|
||||
[ALGORITHM]
|
||||
|
||||
```
|
||||
@inproceedings{li2019show,
|
||||
title={Show, attend and read: A simple and strong baseline for irregular text recognition},
|
||||
author={Li, Hui and Wang, Peng and Shen, Chunhua and Zhang, Guyu},
|
||||
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
|
||||
volume={33},
|
||||
number={01},
|
||||
pages={8610--8617},
|
||||
year={2019}
|
||||
}
|
||||
```
|
||||
|
||||
## Dataset
|
||||
|
||||
### Train Dataset
|
||||
|
||||
| trainset | instance_num | repeat_num | source |
|
||||
| :--------: | :----------: | :--------: | :----------------------: |
|
||||
| icdar_2011 | 3567 | 20 | real |
|
||||
| icdar_2013 | 848 | 20 | real |
|
||||
| icdar2015 | 4468 | 20 | real |
|
||||
| coco_text | 42142 | 20 | real |
|
||||
| IIIT5K | 2000 | 20 | real |
|
||||
| SynthText | 2400000 | 1 | synth |
|
||||
| SynthAdd | 1216889 | 1 | synth, 1.6m in [[1]](#1) |
|
||||
| Syn90k | 2400000 | 1 | synth |
|
||||
|
||||
### Test Dataset
|
||||
|
||||
| testset | instance_num | type |
|
||||
| :-----: | :----------: | :-------------------------: |
|
||||
| IIIT5K | 3000 | regular |
|
||||
| SVT | 647 | regular |
|
||||
| IC13 | 1015 | regular |
|
||||
| IC15 | 2077 | irregular |
|
||||
| SVTP | 645 | irregular, 639 in [[1]](#1) |
|
||||
| CT80 | 288 | irregular |
|
||||
|
||||
## Results and Models
|
||||
| Methods|Backbone|Decoder||Regular Text||||Irregular Text||download|
|
||||
| :-------------: | :-----: | :-----: | :-----: | :------: | :-----: | :----: | :-----: | :-----: | :-----: |:-----: |
|
||||
||||IIIT5K|SVT|IC13||IC15|SVTP|CT80|
|
||||
|[SAR](/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py)|R31-1/8-1/4|ParallelSARDecoder|95.0|89.6|93.7||79.0|82.2|88.9|[model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_parallel_decoder_academic-dba3a4a3.pth) | [config](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_parallel_decoder_academic.py) | [log](https://download.openmmlab.com/mmocr/textrecog/sar/20210327_154129.log.json) |
|
||||
|[SAR](configs/textrecog/sar/sar_r31_sequential_decoder_academic.py)|R31-1/8-1/4|SequentialSARDecoder|95.2|88.7|92.4||78.2|81.9|89.6|[model](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_sequential_decoder_academic-d06c9a8e.pth) | [config](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_sequential_decoder_academic.py) | [log](https://download.openmmlab.com/mmocr/textrecog/sar/20210330_105728.log.json)|
|
||||
|
||||
**Notes:**
|
||||
- `R31-1/8-1/4` means the height of feature from backbone is 1/8 of input image, where 1/4 for width.
|
||||
- We did not use beam search during decoding.
|
||||
- We implemented two kinds of decoder. Namely, `ParallelSARDecoder` and `SequentialSARDecoder`.
|
||||
- `ParallelSARDecoder`: Parallel decoding during training with `LSTM` layer. It would be faster.
|
||||
- `SequentialSARDecoder`: Sequential Decoding during training with `LSTMCell`. It would be easier to understand.
|
||||
- For train dataset.
|
||||
- We did not construct distinct data groups (20 groups in [[1]](#1)) to train the model group-by-group since it would render model training too complicated.
|
||||
- Instead, we randomly selected `2.4m` patches from `Syn90k`, `2.4m` from `SynthText` and `1.2m` from `SynthAdd`, and grouped all data together. See [config](https://download.openmmlab.com/mmocr/textrecog/sar/sar_r31_academic.py) for details.
|
||||
- We used 48 GPUs with `total_batch_size = 64 * 48` in the experiment above to speedup training, while keeping the `initial lr = 1e-3` unchanged.
|
||||
|
||||
## References
|
||||
|
||||
<a id="1">[1]</a> Li, Hui and Wang, Peng and Shen, Chunhua and Zhang, Guyu. Show, attend and read: A simple and strong baseline for irregular text recognition. In AAAI 2019.
|
|
@ -0,0 +1,219 @@
|
|||
_base_ = ['../../_base_/default_runtime.py']
|
||||
|
||||
label_convertor = dict(
|
||||
type='AttnConvertor', dict_type='DICT90', with_unknown=True)
|
||||
|
||||
model = dict(
|
||||
type='SARNet',
|
||||
backbone=dict(type='ResNet31OCR'),
|
||||
encoder=dict(
|
||||
type='SAREncoder',
|
||||
enc_bi_rnn=False,
|
||||
enc_do_rnn=0.1,
|
||||
enc_gru=False,
|
||||
),
|
||||
decoder=dict(
|
||||
type='ParallelSARDecoder',
|
||||
enc_bi_rnn=False,
|
||||
dec_bi_rnn=False,
|
||||
dec_do_rnn=0,
|
||||
dec_gru=False,
|
||||
pred_dropout=0.1,
|
||||
d_k=512,
|
||||
pred_concat=True),
|
||||
loss=dict(type='SARLoss'),
|
||||
label_convertor=label_convertor,
|
||||
max_seq_len=30)
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(type='Adam', lr=1e-3)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[3, 4])
|
||||
total_epochs = 5
|
||||
|
||||
img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeOCR',
|
||||
height=48,
|
||||
min_width=48,
|
||||
max_width=160,
|
||||
keep_aspect_ratio=True,
|
||||
width_downsample_ratio=0.25),
|
||||
dict(type='ToTensorOCR'),
|
||||
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=[
|
||||
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
|
||||
]),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiRotateAugOCR',
|
||||
rotate_degrees=[0, 90, 270],
|
||||
transforms=[
|
||||
dict(
|
||||
type='ResizeOCR',
|
||||
height=48,
|
||||
min_width=48,
|
||||
max_width=160,
|
||||
keep_aspect_ratio=True,
|
||||
width_downsample_ratio=0.25),
|
||||
dict(type='ToTensorOCR'),
|
||||
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=[
|
||||
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
|
||||
]),
|
||||
])
|
||||
]
|
||||
|
||||
dataset_type = 'OCRDataset'
|
||||
|
||||
train_prefix = 'data/mixture/'
|
||||
|
||||
train_img_prefix1 = train_prefix + 'icdar_2011'
|
||||
train_img_prefix2 = train_prefix + 'icdar_2013'
|
||||
train_img_prefix3 = train_prefix + 'icdar_2015'
|
||||
train_img_prefix4 = train_prefix + 'coco_text'
|
||||
train_img_prefix5 = train_prefix + 'III5K'
|
||||
train_img_prefix6 = train_prefix + 'SynthText_Add'
|
||||
train_img_prefix7 = train_prefix + 'SynthText'
|
||||
train_img_prefix8 = train_prefix + 'Syn90k'
|
||||
|
||||
train_ann_file1 = train_prefix + 'icdar_2011/train_label.txt',
|
||||
train_ann_file2 = train_prefix + 'icdar_2013/train_label.txt',
|
||||
train_ann_file3 = train_prefix + 'icdar_2015/train_label.txt',
|
||||
train_ann_file4 = train_prefix + 'coco_text/train_label.txt',
|
||||
train_ann_file5 = train_prefix + 'III5K/train_label.txt',
|
||||
train_ann_file6 = train_prefix + 'SynthText_Add/label.txt',
|
||||
train_ann_file7 = train_prefix + 'SynthText/shuffle_labels.txt',
|
||||
train_ann_file8 = train_prefix + 'Syn90k/shuffle_labels.txt'
|
||||
|
||||
train1 = dict(
|
||||
type=dataset_type,
|
||||
img_prefix=train_img_prefix1,
|
||||
ann_file=train_ann_file1,
|
||||
loader=dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=20,
|
||||
parser=dict(
|
||||
type='LineStrParser',
|
||||
keys=['filename', 'text'],
|
||||
keys_idx=[0, 1],
|
||||
separator=' ')),
|
||||
pipeline=train_pipeline,
|
||||
test_mode=False)
|
||||
|
||||
train2 = {key: value for key, value in train1.items()}
|
||||
train2['img_prefix'] = train_img_prefix2
|
||||
train2['ann_file'] = train_ann_file2
|
||||
|
||||
train3 = {key: value for key, value in train1.items()}
|
||||
train3['img_prefix'] = train_img_prefix3
|
||||
train3['ann_file'] = train_ann_file3
|
||||
|
||||
train4 = {key: value for key, value in train1.items()}
|
||||
train4['img_prefix'] = train_img_prefix4
|
||||
train4['ann_file'] = train_ann_file4
|
||||
|
||||
train5 = {key: value for key, value in train1.items()}
|
||||
train5['img_prefix'] = train_img_prefix5
|
||||
train5['ann_file'] = train_ann_file5
|
||||
|
||||
train6 = dict(
|
||||
type=dataset_type,
|
||||
img_prefix=train_img_prefix6,
|
||||
ann_file=train_ann_file6,
|
||||
loader=dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=1,
|
||||
parser=dict(
|
||||
type='LineStrParser',
|
||||
keys=['filename', 'text'],
|
||||
keys_idx=[0, 1],
|
||||
separator=' ')),
|
||||
pipeline=train_pipeline,
|
||||
test_mode=False)
|
||||
|
||||
train7 = {key: value for key, value in train6.items()}
|
||||
train7['img_prefix'] = train_img_prefix7
|
||||
train7['ann_file'] = train_ann_file7
|
||||
|
||||
train8 = {key: value for key, value in train6.items()}
|
||||
train8['img_prefix'] = train_img_prefix8
|
||||
train8['ann_file'] = train_ann_file8
|
||||
|
||||
test_prefix = 'data/mixture/'
|
||||
test_img_prefix1 = test_prefix + 'IIIT5K/'
|
||||
test_img_prefix2 = test_prefix + 'svt/'
|
||||
test_img_prefix3 = test_prefix + 'icdar_2013/'
|
||||
test_img_prefix4 = test_prefix + 'icdar_2015/'
|
||||
test_img_prefix5 = test_prefix + 'svtp/'
|
||||
test_img_prefix6 = test_prefix + 'ct80/'
|
||||
|
||||
test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt'
|
||||
test_ann_file2 = test_prefix + 'svt/test_label.txt'
|
||||
test_ann_file3 = test_prefix + 'icdar_2013/test_label_1015.txt'
|
||||
test_ann_file4 = test_prefix + 'icdar_2015/test_label.txt'
|
||||
test_ann_file5 = test_prefix + 'svtp/test_label.txt'
|
||||
test_ann_file6 = test_prefix + 'ct80/test_label.txt'
|
||||
|
||||
test1 = dict(
|
||||
type=dataset_type,
|
||||
img_prefix=test_img_prefix1,
|
||||
ann_file=test_ann_file1,
|
||||
loader=dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=1,
|
||||
parser=dict(
|
||||
type='LineStrParser',
|
||||
keys=['filename', 'text'],
|
||||
keys_idx=[0, 1],
|
||||
separator=' ')),
|
||||
pipeline=test_pipeline,
|
||||
test_mode=True)
|
||||
|
||||
test2 = {key: value for key, value in test1.items()}
|
||||
test2['img_prefix'] = test_img_prefix2
|
||||
test2['ann_file'] = test_ann_file2
|
||||
|
||||
test3 = {key: value for key, value in test1.items()}
|
||||
test3['img_prefix'] = test_img_prefix3
|
||||
test3['ann_file'] = test_ann_file3
|
||||
|
||||
test4 = {key: value for key, value in test1.items()}
|
||||
test4['img_prefix'] = test_img_prefix4
|
||||
test4['ann_file'] = test_ann_file4
|
||||
|
||||
test5 = {key: value for key, value in test1.items()}
|
||||
test5['img_prefix'] = test_img_prefix5
|
||||
test5['ann_file'] = test_ann_file5
|
||||
|
||||
test6 = {key: value for key, value in test1.items()}
|
||||
test6['img_prefix'] = test_img_prefix6
|
||||
test6['ann_file'] = test_ann_file6
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=64,
|
||||
workers_per_gpu=2,
|
||||
train=dict(
|
||||
type='ConcatDataset',
|
||||
datasets=[
|
||||
train1, train2, train3, train4, train5, train6, train7, train8
|
||||
]),
|
||||
val=dict(
|
||||
type='ConcatDataset',
|
||||
datasets=[test1, test2, test3, test4, test5, test6]),
|
||||
test=dict(
|
||||
type='ConcatDataset',
|
||||
datasets=[test1, test2, test3, test4, test5, test6]))
|
||||
|
||||
evaluation = dict(interval=1, metric='acc')
|
|
@ -0,0 +1,110 @@
|
|||
_base_ = [
|
||||
'../../_base_/default_runtime.py', '../../_base_/recog_models/sar.py'
|
||||
]
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(type='Adam', lr=1e-3)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[3, 4])
|
||||
total_epochs = 5
|
||||
|
||||
img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeOCR',
|
||||
height=48,
|
||||
min_width=48,
|
||||
max_width=160,
|
||||
keep_aspect_ratio=True),
|
||||
dict(type='ToTensorOCR'),
|
||||
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=[
|
||||
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
|
||||
]),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiRotateAugOCR',
|
||||
rotate_degrees=[0, 90, 270],
|
||||
transforms=[
|
||||
dict(
|
||||
type='ResizeOCR',
|
||||
height=48,
|
||||
min_width=48,
|
||||
max_width=160,
|
||||
keep_aspect_ratio=True),
|
||||
dict(type='ToTensorOCR'),
|
||||
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=[
|
||||
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
|
||||
]),
|
||||
])
|
||||
]
|
||||
|
||||
dataset_type = 'OCRDataset'
|
||||
img_prefix = 'tests/data/ocr_toy_dataset/imgs'
|
||||
train_anno_file1 = 'tests/data/ocr_toy_dataset/label.txt'
|
||||
train1 = dict(
|
||||
type=dataset_type,
|
||||
img_prefix=img_prefix,
|
||||
ann_file=train_anno_file1,
|
||||
loader=dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=100,
|
||||
parser=dict(
|
||||
type='LineStrParser',
|
||||
keys=['filename', 'text'],
|
||||
keys_idx=[0, 1],
|
||||
separator=' ')),
|
||||
pipeline=train_pipeline,
|
||||
test_mode=False)
|
||||
|
||||
train_anno_file2 = 'tests/data/ocr_toy_dataset/label.lmdb'
|
||||
train2 = dict(
|
||||
type=dataset_type,
|
||||
img_prefix=img_prefix,
|
||||
ann_file=train_anno_file2,
|
||||
loader=dict(
|
||||
type='LmdbLoader',
|
||||
repeat=100,
|
||||
parser=dict(
|
||||
type='LineStrParser',
|
||||
keys=['filename', 'text'],
|
||||
keys_idx=[0, 1],
|
||||
separator=' ')),
|
||||
pipeline=train_pipeline,
|
||||
test_mode=False)
|
||||
|
||||
test_anno_file1 = 'tests/data/ocr_toy_dataset/label.lmdb'
|
||||
test = dict(
|
||||
type=dataset_type,
|
||||
img_prefix=img_prefix,
|
||||
ann_file=test_anno_file1,
|
||||
loader=dict(
|
||||
type='LmdbLoader',
|
||||
repeat=1,
|
||||
parser=dict(
|
||||
type='LineStrParser',
|
||||
keys=['filename', 'text'],
|
||||
keys_idx=[0, 1],
|
||||
separator=' ')),
|
||||
pipeline=test_pipeline,
|
||||
test_mode=True)
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=16,
|
||||
workers_per_gpu=2,
|
||||
train=dict(type='ConcatDataset', datasets=[train1, train2]),
|
||||
val=dict(type='ConcatDataset', datasets=[test]),
|
||||
test=dict(type='ConcatDataset', datasets=[test]))
|
||||
|
||||
evaluation = dict(interval=1, metric='acc')
|
|
@ -0,0 +1,219 @@
|
|||
_base_ = ['../../_base_/default_runtime.py']
|
||||
|
||||
label_convertor = dict(
|
||||
type='AttnConvertor', dict_type='DICT90', with_unknown=True)
|
||||
|
||||
model = dict(
|
||||
type='SARNet',
|
||||
backbone=dict(type='ResNet31OCR'),
|
||||
encoder=dict(
|
||||
type='SAREncoder',
|
||||
enc_bi_rnn=False,
|
||||
enc_do_rnn=0.1,
|
||||
enc_gru=False,
|
||||
),
|
||||
decoder=dict(
|
||||
type='SequentialSARDecoder',
|
||||
enc_bi_rnn=False,
|
||||
dec_bi_rnn=False,
|
||||
dec_do_rnn=0,
|
||||
dec_gru=False,
|
||||
pred_dropout=0.1,
|
||||
d_k=512,
|
||||
pred_concat=True),
|
||||
loss=dict(type='SARLoss'),
|
||||
label_convertor=label_convertor,
|
||||
max_seq_len=30)
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(type='Adam', lr=1e-3)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[3, 4])
|
||||
total_epochs = 5
|
||||
|
||||
img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeOCR',
|
||||
height=48,
|
||||
min_width=48,
|
||||
max_width=160,
|
||||
keep_aspect_ratio=True,
|
||||
width_downsample_ratio=0.25),
|
||||
dict(type='ToTensorOCR'),
|
||||
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=[
|
||||
'filename', 'ori_shape', 'img_shape', 'text', 'valid_ratio'
|
||||
]),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiRotateAugOCR',
|
||||
rotate_degrees=[0, 90, 270],
|
||||
transforms=[
|
||||
dict(
|
||||
type='ResizeOCR',
|
||||
height=48,
|
||||
min_width=48,
|
||||
max_width=160,
|
||||
keep_aspect_ratio=True,
|
||||
width_downsample_ratio=0.25),
|
||||
dict(type='ToTensorOCR'),
|
||||
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=[
|
||||
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
|
||||
]),
|
||||
])
|
||||
]
|
||||
|
||||
dataset_type = 'OCRDataset'
|
||||
|
||||
train_prefix = 'data/mixture/'
|
||||
|
||||
train_img_prefix1 = train_prefix + 'icdar_2011'
|
||||
train_img_prefix2 = train_prefix + 'icdar_2013'
|
||||
train_img_prefix3 = train_prefix + 'icdar_2015'
|
||||
train_img_prefix4 = train_prefix + 'coco_text'
|
||||
train_img_prefix5 = train_prefix + 'III5K'
|
||||
train_img_prefix6 = train_prefix + 'SynthText_Add'
|
||||
train_img_prefix7 = train_prefix + 'SynthText'
|
||||
train_img_prefix8 = train_prefix + 'Syn90k'
|
||||
|
||||
train_ann_file1 = train_prefix + 'icdar_2011/train_label.txt',
|
||||
train_ann_file2 = train_prefix + 'icdar_2013/train_label.txt',
|
||||
train_ann_file3 = train_prefix + 'icdar_2015/train_label.txt',
|
||||
train_ann_file4 = train_prefix + 'coco_text/train_label.txt',
|
||||
train_ann_file5 = train_prefix + 'III5K/train_label.txt',
|
||||
train_ann_file6 = train_prefix + 'SynthText_Add/label.txt',
|
||||
train_ann_file7 = train_prefix + 'SynthText/shuffle_labels.txt',
|
||||
train_ann_file8 = train_prefix + 'Syn90k/shuffle_labels.txt'
|
||||
|
||||
train1 = dict(
|
||||
type=dataset_type,
|
||||
img_prefix=train_img_prefix1,
|
||||
ann_file=train_ann_file1,
|
||||
loader=dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=20,
|
||||
parser=dict(
|
||||
type='LineStrParser',
|
||||
keys=['filename', 'text'],
|
||||
keys_idx=[0, 1],
|
||||
separator=' ')),
|
||||
pipeline=train_pipeline,
|
||||
test_mode=False)
|
||||
|
||||
train2 = {key: value for key, value in train1.items()}
|
||||
train2['img_prefix'] = train_img_prefix2
|
||||
train2['ann_file'] = train_ann_file2
|
||||
|
||||
train3 = {key: value for key, value in train1.items()}
|
||||
train3['img_prefix'] = train_img_prefix3
|
||||
train3['ann_file'] = train_ann_file3
|
||||
|
||||
train4 = {key: value for key, value in train1.items()}
|
||||
train4['img_prefix'] = train_img_prefix4
|
||||
train4['ann_file'] = train_ann_file4
|
||||
|
||||
train5 = {key: value for key, value in train1.items()}
|
||||
train5['img_prefix'] = train_img_prefix5
|
||||
train5['ann_file'] = train_ann_file5
|
||||
|
||||
train6 = dict(
|
||||
type=dataset_type,
|
||||
img_prefix=train_img_prefix6,
|
||||
ann_file=train_ann_file6,
|
||||
loader=dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=1,
|
||||
parser=dict(
|
||||
type='LineStrParser',
|
||||
keys=['filename', 'text'],
|
||||
keys_idx=[0, 1],
|
||||
separator=' ')),
|
||||
pipeline=train_pipeline,
|
||||
test_mode=False)
|
||||
|
||||
train7 = {key: value for key, value in train6.items()}
|
||||
train7['img_prefix'] = train_img_prefix7
|
||||
train7['ann_file'] = train_ann_file7
|
||||
|
||||
train8 = {key: value for key, value in train6.items()}
|
||||
train8['img_prefix'] = train_img_prefix8
|
||||
train8['ann_file'] = train_ann_file8
|
||||
|
||||
test_prefix = 'data/mixture/'
|
||||
test_img_prefix1 = test_prefix + 'IIIT5K/'
|
||||
test_img_prefix2 = test_prefix + 'svt/'
|
||||
test_img_prefix3 = test_prefix + 'icdar_2013/'
|
||||
test_img_prefix4 = test_prefix + 'icdar_2015/'
|
||||
test_img_prefix5 = test_prefix + 'svtp/'
|
||||
test_img_prefix6 = test_prefix + 'ct80/'
|
||||
|
||||
test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt'
|
||||
test_ann_file2 = test_prefix + 'svt/test_label.txt'
|
||||
test_ann_file3 = test_prefix + 'icdar_2013/test_label_1015.txt'
|
||||
test_ann_file4 = test_prefix + 'icdar_2015/test_label.txt'
|
||||
test_ann_file5 = test_prefix + 'svtp/test_label.txt'
|
||||
test_ann_file6 = test_prefix + 'ct80/test_label.txt'
|
||||
|
||||
test1 = dict(
|
||||
type=dataset_type,
|
||||
img_prefix=test_img_prefix1,
|
||||
ann_file=test_ann_file1,
|
||||
loader=dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=1,
|
||||
parser=dict(
|
||||
type='LineStrParser',
|
||||
keys=['filename', 'text'],
|
||||
keys_idx=[0, 1],
|
||||
separator=' ')),
|
||||
pipeline=test_pipeline,
|
||||
test_mode=True)
|
||||
|
||||
test2 = {key: value for key, value in test1.items()}
|
||||
test2['img_prefix'] = test_img_prefix2
|
||||
test2['ann_file'] = test_ann_file2
|
||||
|
||||
test3 = {key: value for key, value in test1.items()}
|
||||
test3['img_prefix'] = test_img_prefix3
|
||||
test3['ann_file'] = test_ann_file3
|
||||
|
||||
test4 = {key: value for key, value in test1.items()}
|
||||
test4['img_prefix'] = test_img_prefix4
|
||||
test4['ann_file'] = test_ann_file4
|
||||
|
||||
test5 = {key: value for key, value in test1.items()}
|
||||
test5['img_prefix'] = test_img_prefix5
|
||||
test5['ann_file'] = test_ann_file5
|
||||
|
||||
test6 = {key: value for key, value in test1.items()}
|
||||
test6['img_prefix'] = test_img_prefix6
|
||||
test6['ann_file'] = test_ann_file6
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=64,
|
||||
workers_per_gpu=2,
|
||||
train=dict(
|
||||
type='ConcatDataset',
|
||||
datasets=[
|
||||
train1, train2, train3, train4, train5, train6, train7, train8
|
||||
]),
|
||||
val=dict(
|
||||
type='ConcatDataset',
|
||||
datasets=[test1, test2, test3, test4, test5, test6]),
|
||||
test=dict(
|
||||
type='ConcatDataset',
|
||||
datasets=[test1, test2, test3, test4, test5, test6]))
|
||||
|
||||
evaluation = dict(interval=1, metric='acc')
|
|
@ -0,0 +1,36 @@
|
|||
# Baseline of segmentation based text recognition method.
|
||||
|
||||
## Introduction
|
||||
|
||||
A Baseline Method for Segmentation based Text Recognition.
|
||||
|
||||
[ALGORITHM]
|
||||
|
||||
## Dataset
|
||||
|
||||
### Train Dataset
|
||||
|
||||
| trainset | instance_num | repeat_num | source |
|
||||
| :-------: | :----------: | :--------: | :----: |
|
||||
| SynthText | 7266686 | 1 | synth |
|
||||
|
||||
### Test Dataset
|
||||
|
||||
| testset | instance_num | type |
|
||||
| :-----: | :----------: | :-------: |
|
||||
| IIIT5K | 3000 | regular |
|
||||
| SVT | 647 | regular |
|
||||
| IC13 | 1015 | regular |
|
||||
| CT80 | 288 | irregular |
|
||||
|
||||
## Results and Models
|
||||
|Backbone|Neck|Head|||Regular Text|||Irregular Text|base_lr|batch_size/gpu|gpus|download
|
||||
| :-------------: | :-----: | :-----: | :------: | :-----: | :----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: |
|
||||
|||||IIIT5K|SVT|IC13||CT80|
|
||||
|R31-1/16|FPNOCR|1x||90.9|81.8|90.7||80.9|1e-4|16|4|[model](https://download.openmmlab.com/mmocr/textrecog/seg/seg_r31_1by16_fpnocr_academic-0c50e163.pth) | [config](https://download.openmmlab.com/mmocr/textrecog/seg/seg_r31_1by16_fpnocr_academic.py) | [log](https://download.openmmlab.com/mmocr/textrecog/seg/20210325_112835.log.json) |
|
||||
|
||||
|
||||
**Notes:**
|
||||
|
||||
- `R31-1/16` means the size (both height and width ) of feature from backbone is 1/16 of input image.
|
||||
- `1x` means the size (both height and width) of feature from head is the same with input image.
|
|
@ -0,0 +1,160 @@
|
|||
_base_ = ['../../_base_/default_runtime.py']
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(type='Adam', lr=1e-4)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[3, 4])
|
||||
total_epochs = 5
|
||||
|
||||
label_convertor = dict(
|
||||
type='SegConvertor', dict_type='DICT36', with_unknown=True, lower=True)
|
||||
|
||||
model = dict(
|
||||
type='SegRecognizer',
|
||||
backbone=dict(
|
||||
type='ResNet31OCR',
|
||||
layers=[1, 2, 5, 3],
|
||||
channels=[32, 64, 128, 256, 512, 512],
|
||||
out_indices=[0, 1, 2, 3],
|
||||
stage4_pool_cfg=dict(kernel_size=2, stride=2),
|
||||
last_stage_pool=True),
|
||||
neck=dict(
|
||||
type='FPNOCR', in_channels=[128, 256, 512, 512], out_channels=256),
|
||||
head=dict(
|
||||
type='SegHead',
|
||||
in_channels=256,
|
||||
upsample_param=dict(scale_factor=2.0, mode='nearest')),
|
||||
loss=dict(
|
||||
type='SegLoss', seg_downsample_ratio=1.0, seg_with_loss_weight=True),
|
||||
label_convertor=label_convertor)
|
||||
|
||||
find_unused_parameters = True
|
||||
|
||||
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
gt_label_convertor = dict(
|
||||
type='SegConvertor', dict_type='DICT36', with_unknown=True, lower=True)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomPaddingOCR',
|
||||
max_ratio=[0.15, 0.2, 0.15, 0.2],
|
||||
box_type='char_quads'),
|
||||
dict(type='OpencvToPil'),
|
||||
dict(
|
||||
type='RandomRotateImageBox',
|
||||
min_angle=-17,
|
||||
max_angle=17,
|
||||
box_type='char_quads'),
|
||||
dict(type='PilToOpencv'),
|
||||
dict(
|
||||
type='ResizeOCR',
|
||||
height=64,
|
||||
min_width=64,
|
||||
max_width=512,
|
||||
keep_aspect_ratio=True),
|
||||
dict(
|
||||
type='OCRSegTargets',
|
||||
label_convertor=gt_label_convertor,
|
||||
box_type='char_quads'),
|
||||
dict(type='RandomRotateTextDet', rotate_ratio=0.5, max_angle=15),
|
||||
dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
|
||||
dict(type='ToTensorOCR'),
|
||||
dict(type='FancyPCA'),
|
||||
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||
dict(
|
||||
type='CustomFormatBundle',
|
||||
keys=['gt_kernels'],
|
||||
visualize=dict(flag=False, boundary_key=None),
|
||||
call_super=False),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img', 'gt_kernels'],
|
||||
meta_keys=['filename', 'ori_shape', 'img_shape'])
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeOCR',
|
||||
height=64,
|
||||
min_width=64,
|
||||
max_width=None,
|
||||
keep_aspect_ratio=True),
|
||||
dict(type='ToTensorOCR'),
|
||||
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||
dict(type='CustomFormatBundle', call_super=False),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=['filename', 'ori_shape', 'img_shape'])
|
||||
]
|
||||
|
||||
train_img_root = 'data/mixture/'
|
||||
|
||||
train_img_prefix = train_img_root + 'SynthText'
|
||||
|
||||
train_ann_file = train_img_root + 'SynthText/instances_train.txt'
|
||||
|
||||
train = dict(
|
||||
type='OCRSegDataset',
|
||||
img_prefix=train_img_prefix,
|
||||
ann_file=train_ann_file,
|
||||
loader=dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=1,
|
||||
parser=dict(
|
||||
type='LineJsonParser', keys=['file_name', 'annotations', 'text'])),
|
||||
pipeline=train_pipeline,
|
||||
test_mode=False)
|
||||
|
||||
dataset_type = 'OCRDataset'
|
||||
test_prefix = 'data/mixture/'
|
||||
|
||||
test_img_prefix1 = test_prefix + 'IIIT5K/'
|
||||
test_img_prefix2 = test_prefix + 'svt/'
|
||||
test_img_prefix3 = test_prefix + 'icdar_2013/'
|
||||
test_img_prefix4 = test_prefix + 'ct80/'
|
||||
|
||||
test_ann_file1 = test_prefix + 'IIIT5K/test_label.txt'
|
||||
test_ann_file2 = test_prefix + 'svt/test_label.txt'
|
||||
test_ann_file3 = test_prefix + 'icdar_2013/test_label_1015.txt'
|
||||
test_ann_file4 = test_prefix + 'ct80/test_label.txt'
|
||||
|
||||
test1 = dict(
|
||||
type=dataset_type,
|
||||
img_prefix=test_img_prefix1,
|
||||
ann_file=test_ann_file1,
|
||||
loader=dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=1,
|
||||
parser=dict(
|
||||
type='LineStrParser',
|
||||
keys=['filename', 'text'],
|
||||
keys_idx=[0, 1],
|
||||
separator=' ')),
|
||||
pipeline=test_pipeline,
|
||||
test_mode=True)
|
||||
|
||||
test2 = {key: value for key, value in test1.items()}
|
||||
test2['img_prefix'] = test_img_prefix2
|
||||
test2['ann_file'] = test_ann_file2
|
||||
|
||||
test3 = {key: value for key, value in test1.items()}
|
||||
test3['img_prefix'] = test_img_prefix3
|
||||
test3['ann_file'] = test_ann_file3
|
||||
|
||||
test4 = {key: value for key, value in test1.items()}
|
||||
test4['img_prefix'] = test_img_prefix4
|
||||
test4['ann_file'] = test_ann_file4
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=16,
|
||||
workers_per_gpu=2,
|
||||
train=dict(type='ConcatDataset', datasets=[train]),
|
||||
val=dict(type='ConcatDataset', datasets=[test1, test2, test3, test4]),
|
||||
test=dict(type='ConcatDataset', datasets=[test1, test2, test3, test4]))
|
||||
|
||||
evaluation = dict(interval=1, metric='acc')
|
|
@ -0,0 +1,35 @@
|
|||
_base_ = [
|
||||
'../../_base_/default_runtime.py',
|
||||
'../../_base_/recog_datasets/seg_toy_dataset.py'
|
||||
]
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(type='Adam', lr=1e-4)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[3, 4])
|
||||
total_epochs = 5
|
||||
|
||||
label_convertor = dict(
|
||||
type='SegConvertor', dict_type='DICT36', with_unknown=True, lower=True)
|
||||
|
||||
model = dict(
|
||||
type='SegRecognizer',
|
||||
backbone=dict(
|
||||
type='ResNet31OCR',
|
||||
layers=[1, 2, 5, 3],
|
||||
channels=[32, 64, 128, 256, 512, 512],
|
||||
out_indices=[0, 1, 2, 3],
|
||||
stage4_pool_cfg=dict(kernel_size=2, stride=2),
|
||||
last_stage_pool=True),
|
||||
neck=dict(
|
||||
type='FPNOCR', in_channels=[128, 256, 512, 512], out_channels=256),
|
||||
head=dict(
|
||||
type='SegHead',
|
||||
in_channels=256,
|
||||
upsample_param=dict(scale_factor=2.0, mode='nearest')),
|
||||
loss=dict(
|
||||
type='SegLoss', seg_downsample_ratio=1.0, seg_with_loss_weight=False),
|
||||
label_convertor=label_convertor)
|
||||
|
||||
find_unused_parameters = True
|
|
@ -0,0 +1,30 @@
|
|||
## Introduction
|
||||
|
||||
### Train Dataset
|
||||
|
||||
| trainset | instance_num | repeat_num | note |
|
||||
| :--------: | :----------: | :--------: | :---: |
|
||||
| icdar_2011 | 3567 | 20 | real |
|
||||
| icdar_2013 | 848 | 20 | real |
|
||||
| icdar2015 | 4468 | 20 | real |
|
||||
| coco_text | 42142 | 20 | real |
|
||||
| IIIT5K | 2000 | 20 | real |
|
||||
| SynthText | 2400000 | 1 | synth |
|
||||
|
||||
### Test Dataset
|
||||
|
||||
| testset | instance_num | note |
|
||||
| :-----: | :----------: | :-------------------------: |
|
||||
| IIIT5K | 3000 | regular |
|
||||
| SVT | 647 | regular |
|
||||
| IC13 | 1015 | regular |
|
||||
| IC15 | 2077 | irregular |
|
||||
| SVTP | 645 | irregular, 639 in [[1]](#1) |
|
||||
| CT80 | 288 | irregular |
|
||||
|
||||
## Results and models
|
||||
|
||||
| methods | | Regular Text | | | | Irregular Text | | download |
|
||||
| :---------: | :----: | :----------: | :--: | :-: | :--: | :------------: | :--: | :------: |
|
||||
| | IIIT5K | SVT | IC13 | | IC15 | SVTP | CT80 |
|
||||
| Transformer | 93.3 | 85.8 | 91.3 | | 73.2 | 76.6 | 87.8 | |
|
|
@ -0,0 +1,12 @@
|
|||
_base_ = [
|
||||
'../../_base_/default_runtime.py',
|
||||
'../../_base_/recog_models/transformer.py',
|
||||
'../../_base_/recog_datasets/toy_dataset.py'
|
||||
]
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(type='Adadelta', lr=1)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[3, 4])
|
||||
total_epochs = 5
|
|
@ -0,0 +1,3 @@
|
|||
from .version import __version__, short_version
|
||||
|
||||
__all__ = ['__version__', 'short_version']
|
|
@ -0,0 +1,133 @@
|
|||
import re
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
import Levenshtein
|
||||
|
||||
|
||||
def cal_true_positive_char(pred, gt):
|
||||
"""Calculate correct character number in prediction.
|
||||
|
||||
Args:
|
||||
pred (str): Prediction text.
|
||||
gt (str): Ground truth text.
|
||||
|
||||
Returns:
|
||||
true_positive_char_num (int): The true positive number.
|
||||
"""
|
||||
|
||||
all_opt = SequenceMatcher(None, pred, gt)
|
||||
true_positive_char_num = 0
|
||||
for opt, _, _, s2, e2 in all_opt.get_opcodes():
|
||||
if opt == 'equal':
|
||||
true_positive_char_num += (e2 - s2)
|
||||
else:
|
||||
pass
|
||||
return true_positive_char_num
|
||||
|
||||
|
||||
def count_matches(pred_texts, gt_texts):
|
||||
"""Count the various match number for metric calculation.
|
||||
|
||||
Args:
|
||||
pred_texts (list[str]): Predicted text string.
|
||||
gt_texts (list[str]): Ground truth text string.
|
||||
|
||||
Returns:
|
||||
match_res: (dict[str: int]): Match number used for
|
||||
metric calculation.
|
||||
"""
|
||||
match_res = {
|
||||
'gt_char_num': 0,
|
||||
'pred_char_num': 0,
|
||||
'true_positive_char_num': 0,
|
||||
'gt_word_num': 0,
|
||||
'match_word_num': 0,
|
||||
'match_word_ignore_case': 0,
|
||||
'match_word_ignore_case_symbol': 0
|
||||
}
|
||||
comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
|
||||
norm_ed_sum = 0.0
|
||||
for pred_text, gt_text in zip(pred_texts, gt_texts):
|
||||
if gt_text == pred_text:
|
||||
match_res['match_word_num'] += 1
|
||||
gt_text_lower = gt_text.lower()
|
||||
pred_text_lower = pred_text.lower()
|
||||
if gt_text_lower == pred_text_lower:
|
||||
match_res['match_word_ignore_case'] += 1
|
||||
gt_text_lower_ignore = comp.sub('', gt_text_lower)
|
||||
pred_text_lower_ignore = comp.sub('', pred_text_lower)
|
||||
if gt_text_lower_ignore == pred_text_lower_ignore:
|
||||
match_res['match_word_ignore_case_symbol'] += 1
|
||||
match_res['gt_word_num'] += 1
|
||||
|
||||
# normalized edit distance
|
||||
edit_dist = Levenshtein.distance(pred_text_lower_ignore,
|
||||
gt_text_lower_ignore)
|
||||
norm_ed = float(edit_dist) / max(1, len(gt_text_lower_ignore),
|
||||
len(pred_text_lower_ignore))
|
||||
norm_ed_sum += norm_ed
|
||||
|
||||
# number to calculate char level recall & precision
|
||||
match_res['gt_char_num'] += len(gt_text_lower_ignore)
|
||||
match_res['pred_char_num'] += len(pred_text_lower_ignore)
|
||||
true_positive_char_num = cal_true_positive_char(
|
||||
pred_text_lower_ignore, gt_text_lower_ignore)
|
||||
match_res['true_positive_char_num'] += true_positive_char_num
|
||||
|
||||
normalized_edit_distance = norm_ed_sum / max(1, len(gt_texts))
|
||||
match_res['ned'] = normalized_edit_distance
|
||||
|
||||
return match_res
|
||||
|
||||
|
||||
def eval_ocr_metric(pred_texts, gt_texts):
|
||||
"""Evaluate the text recognition performance with metric: word accuracy and
|
||||
1-N.E.D. See https://rrc.cvc.uab.es/?ch=14&com=tasks for details.
|
||||
|
||||
Args:
|
||||
pred_texts (list[str]): Text strings of prediction.
|
||||
gt_texts (list[str]): Text strings of ground truth.
|
||||
|
||||
Returns:
|
||||
eval_res (dict[str: float]): Metric dict for text recognition, include:
|
||||
- word_acc: Accuracy in word level.
|
||||
- word_acc_ignore_case: Accuracy in word level, ignore letter case.
|
||||
- word_acc_ignore_case_symbol: Accuracy in word level, ignore
|
||||
letter case and symbol. (default metric for
|
||||
academic evaluation)
|
||||
- char_recall: Recall in character level, ignore
|
||||
letter case and symbol.
|
||||
- char_precision: Precision in character level, ignore
|
||||
letter case and symbol.
|
||||
- 1-N.E.D: 1 - normalized_edit_distance.
|
||||
"""
|
||||
assert isinstance(pred_texts, list)
|
||||
assert isinstance(gt_texts, list)
|
||||
assert len(pred_texts) == len(gt_texts)
|
||||
|
||||
match_res = count_matches(pred_texts, gt_texts)
|
||||
eps = 1e-8
|
||||
char_recall = 1.0 * match_res['true_positive_char_num'] / (
|
||||
eps + match_res['gt_char_num'])
|
||||
char_precision = 1.0 * match_res['true_positive_char_num'] / (
|
||||
eps + match_res['pred_char_num'])
|
||||
word_acc = 1.0 * match_res['match_word_num'] / (
|
||||
eps + match_res['gt_word_num'])
|
||||
word_acc_ignore_case = 1.0 * match_res['match_word_ignore_case'] / (
|
||||
eps + match_res['gt_word_num'])
|
||||
word_acc_ignore_case_symbol = 1.0 * match_res[
|
||||
'match_word_ignore_case_symbol'] / (
|
||||
eps + match_res['gt_word_num'])
|
||||
|
||||
eval_res = {}
|
||||
eval_res['word_acc'] = word_acc
|
||||
eval_res['word_acc_ignore_case'] = word_acc_ignore_case
|
||||
eval_res['word_acc_ignore_case_symbol'] = word_acc_ignore_case_symbol
|
||||
eval_res['char_recall'] = char_recall
|
||||
eval_res['char_precision'] = char_precision
|
||||
eval_res['1-N.E.D'] = 1.0 - match_res['ned']
|
||||
|
||||
for key, value in eval_res.items():
|
||||
eval_res[key] = float('{:.4f}'.format(value))
|
||||
|
||||
return eval_res
|
|
@ -0,0 +1,15 @@
|
|||
from mmdet.datasets.builder import DATASETS, build_dataloader, build_dataset
|
||||
from .base_dataset import BaseDataset
|
||||
from .icdar_dataset import IcdarDataset
|
||||
from .kie_dataset import KIEDataset
|
||||
from .ocr_dataset import OCRDataset
|
||||
from .ocr_seg_dataset import OCRSegDataset
|
||||
from .pipelines import CustomFormatBundle, DBNetTargets, DRRGTargets
|
||||
from .text_det_dataset import TextDetDataset
|
||||
from .utils import * # noqa: F401,F403
|
||||
|
||||
__all__ = [
|
||||
'DATASETS', 'IcdarDataset', 'build_dataloader', 'build_dataset',
|
||||
'BaseDataset', 'OCRDataset', 'TextDetDataset', 'CustomFormatBundle',
|
||||
'DBNetTargets', 'OCRSegDataset', 'DRRGTargets', 'KIEDataset'
|
||||
]
|
|
@ -0,0 +1,166 @@
|
|||
import numpy as np
|
||||
from mmcv.utils import print_log
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from mmdet.datasets.builder import DATASETS
|
||||
from mmdet.datasets.pipelines import Compose
|
||||
from mmocr.datasets.builder import build_loader
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class BaseDataset(Dataset):
|
||||
"""Custom dataset for text detection, text recognition, and their
|
||||
downstream tasks.
|
||||
|
||||
1. The text detection annotation format is as follows:
|
||||
The `annotations` field is optional for testing
|
||||
(this is one line of anno_file, with line-json-str
|
||||
converted to dict for visualizing only).
|
||||
|
||||
{
|
||||
"file_name": "sample.jpg",
|
||||
"height": 1080,
|
||||
"width": 960,
|
||||
"annotations":
|
||||
[
|
||||
{
|
||||
"iscrowd": 0,
|
||||
"category_id": 1,
|
||||
"bbox": [357.0, 667.0, 804.0, 100.0],
|
||||
"segmentation": [[361, 667, 710, 670,
|
||||
72, 767, 357, 763]]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
2. The two text recognition annotation formats are as follows:
|
||||
The `x1,y1,x2,y2,x3,y3,x4,y4` field is used for online crop
|
||||
augmentation during training.
|
||||
|
||||
format1: sample.jpg hello
|
||||
format2: sample.jpg 20 20 100 20 100 40 20 40 hello
|
||||
|
||||
Args:
|
||||
ann_file (str): Annotation file path.
|
||||
pipeline (list[dict]): Processing pipeline.
|
||||
loader (dict): Dictionary to construct loader
|
||||
to load annotation infos.
|
||||
img_prefix (str, optional): Image prefix to generate full
|
||||
image path.
|
||||
test_mode (bool, optional): If set True, try...except will
|
||||
be turned off in __getitem__.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ann_file,
|
||||
loader,
|
||||
pipeline,
|
||||
img_prefix='',
|
||||
test_mode=False):
|
||||
super().__init__()
|
||||
self.test_mode = test_mode
|
||||
self.img_prefix = img_prefix
|
||||
self.ann_file = ann_file
|
||||
# load annotations
|
||||
loader.update(ann_file=ann_file)
|
||||
self.data_infos = build_loader(loader)
|
||||
# processing pipeline
|
||||
self.pipeline = Compose(pipeline)
|
||||
# set group flag and class, no meaning
|
||||
# for text detect and recognize
|
||||
self._set_group_flag()
|
||||
self.CLASSES = 0
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_infos)
|
||||
|
||||
def _set_group_flag(self):
|
||||
"""Set flag."""
|
||||
self.flag = np.zeros(len(self), dtype=np.uint8)
|
||||
|
||||
def pre_pipeline(self, results):
|
||||
"""Prepare results dict for pipeline."""
|
||||
results['img_prefix'] = self.img_prefix
|
||||
|
||||
def prepare_train_img(self, index):
|
||||
"""Get training data and annotations from pipeline.
|
||||
|
||||
Args:
|
||||
index (int): Index of data.
|
||||
|
||||
Returns:
|
||||
dict: Training data and annotation after pipeline with new keys \
|
||||
introduced by pipeline.
|
||||
"""
|
||||
img_info = self.data_infos[index]
|
||||
results = dict(img_info=img_info)
|
||||
self.pre_pipeline(results)
|
||||
return self.pipeline(results)
|
||||
|
||||
def prepare_test_img(self, img_info):
|
||||
"""Get testing data from pipeline.
|
||||
|
||||
Args:
|
||||
idx (int): Index of data.
|
||||
|
||||
Returns:
|
||||
dict: Testing data after pipeline with new keys introduced by \
|
||||
pipeline.
|
||||
"""
|
||||
return self.prepare_train_img(img_info)
|
||||
|
||||
def _log_error_index(self, index):
|
||||
"""Logging data info of bad index."""
|
||||
try:
|
||||
data_info = self.data_infos[index]
|
||||
img_prefix = self.img_prefix
|
||||
print_log(f'Warning: skip broken file {data_info} '
|
||||
f'with img_prefix {img_prefix}')
|
||||
except Exception as e:
|
||||
print_log(f'load index {index} with error {e}')
|
||||
|
||||
def _get_next_index(self, index):
|
||||
"""Get next index from dataset."""
|
||||
self._log_error_index(index)
|
||||
index = (index + 1) % len(self)
|
||||
return index
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Get training/test data from pipeline.
|
||||
|
||||
Args:
|
||||
index (int): Index of data.
|
||||
|
||||
Returns:
|
||||
dict: Training/test data.
|
||||
"""
|
||||
if self.test_mode:
|
||||
return self.prepare_test_img(index)
|
||||
|
||||
while True:
|
||||
try:
|
||||
data = self.prepare_train_img(index)
|
||||
if data is None:
|
||||
raise Exception('prepared train data empty')
|
||||
break
|
||||
except Exception as e:
|
||||
print_log(f'prepare index {index} with error {e}')
|
||||
index = self._get_next_index(index)
|
||||
return data
|
||||
|
||||
def format_results(self, results, **kwargs):
|
||||
"""Placeholder to format result to dataset-specific output."""
|
||||
pass
|
||||
|
||||
def evaluate(self, results, metric=None, logger=None, **kwargs):
|
||||
"""Evaluate the dataset.
|
||||
|
||||
Args:
|
||||
results (list): Testing results of the dataset.
|
||||
metric (str | list[str]): Metrics to be evaluated.
|
||||
logger (logging.Logger | str | None): Logger used for printing
|
||||
related information during evaluation. Default: None.
|
||||
Returns:
|
||||
dict[str: float]
|
||||
"""
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,14 @@
|
|||
from mmcv.utils import Registry, build_from_cfg
|
||||
|
||||
LOADERS = Registry('loader')
|
||||
PARSERS = Registry('parser')
|
||||
|
||||
|
||||
def build_loader(cfg):
|
||||
"""Build anno file loader."""
|
||||
return build_from_cfg(cfg, LOADERS)
|
||||
|
||||
|
||||
def build_parser(cfg):
|
||||
"""Build anno file parser."""
|
||||
return build_from_cfg(cfg, PARSERS)
|
|
@ -0,0 +1,34 @@
|
|||
from mmdet.datasets.builder import DATASETS
|
||||
from mmocr.core.evaluation.ocr_metric import eval_ocr_metric
|
||||
from mmocr.datasets.base_dataset import BaseDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class OCRDataset(BaseDataset):
|
||||
|
||||
def pre_pipeline(self, results):
|
||||
results['img_prefix'] = self.img_prefix
|
||||
results['text'] = results['img_info']['text']
|
||||
|
||||
def evaluate(self, results, metric='acc', logger=None, **kwargs):
|
||||
"""Evaluate the dataset.
|
||||
|
||||
Args:
|
||||
results (list): Testing results of the dataset.
|
||||
metric (str | list[str]): Metrics to be evaluated.
|
||||
logger (logging.Logger | str | None): Logger used for printing
|
||||
related information during evaluation. Default: None.
|
||||
Returns:
|
||||
dict[str: float]
|
||||
"""
|
||||
gt_texts = []
|
||||
pred_texts = []
|
||||
for i in range(len(self)):
|
||||
item_info = self.data_infos[i]
|
||||
text = item_info['text']
|
||||
gt_texts.append(text)
|
||||
pred_texts.append(results[i]['text'])
|
||||
|
||||
eval_results = eval_ocr_metric(pred_texts, gt_texts)
|
||||
|
||||
return eval_results
|
|
@ -0,0 +1,90 @@
|
|||
import mmocr.utils as utils
|
||||
from mmdet.datasets.builder import DATASETS
|
||||
from mmocr.datasets.ocr_dataset import OCRDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class OCRSegDataset(OCRDataset):
|
||||
|
||||
def pre_pipeline(self, results):
|
||||
results['img_prefix'] = self.img_prefix
|
||||
|
||||
def _parse_anno_info(self, annotations):
|
||||
"""Parse char boxes annotations.
|
||||
Args:
|
||||
annotations (list[dict]): Annotations of one image, where
|
||||
each dict is for one character.
|
||||
|
||||
Returns:
|
||||
dict: A dict containing the following keys:
|
||||
|
||||
- chars (list[str]): List of character strings.
|
||||
- char_rects (list[list[float]]): List of char box, with each
|
||||
in style of rectangle: [x_min, y_min, x_max, y_max].
|
||||
- char_quads (list[list[float]]): List of char box, with each
|
||||
in style of quadrangle: [x1, y1, x2, y2, x3, y3, x4, y4].
|
||||
"""
|
||||
|
||||
assert utils.is_type_list(annotations, dict)
|
||||
assert 'char_box' in annotations[0]
|
||||
assert 'char_text' in annotations[0]
|
||||
assert len(annotations[0]['char_box']) == 4 or \
|
||||
len(annotations[0]['char_box']) == 8
|
||||
|
||||
chars, char_rects, char_quads = [], [], []
|
||||
for ann in annotations:
|
||||
char_box = ann['char_box']
|
||||
if len(char_box) == 4:
|
||||
char_box_type = ann.get('char_box_type', 'xyxy')
|
||||
if char_box_type == 'xyxy':
|
||||
char_rects.append(char_box)
|
||||
char_quads.append([
|
||||
char_box[0], char_box[1], char_box[2], char_box[1],
|
||||
char_box[2], char_box[3], char_box[0], char_box[3]
|
||||
])
|
||||
elif char_box_type == 'xywh':
|
||||
x1, y1, w, h = char_box
|
||||
x2 = x1 + w
|
||||
y2 = y1 + h
|
||||
char_rects.append([x1, y1, x2, y2])
|
||||
char_quads.append([x1, y1, x2, y1, x2, y2, x1, y2])
|
||||
else:
|
||||
raise ValueError(f'invalid char_box_type {char_box_type}')
|
||||
elif len(char_box) == 8:
|
||||
x_list, y_list = [], []
|
||||
for i in range(4):
|
||||
x_list.append(char_box[2 * i])
|
||||
y_list.append(char_box[2 * i + 1])
|
||||
x_max, x_min = max(x_list), min(x_list)
|
||||
y_max, y_min = max(y_list), min(y_list)
|
||||
char_rects.append([x_min, y_min, x_max, y_max])
|
||||
char_quads.append(char_box)
|
||||
else:
|
||||
raise Exception(
|
||||
f'invalid num in char box: {len(char_box)} not in (4, 8)')
|
||||
chars.append(ann['char_text'])
|
||||
|
||||
ann = dict(chars=chars, char_rects=char_rects, char_quads=char_quads)
|
||||
|
||||
return ann
|
||||
|
||||
def prepare_train_img(self, index):
|
||||
"""Get training data and annotations from pipeline.
|
||||
|
||||
Args:
|
||||
index (int): Index of data.
|
||||
|
||||
Returns:
|
||||
dict: Training data and annotation after pipeline with new keys \
|
||||
introduced by pipeline.
|
||||
"""
|
||||
img_ann_info = self.data_infos[index]
|
||||
img_info = {
|
||||
'filename': img_ann_info['file_name'],
|
||||
}
|
||||
ann_info = self._parse_anno_info(img_ann_info['annotations'])
|
||||
results = dict(img_info=img_info, ann_info=ann_info)
|
||||
|
||||
self.pre_pipeline(results)
|
||||
|
||||
return self.pipeline(results)
|
|
@ -0,0 +1,185 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
from shapely.geometry import LineString, Point, Polygon
|
||||
|
||||
import mmocr.utils as utils
|
||||
|
||||
|
||||
def sort_vertex(points_x, points_y):
|
||||
"""Sort box vertices in clockwise order from left-top first.
|
||||
|
||||
Args:
|
||||
points_x (list[float]): x of four vertices.
|
||||
points_y (list[float]): y of four vertices.
|
||||
Returns:
|
||||
sorted_points_x (list[float]): x of sorted four vertices.
|
||||
sorted_points_y (list[float]): y of sorted four vertices.
|
||||
"""
|
||||
assert utils.is_type_list(points_x, float) or utils.is_type_list(
|
||||
points_x, int)
|
||||
assert utils.is_type_list(points_y, float) or utils.is_type_list(
|
||||
points_y, int)
|
||||
assert len(points_x) == 4
|
||||
assert len(points_y) == 4
|
||||
|
||||
x = np.array(points_x)
|
||||
y = np.array(points_y)
|
||||
center_x = np.sum(x) * 0.25
|
||||
center_y = np.sum(y) * 0.25
|
||||
|
||||
x_arr = np.array(x - center_x)
|
||||
y_arr = np.array(y - center_y)
|
||||
|
||||
angle = np.arctan2(y_arr, x_arr) * 180.0 / np.pi
|
||||
sort_idx = np.argsort(angle)
|
||||
|
||||
sorted_points_x, sorted_points_y = [], []
|
||||
for i in range(4):
|
||||
sorted_points_x.append(points_x[sort_idx[i]])
|
||||
sorted_points_y.append(points_y[sort_idx[i]])
|
||||
|
||||
return convert_canonical(sorted_points_x, sorted_points_y)
|
||||
|
||||
|
||||
def convert_canonical(points_x, points_y):
|
||||
"""Make left-top be first.
|
||||
|
||||
Args:
|
||||
points_x (list[float]): x of four vertices.
|
||||
points_y (list[float]): y of four vertices.
|
||||
Returns:
|
||||
sorted_points_x (list[float]): x of sorted four vertices.
|
||||
sorted_points_y (list[float]): y of sorted four vertices.
|
||||
"""
|
||||
assert utils.is_type_list(points_x, float) or utils.is_type_list(
|
||||
points_x, int)
|
||||
assert utils.is_type_list(points_y, float) or utils.is_type_list(
|
||||
points_y, int)
|
||||
assert len(points_x) == 4
|
||||
assert len(points_y) == 4
|
||||
|
||||
points = [Point(points_x[i], points_y[i]) for i in range(4)]
|
||||
|
||||
polygon = Polygon([(p.x, p.y) for p in points])
|
||||
min_x, min_y, _, _ = polygon.bounds
|
||||
points_to_lefttop = [
|
||||
LineString([points[i], Point(min_x, min_y)]) for i in range(4)
|
||||
]
|
||||
distances = np.array([line.length for line in points_to_lefttop])
|
||||
sort_dist_idx = np.argsort(distances)
|
||||
lefttop_idx = sort_dist_idx[0]
|
||||
|
||||
if lefttop_idx == 0:
|
||||
point_orders = [0, 1, 2, 3]
|
||||
elif lefttop_idx == 1:
|
||||
point_orders = [1, 2, 3, 0]
|
||||
elif lefttop_idx == 2:
|
||||
point_orders = [2, 3, 0, 1]
|
||||
else:
|
||||
point_orders = [3, 0, 1, 2]
|
||||
|
||||
sorted_points_x = [points_x[i] for i in point_orders]
|
||||
sorted_points_y = [points_y[j] for j in point_orders]
|
||||
|
||||
return sorted_points_x, sorted_points_y
|
||||
|
||||
|
||||
def box_jitter(points_x, points_y, jitter_ratio_x=0.5, jitter_ratio_y=0.1):
|
||||
"""Jitter on the coordinates of bounding box.
|
||||
|
||||
Args:
|
||||
points_x (list[float | int]): List of y for four vertices.
|
||||
points_y (list[float | int]): List of x for four vertices.
|
||||
jitter_ratio_x (float): Horizontal jitter ratio relative to the height.
|
||||
jitter_ratio_y (float): Vertical jitter ratio relative to the height.
|
||||
"""
|
||||
assert len(points_x) == 4
|
||||
assert len(points_y) == 4
|
||||
assert isinstance(jitter_ratio_x, float)
|
||||
assert isinstance(jitter_ratio_y, float)
|
||||
assert 0 <= jitter_ratio_x < 1
|
||||
assert 0 <= jitter_ratio_y < 1
|
||||
|
||||
points = [Point(points_x[i], points_y[i]) for i in range(4)]
|
||||
line_list = [
|
||||
LineString([points[i], points[i + 1 if i < 3 else 0]])
|
||||
for i in range(4)
|
||||
]
|
||||
|
||||
tmp_h = max(line_list[1].length, line_list[3].length)
|
||||
|
||||
for i in range(4):
|
||||
jitter_pixel_x = (np.random.rand() - 0.5) * 2 * jitter_ratio_x * tmp_h
|
||||
jitter_pixel_y = (np.random.rand() - 0.5) * 2 * jitter_ratio_y * tmp_h
|
||||
points_x[i] += jitter_pixel_x
|
||||
points_y[i] += jitter_pixel_y
|
||||
|
||||
|
||||
def warp_img(src_img,
|
||||
box,
|
||||
jitter_flag=False,
|
||||
jitter_ratio_x=0.5,
|
||||
jitter_ratio_y=0.1):
|
||||
"""Crop box area from image using opencv warpPerspective w/o box jitter.
|
||||
|
||||
Args:
|
||||
src_img (np.array): Image before cropping.
|
||||
box (list[float | int]): Coordinates of quadrangle.
|
||||
"""
|
||||
assert utils.is_type_list(box, float) or utils.is_type_list(box, int)
|
||||
assert len(box) == 8
|
||||
|
||||
h, w = src_img.shape[:2]
|
||||
points_x = [min(max(x, 0), w) for x in box[0:8:2]]
|
||||
points_y = [min(max(y, 0), h) for y in box[1:9:2]]
|
||||
|
||||
points_x, points_y = sort_vertex(points_x, points_y)
|
||||
|
||||
if jitter_flag:
|
||||
box_jitter(
|
||||
points_x,
|
||||
points_y,
|
||||
jitter_ratio_x=jitter_ratio_x,
|
||||
jitter_ratio_y=jitter_ratio_y)
|
||||
|
||||
points = [Point(points_x[i], points_y[i]) for i in range(4)]
|
||||
edges = [
|
||||
LineString([points[i], points[i + 1 if i < 3 else 0]])
|
||||
for i in range(4)
|
||||
]
|
||||
|
||||
pts1 = np.float32([[points[i].x, points[i].y] for i in range(4)])
|
||||
box_width = max(edges[0].length, edges[2].length)
|
||||
box_height = max(edges[1].length, edges[3].length)
|
||||
|
||||
pts2 = np.float32([[0, 0], [box_width, 0], [box_width, box_height],
|
||||
[0, box_height]])
|
||||
M = cv2.getPerspectiveTransform(pts1, pts2)
|
||||
dst_img = cv2.warpPerspective(src_img, M,
|
||||
(int(box_width), int(box_height)))
|
||||
|
||||
return dst_img
|
||||
|
||||
|
||||
def crop_img(src_img, box):
|
||||
"""Crop box area to rectangle.
|
||||
|
||||
Args:
|
||||
src_img (np.array): Image before crop.
|
||||
box (list[float | int]): Points of quadrangle.
|
||||
"""
|
||||
assert utils.is_type_list(box, float) or utils.is_type_list(box, int)
|
||||
assert len(box) == 8
|
||||
|
||||
h, w = src_img.shape[:2]
|
||||
points_x = [min(max(x, 0), w) for x in box[0:8:2]]
|
||||
points_y = [min(max(y, 0), h) for y in box[1:9:2]]
|
||||
|
||||
left = int(min(points_x))
|
||||
top = int(min(points_y))
|
||||
right = int(max(points_x))
|
||||
bottom = int(max(points_y))
|
||||
|
||||
dst_img = src_img[top:bottom, left:right]
|
||||
|
||||
return dst_img
|
|
@ -0,0 +1,201 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
|
||||
import mmocr.utils.check_argument as check_argument
|
||||
from mmdet.core import BitmapMasks
|
||||
from mmdet.datasets.builder import PIPELINES
|
||||
from mmocr.models.builder import build_convertor
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class OCRSegTargets:
|
||||
"""Generate gt shrinked kernels for segmentation based OCR framework.
|
||||
|
||||
Args:
|
||||
label_convertor (dict): Dictionary to construct label_convertor
|
||||
to convert char to index.
|
||||
attn_shrink_ratio (float): The area shrinked ratio
|
||||
between attention kernels and gt text masks.
|
||||
seg_shrink_ratio (float): The area shrinked ratio
|
||||
between segmentation kernels and gt text masks.
|
||||
box_type (str): Character box type, should be either
|
||||
'char_rects' or 'char_quads', with 'char_rects'
|
||||
for rectangle with ``xyxy`` style and 'char_quads'
|
||||
for quadrangle with ``x1y1x2y2x3y3x4y4`` style.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
label_convertor=None,
|
||||
attn_shrink_ratio=0.5,
|
||||
seg_shrink_ratio=0.25,
|
||||
box_type='char_rects',
|
||||
pad_val=255):
|
||||
|
||||
assert isinstance(attn_shrink_ratio, float)
|
||||
assert isinstance(seg_shrink_ratio, float)
|
||||
assert 0. < attn_shrink_ratio < 1.0
|
||||
assert 0. < seg_shrink_ratio < 1.0
|
||||
assert label_convertor is not None
|
||||
assert box_type in ('char_rects', 'char_quads')
|
||||
|
||||
self.attn_shrink_ratio = attn_shrink_ratio
|
||||
self.seg_shrink_ratio = seg_shrink_ratio
|
||||
self.label_convertor = build_convertor(label_convertor)
|
||||
self.box_type = box_type
|
||||
self.pad_val = pad_val
|
||||
|
||||
def shrink_char_quad(self, char_quad, shrink_ratio):
|
||||
"""Shrink char box in style of quadrangle.
|
||||
|
||||
Args:
|
||||
char_quad (list[float]): Char box with format
|
||||
[x1, y1, x2, y2, x3, y3, x4, y4].
|
||||
shrink_ratio (float): The area shrinked ratio
|
||||
between gt kernels and gt text masks.
|
||||
"""
|
||||
points = [[char_quad[0], char_quad[1]], [char_quad[2], char_quad[3]],
|
||||
[char_quad[4], char_quad[5]], [char_quad[6], char_quad[7]]]
|
||||
shrink_points = []
|
||||
for p_idx, point in enumerate(points):
|
||||
p1 = points[(p_idx + 3) % 4]
|
||||
p2 = points[(p_idx + 1) % 4]
|
||||
|
||||
dist1 = self.l2_dist_two_points(p1, point)
|
||||
dist2 = self.l2_dist_two_points(p2, point)
|
||||
min_dist = min(dist1, dist2)
|
||||
|
||||
v1 = [p1[0] - point[0], p1[1] - point[1]]
|
||||
v2 = [p2[0] - point[0], p2[1] - point[1]]
|
||||
|
||||
temp_dist1 = (shrink_ratio * min_dist /
|
||||
dist1) if min_dist != 0 else 0.
|
||||
temp_dist2 = (shrink_ratio * min_dist /
|
||||
dist2) if min_dist != 0 else 0.
|
||||
|
||||
v1 = [temp * temp_dist1 for temp in v1]
|
||||
v2 = [temp * temp_dist2 for temp in v2]
|
||||
|
||||
shrink_point = [
|
||||
round(point[0] + v1[0] + v2[0]),
|
||||
round(point[1] + v1[1] + v2[1])
|
||||
]
|
||||
shrink_points.append(shrink_point)
|
||||
|
||||
poly = np.array(shrink_points)
|
||||
|
||||
return poly
|
||||
|
||||
def shrink_char_rect(self, char_rect, shrink_ratio):
|
||||
"""Shrink char box in style of rectangle.
|
||||
|
||||
Args:
|
||||
char_rect (list[float]): Char box with format
|
||||
[x_min, y_min, x_max, y_max].
|
||||
shrink_ratio (float): The area shrinked ratio
|
||||
between gt kernels and gt text masks.
|
||||
"""
|
||||
x_min, y_min, x_max, y_max = char_rect
|
||||
w = x_max - x_min
|
||||
h = y_max - y_min
|
||||
x_min_s = round((x_min + x_max - w * shrink_ratio) / 2)
|
||||
y_min_s = round((y_min + y_max - h * shrink_ratio) / 2)
|
||||
x_max_s = round((x_min + x_max + w * shrink_ratio) / 2)
|
||||
y_max_s = round((y_min + y_max + h * shrink_ratio) / 2)
|
||||
poly = np.array([[x_min_s, y_min_s], [x_max_s, y_min_s],
|
||||
[x_max_s, y_max_s], [x_min_s, y_max_s]])
|
||||
|
||||
return poly
|
||||
|
||||
def generate_kernels(self,
|
||||
resize_shape,
|
||||
pad_shape,
|
||||
char_boxes,
|
||||
char_inds,
|
||||
shrink_ratio=0.5,
|
||||
binary=True):
|
||||
"""Generate char instance kernels for one shrink ratio.
|
||||
|
||||
Args:
|
||||
resize_shape (tuple(int, int)): Image size (height, width)
|
||||
after resizing.
|
||||
pad_shape (tuple(int, int)): Image size (height, width)
|
||||
after padding.
|
||||
char_boxes (list[list[float]]): The list of char polygons.
|
||||
char_inds (list[int]): List of char indexes.
|
||||
shrink_ratio (float): The shrink ratio of kernel.
|
||||
binary (bool): If True, return binary ndarray
|
||||
containing 0 & 1 only.
|
||||
Returns:
|
||||
char_kernel (ndarray): The text kernel mask of (height, width).
|
||||
"""
|
||||
assert isinstance(resize_shape, tuple)
|
||||
assert isinstance(pad_shape, tuple)
|
||||
assert check_argument.is_2dlist(char_boxes)
|
||||
assert check_argument.is_type_list(char_inds, int)
|
||||
assert isinstance(shrink_ratio, float)
|
||||
assert isinstance(binary, bool)
|
||||
|
||||
char_kernel = np.zeros(pad_shape, dtype=np.int32)
|
||||
char_kernel[:resize_shape[0], resize_shape[1]:] = self.pad_val
|
||||
|
||||
for i, char_box in enumerate(char_boxes):
|
||||
if self.box_type == 'char_rects':
|
||||
poly = self.shrink_char_rect(char_box, shrink_ratio)
|
||||
elif self.box_type == 'char_quads':
|
||||
poly = self.shrink_char_quad(char_box, shrink_ratio)
|
||||
|
||||
fill_value = 1 if binary else char_inds[i]
|
||||
cv2.fillConvexPoly(char_kernel, poly.astype(np.int32),
|
||||
(fill_value))
|
||||
|
||||
return char_kernel
|
||||
|
||||
def l2_dist_two_points(self, p1, p2):
|
||||
return ((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)**0.5
|
||||
|
||||
def __call__(self, results):
|
||||
img_shape = results['img_shape']
|
||||
resize_shape = results['resize_shape']
|
||||
|
||||
h_scale = 1.0 * resize_shape[0] / img_shape[0]
|
||||
w_scale = 1.0 * resize_shape[1] / img_shape[1]
|
||||
|
||||
char_boxes, char_inds = [], []
|
||||
char_num = len(results['ann_info'][self.box_type])
|
||||
for i in range(char_num):
|
||||
char_box = results['ann_info'][self.box_type][i]
|
||||
num_points = 2 if self.box_type == 'char_rects' else 4
|
||||
for j in range(num_points):
|
||||
char_box[j * 2] = round(char_box[j * 2] * w_scale)
|
||||
char_box[j * 2 + 1] = round(char_box[j * 2 + 1] * h_scale)
|
||||
char_boxes.append(char_box)
|
||||
char = results['ann_info']['chars'][i]
|
||||
char_ind = self.label_convertor.str2idx([char])[0][0]
|
||||
char_inds.append(char_ind)
|
||||
|
||||
resize_shape = tuple(results['resize_shape'][:2])
|
||||
pad_shape = tuple(results['pad_shape'][:2])
|
||||
binary_target = self.generate_kernels(
|
||||
resize_shape,
|
||||
pad_shape,
|
||||
char_boxes,
|
||||
char_inds,
|
||||
shrink_ratio=self.attn_shrink_ratio,
|
||||
binary=True)
|
||||
|
||||
seg_target = self.generate_kernels(
|
||||
resize_shape,
|
||||
pad_shape,
|
||||
char_boxes,
|
||||
char_inds,
|
||||
shrink_ratio=self.seg_shrink_ratio,
|
||||
binary=False)
|
||||
|
||||
mask = np.ones(pad_shape, dtype=np.int32)
|
||||
mask[:resize_shape[0], resize_shape[1]:] = 0
|
||||
|
||||
results['gt_kernels'] = BitmapMasks([binary_target, seg_target, mask],
|
||||
pad_shape[0], pad_shape[1])
|
||||
results['mask_fields'] = ['gt_kernels']
|
||||
|
||||
return results
|
|
@ -0,0 +1,447 @@
|
|||
import math
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as TF
|
||||
from mmcv.runner.dist_utils import get_dist_info
|
||||
from PIL import Image
|
||||
from shapely.geometry import Polygon
|
||||
from shapely.geometry import box as shapely_box
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmdet.datasets.builder import PIPELINES
|
||||
from mmocr.datasets.pipelines.crop import warp_img
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class ResizeOCR:
|
||||
"""Image resizing and padding for OCR.
|
||||
|
||||
Args:
|
||||
height (int | tuple(int)): Image height after resizing.
|
||||
min_width (none | int | tuple(int)): Image minimum width
|
||||
after resizing.
|
||||
max_width (none | int | tuple(int)): Image maximum width
|
||||
after resizing.
|
||||
keep_aspect_ratio (bool): Keep image aspect ratio if True
|
||||
during resizing, Otherwise resize to the size height *
|
||||
max_width.
|
||||
img_pad_value (int): Scalar to fill padding area.
|
||||
width_downsample_ratio (float): Downsample ratio in horizontal
|
||||
direction from input image to output feature.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
height,
|
||||
min_width=None,
|
||||
max_width=None,
|
||||
keep_aspect_ratio=True,
|
||||
img_pad_value=0,
|
||||
width_downsample_ratio=1.0 / 16):
|
||||
assert isinstance(height, (int, tuple))
|
||||
assert utils.is_none_or_type(min_width, (int, tuple))
|
||||
assert utils.is_none_or_type(max_width, (int, tuple))
|
||||
if not keep_aspect_ratio:
|
||||
assert max_width is not None, \
|
||||
'"max_width" must assigned ' + \
|
||||
'if "keep_aspect_ratio" is False'
|
||||
assert isinstance(img_pad_value, int)
|
||||
if isinstance(height, tuple):
|
||||
assert isinstance(min_width, tuple)
|
||||
assert isinstance(max_width, tuple)
|
||||
assert len(height) == len(min_width) == len(max_width)
|
||||
|
||||
self.height = height
|
||||
self.min_width = min_width
|
||||
self.max_width = max_width
|
||||
self.keep_aspect_ratio = keep_aspect_ratio
|
||||
self.img_pad_value = img_pad_value
|
||||
self.width_downsample_ratio = width_downsample_ratio
|
||||
|
||||
def __call__(self, results):
|
||||
rank, _ = get_dist_info()
|
||||
if isinstance(self.height, int):
|
||||
dst_height = self.height
|
||||
dst_min_width = self.min_width
|
||||
dst_max_width = self.max_width
|
||||
else:
|
||||
"""Multi-scale resize used in distributed training.
|
||||
|
||||
Choose one (height, width) pair for one rank id.
|
||||
"""
|
||||
idx = rank % len(self.height)
|
||||
dst_height = self.height[idx]
|
||||
dst_min_width = self.min_width[idx]
|
||||
dst_max_width = self.max_width[idx]
|
||||
|
||||
img_shape = results['img_shape']
|
||||
ori_height, ori_width = img_shape[:2]
|
||||
valid_ratio = 1.0
|
||||
resize_shape = list(img_shape)
|
||||
pad_shape = list(img_shape)
|
||||
|
||||
if self.keep_aspect_ratio:
|
||||
new_width = math.ceil(float(dst_height) / ori_height * ori_width)
|
||||
width_divisor = int(1 / self.width_downsample_ratio)
|
||||
# make sure new_width is an integral multiple of width_divisor.
|
||||
if new_width % width_divisor != 0:
|
||||
new_width = round(new_width / width_divisor) * width_divisor
|
||||
if dst_min_width is not None:
|
||||
new_width = max(dst_min_width, new_width)
|
||||
if dst_max_width is not None:
|
||||
valid_ratio = min(1.0, 1.0 * new_width / dst_max_width)
|
||||
resize_width = min(dst_max_width, new_width)
|
||||
img_resize = cv2.resize(results['img'],
|
||||
(resize_width, dst_height))
|
||||
resize_shape = img_resize.shape
|
||||
pad_shape = img_resize.shape
|
||||
if new_width < dst_max_width:
|
||||
img_resize = mmcv.impad(
|
||||
img_resize,
|
||||
shape=(dst_height, dst_max_width),
|
||||
pad_val=self.img_pad_value)
|
||||
pad_shape = img_resize.shape
|
||||
else:
|
||||
img_resize = cv2.resize(results['img'],
|
||||
(new_width, dst_height))
|
||||
resize_shape = img_resize.shape
|
||||
pad_shape = img_resize.shape
|
||||
else:
|
||||
img_resize = cv2.resize(results['img'],
|
||||
(dst_max_width, dst_height))
|
||||
resize_shape = img_resize.shape
|
||||
pad_shape = img_resize.shape
|
||||
|
||||
results['img'] = img_resize
|
||||
results['resize_shape'] = resize_shape
|
||||
results['pad_shape'] = pad_shape
|
||||
results['valid_ratio'] = valid_ratio
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class ToTensorOCR:
|
||||
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, results):
|
||||
results['img'] = TF.to_tensor(results['img'].copy())
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class NormalizeOCR:
|
||||
"""Normalize a tensor image with mean and standard deviation."""
|
||||
|
||||
def __init__(self, mean, std):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, results):
|
||||
results['img'] = TF.normalize(results['img'], self.mean, self.std)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class OnlineCropOCR:
|
||||
"""Crop text areas from whole image with bounding box jitter. If no bbox is
|
||||
given, return directly.
|
||||
|
||||
Args:
|
||||
box_keys (list[str]): Keys in results which correspond to RoI bbox.
|
||||
jitter_prob (float): The probability of box jitter.
|
||||
max_jitter_ratio_x (float): Maximum horizontal jitter ratio
|
||||
relative to height.
|
||||
max_jitter_ratio_y (float): Maximum vertical jitter ratio
|
||||
relative to height.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
box_keys=['x1', 'y1', 'x2', 'y2', 'x3', 'y3', 'x4', 'y4'],
|
||||
jitter_prob=0.5,
|
||||
max_jitter_ratio_x=0.05,
|
||||
max_jitter_ratio_y=0.02):
|
||||
assert utils.is_type_list(box_keys, str)
|
||||
assert 0 <= jitter_prob <= 1
|
||||
assert 0 <= max_jitter_ratio_x <= 1
|
||||
assert 0 <= max_jitter_ratio_y <= 1
|
||||
|
||||
self.box_keys = box_keys
|
||||
self.jitter_prob = jitter_prob
|
||||
self.max_jitter_ratio_x = max_jitter_ratio_x
|
||||
self.max_jitter_ratio_y = max_jitter_ratio_y
|
||||
|
||||
def __call__(self, results):
|
||||
|
||||
if 'img_info' not in results:
|
||||
return results
|
||||
|
||||
crop_flag = True
|
||||
box = []
|
||||
for key in self.box_keys:
|
||||
if key not in results['img_info']:
|
||||
crop_flag = False
|
||||
break
|
||||
|
||||
box.append(float(results['img_info'][key]))
|
||||
|
||||
if not crop_flag:
|
||||
return results
|
||||
|
||||
jitter_flag = np.random.random() > self.jitter_prob
|
||||
|
||||
kwargs = dict(
|
||||
jitter_flag=jitter_flag,
|
||||
jitter_ratio_x=self.max_jitter_ratio_x,
|
||||
jitter_ratio_y=self.max_jitter_ratio_y)
|
||||
crop_img = warp_img(results['img'], box, **kwargs)
|
||||
|
||||
results['img'] = crop_img
|
||||
results['img_shape'] = crop_img.shape
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class FancyPCA:
|
||||
"""Implementation of PCA based image augmentation, proposed in the paper
|
||||
``Imagenet Classification With Deep Convolutional Neural Networks``.
|
||||
|
||||
It alters the intensities of RGB values along the principal components of
|
||||
ImageNet dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, eig_vec=None, eig_val=None):
|
||||
if eig_vec is None:
|
||||
eig_vec = torch.Tensor([
|
||||
[-0.5675, +0.7192, +0.4009],
|
||||
[-0.5808, -0.0045, -0.8140],
|
||||
[-0.5836, -0.6948, +0.4203],
|
||||
]).t()
|
||||
if eig_val is None:
|
||||
eig_val = torch.Tensor([[0.2175, 0.0188, 0.0045]])
|
||||
self.eig_val = eig_val # 1*3
|
||||
self.eig_vec = eig_vec # 3*3
|
||||
|
||||
def pca(self, tensor):
|
||||
assert tensor.size(0) == 3
|
||||
alpha = torch.normal(mean=torch.zeros_like(self.eig_val)) * 0.1
|
||||
reconst = torch.mm(self.eig_val * alpha, self.eig_vec)
|
||||
tensor = tensor + reconst.view(3, 1, 1)
|
||||
|
||||
return tensor
|
||||
|
||||
def __call__(self, results):
|
||||
img = results['img']
|
||||
tensor = self.pca(img)
|
||||
results['img'] = tensor
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class RandomPaddingOCR:
|
||||
"""Pad the given image on all sides, as well as modify the coordinates of
|
||||
character bounding box in image.
|
||||
|
||||
Args:
|
||||
max_ratio (list[int]): [left, top, right, bottom].
|
||||
box_type (None|str): Character box type. If not none,
|
||||
should be either 'char_rects' or 'char_quads', with
|
||||
'char_rects' for rectangle with ``xyxy`` style and
|
||||
'char_quads' for quadrangle with ``x1y1x2y2x3y3x4y4`` style.
|
||||
"""
|
||||
|
||||
def __init__(self, max_ratio=None, box_type=None):
|
||||
if max_ratio is None:
|
||||
max_ratio = [0.1, 0.2, 0.1, 0.2]
|
||||
else:
|
||||
assert utils.is_type_list(max_ratio, float)
|
||||
assert len(max_ratio) == 4
|
||||
assert box_type is None or box_type in ('char_rects', 'char_quads')
|
||||
|
||||
self.max_ratio = max_ratio
|
||||
self.box_type = box_type
|
||||
|
||||
def __call__(self, results):
|
||||
|
||||
img_shape = results['img_shape']
|
||||
ori_height, ori_width = img_shape[:2]
|
||||
|
||||
random_padding_left = round(
|
||||
np.random.uniform(0, self.max_ratio[0]) * ori_width)
|
||||
random_padding_top = round(
|
||||
np.random.uniform(0, self.max_ratio[1]) * ori_height)
|
||||
random_padding_right = round(
|
||||
np.random.uniform(0, self.max_ratio[2]) * ori_width)
|
||||
random_padding_bottom = round(
|
||||
np.random.uniform(0, self.max_ratio[3]) * ori_height)
|
||||
|
||||
img = np.copy(results['img'])
|
||||
img = cv2.copyMakeBorder(img, random_padding_top,
|
||||
random_padding_bottom, random_padding_left,
|
||||
random_padding_right, cv2.BORDER_REPLICATE)
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
|
||||
if self.box_type is not None:
|
||||
num_points = 2 if self.box_type == 'char_rects' else 4
|
||||
char_num = len(results['ann_info'][self.box_type])
|
||||
for i in range(char_num):
|
||||
for j in range(num_points):
|
||||
results['ann_info'][self.box_type][i][
|
||||
j * 2] += random_padding_left
|
||||
results['ann_info'][self.box_type][i][
|
||||
j * 2 + 1] += random_padding_top
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class RandomRotateImageBox:
|
||||
"""Rotate augmentation for segmentation based text recognition.
|
||||
|
||||
Args:
|
||||
min_angle (int): Minimum rotation angle for image and box.
|
||||
max_angle (int): Maximum rotation angle for image and box.
|
||||
box_type (str): Character box type, should be either
|
||||
'char_rects' or 'char_quads', with 'char_rects'
|
||||
for rectangle with ``xyxy`` style and 'char_quads'
|
||||
for quadrangle with ``x1y1x2y2x3y3x4y4`` style.
|
||||
"""
|
||||
|
||||
def __init__(self, min_angle=-10, max_angle=10, box_type='char_quads'):
|
||||
assert box_type in ('char_rects', 'char_quads')
|
||||
|
||||
self.min_angle = min_angle
|
||||
self.max_angle = max_angle
|
||||
self.box_type = box_type
|
||||
|
||||
def __call__(self, results):
|
||||
in_img = results['img']
|
||||
in_chars = results['ann_info']['chars']
|
||||
in_boxes = results['ann_info'][self.box_type]
|
||||
|
||||
img_width, img_height = in_img.size
|
||||
rotate_center = [img_width / 2., img_height / 2.]
|
||||
|
||||
tan_temp_max_angle = rotate_center[1] / rotate_center[0]
|
||||
temp_max_angle = np.arctan(tan_temp_max_angle) * 180. / np.pi
|
||||
|
||||
random_angle = np.random.uniform(
|
||||
max(self.min_angle, -temp_max_angle),
|
||||
min(self.max_angle, temp_max_angle))
|
||||
random_angle_radian = random_angle * np.pi / 180.
|
||||
|
||||
img_box = shapely_box(0, 0, img_width, img_height)
|
||||
|
||||
out_img = TF.rotate(
|
||||
in_img,
|
||||
random_angle,
|
||||
resample=False,
|
||||
expand=False,
|
||||
center=rotate_center)
|
||||
|
||||
out_boxes, out_chars = self.rotate_bbox(in_boxes, in_chars,
|
||||
random_angle_radian,
|
||||
rotate_center, img_box)
|
||||
|
||||
results['img'] = out_img
|
||||
results['ann_info']['chars'] = out_chars
|
||||
results['ann_info'][self.box_type] = out_boxes
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def rotate_bbox(boxes, chars, angle, center, img_box):
|
||||
out_boxes = []
|
||||
out_chars = []
|
||||
for idx, bbox in enumerate(boxes):
|
||||
temp_bbox = []
|
||||
for i in range(len(bbox) // 2):
|
||||
point = [bbox[2 * i], bbox[2 * i + 1]]
|
||||
temp_bbox.append(
|
||||
RandomRotateImageBox.rotate_point(point, angle, center))
|
||||
poly_temp_bbox = Polygon(temp_bbox).buffer(0)
|
||||
if poly_temp_bbox.is_valid:
|
||||
if img_box.intersects(poly_temp_bbox) and (
|
||||
not img_box.touches(poly_temp_bbox)):
|
||||
temp_bbox_area = poly_temp_bbox.area
|
||||
|
||||
intersect_area = img_box.intersection(poly_temp_bbox).area
|
||||
intersect_ratio = intersect_area / temp_bbox_area
|
||||
|
||||
if intersect_ratio >= 0.7:
|
||||
out_box = []
|
||||
for p in temp_bbox:
|
||||
out_box.extend(p)
|
||||
out_boxes.append(out_box)
|
||||
out_chars.append(chars[idx])
|
||||
|
||||
return out_boxes, out_chars
|
||||
|
||||
@staticmethod
|
||||
def rotate_point(point, angle, center):
|
||||
cos_theta = math.cos(-angle)
|
||||
sin_theta = math.sin(-angle)
|
||||
c_x = center[0]
|
||||
c_y = center[1]
|
||||
new_x = (point[0] - c_x) * cos_theta - (point[1] -
|
||||
c_y) * sin_theta + c_x
|
||||
new_y = (point[0] - c_x) * sin_theta + (point[1] -
|
||||
c_y) * cos_theta + c_y
|
||||
|
||||
return [new_x, new_y]
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class OpencvToPil:
|
||||
"""Convert ``numpy.ndarray`` (bgr) to ``PIL Image`` (rgb)."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, results):
|
||||
img = results['img'][..., ::-1]
|
||||
img = Image.fromarray(img)
|
||||
results['img'] = img
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class PilToOpencv:
|
||||
"""Convert ``PIL Image`` (rgb) to ``numpy.ndarray`` (bgr)."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, results):
|
||||
img = np.asarray(results['img'])
|
||||
img = img[..., ::-1]
|
||||
results['img'] = img
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
return repr_str
|
|
@ -0,0 +1,108 @@
|
|||
import mmcv
|
||||
import numpy as np
|
||||
|
||||
from mmdet.datasets.builder import PIPELINES
|
||||
from mmdet.datasets.pipelines.compose import Compose
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class MultiRotateAugOCR:
|
||||
"""Test-time augmentation with multiple rotations in the case that
|
||||
img_height > img_width.
|
||||
|
||||
An example configuration is as follows:
|
||||
|
||||
.. code-block::
|
||||
|
||||
rotate_degrees=[0, 90, 270],
|
||||
transforms=[
|
||||
dict(
|
||||
type='ResizeOCR',
|
||||
height=32,
|
||||
min_width=32,
|
||||
max_width=160,
|
||||
keep_aspect_ratio=True),
|
||||
dict(type='ToTensorOCR'),
|
||||
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=[
|
||||
'filename', 'ori_shape', 'img_shape', 'valid_ratio'
|
||||
]),
|
||||
]
|
||||
|
||||
After MultiRotateAugOCR with above configuration, the results are wrapped
|
||||
into lists of the same length as follows:
|
||||
|
||||
.. code-block::
|
||||
|
||||
dict(
|
||||
img=[...],
|
||||
img_shape=[...]
|
||||
...
|
||||
)
|
||||
|
||||
Args:
|
||||
transforms (list[dict]): Transformation applied for each augmentation.
|
||||
rotate_degrees (list[int] | None): Degrees of anti-clockwise rotation.
|
||||
force_rotate (bool): If True, rotate image by 'rotate_degrees'
|
||||
while ignore image aspect ratio.
|
||||
"""
|
||||
|
||||
def __init__(self, transforms, rotate_degrees=None, force_rotate=False):
|
||||
self.transforms = Compose(transforms)
|
||||
self.force_rotate = force_rotate
|
||||
if rotate_degrees is not None:
|
||||
self.rotate_degrees = rotate_degrees if isinstance(
|
||||
rotate_degrees, list) else [rotate_degrees]
|
||||
assert mmcv.is_list_of(self.rotate_degrees, int)
|
||||
for degree in self.rotate_degrees:
|
||||
assert 0 <= degree < 360
|
||||
assert degree % 90 == 0
|
||||
if 0 not in self.rotate_degrees:
|
||||
self.rotate_degrees.append(0)
|
||||
else:
|
||||
self.rotate_degrees = [0]
|
||||
|
||||
def __call__(self, results):
|
||||
"""Call function to apply test time augment transformation to results.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict contains the data to be transformed.
|
||||
|
||||
Returns:
|
||||
dict[str: list]: The augmented data, where each value is wrapped
|
||||
into a list.
|
||||
"""
|
||||
img_shape = results['img_shape']
|
||||
ori_height, ori_width = img_shape[:2]
|
||||
if not self.force_rotate and ori_height <= ori_width:
|
||||
rotate_degrees = [0]
|
||||
else:
|
||||
rotate_degrees = self.rotate_degrees
|
||||
aug_data = []
|
||||
for degree in set(rotate_degrees):
|
||||
_results = results.copy()
|
||||
if degree == 0:
|
||||
pass
|
||||
elif degree == 90:
|
||||
_results['img'] = np.rot90(_results['img'], 1)
|
||||
elif degree == 180:
|
||||
_results['img'] = np.rot90(_results['img'], 2)
|
||||
elif degree == 270:
|
||||
_results['img'] = np.rot90(_results['img'], 3)
|
||||
data = self.transforms(_results)
|
||||
aug_data.append(data)
|
||||
# list of dict to dict of list
|
||||
aug_data_dict = {key: [] for key in aug_data[0]}
|
||||
for data in aug_data:
|
||||
for key, val in data.items():
|
||||
aug_data_dict[key].append(val)
|
||||
return aug_data_dict
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(transforms={self.transforms}, '
|
||||
repr_str += f'rotate_degrees={self.rotate_degrees})'
|
||||
return repr_str
|
|
@ -0,0 +1,121 @@
|
|||
import numpy as np
|
||||
|
||||
from mmdet.datasets.builder import DATASETS
|
||||
from mmocr.core.evaluation.hmean import eval_hmean
|
||||
from mmocr.datasets.base_dataset import BaseDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class TextDetDataset(BaseDataset):
|
||||
|
||||
def _parse_anno_info(self, annotations):
|
||||
"""Parse bbox and mask annotation.
|
||||
Args:
|
||||
annotations (dict): Annotations of one image.
|
||||
|
||||
Returns:
|
||||
dict: A dict containing the following keys: bboxes, bboxes_ignore,
|
||||
labels, masks, masks_ignore. "masks" and
|
||||
"masks_ignore" are represented by polygon boundary
|
||||
point sequences.
|
||||
"""
|
||||
gt_bboxes, gt_bboxes_ignore = [], []
|
||||
gt_masks, gt_masks_ignore = [], []
|
||||
gt_labels = []
|
||||
for ann in annotations:
|
||||
if ann.get('iscrowd', False):
|
||||
gt_bboxes_ignore.append(ann['bbox'])
|
||||
gt_masks_ignore.append(ann.get('segmentation', None))
|
||||
else:
|
||||
gt_bboxes.append(ann['bbox'])
|
||||
gt_labels.append(ann['category_id'])
|
||||
gt_masks.append(ann.get('segmentation', None))
|
||||
if gt_bboxes:
|
||||
gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
|
||||
gt_labels = np.array(gt_labels, dtype=np.int64)
|
||||
else:
|
||||
gt_bboxes = np.zeros((0, 4), dtype=np.float32)
|
||||
gt_labels = np.array([], dtype=np.int64)
|
||||
|
||||
if gt_bboxes_ignore:
|
||||
gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
|
||||
else:
|
||||
gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
|
||||
|
||||
ann = dict(
|
||||
bboxes=gt_bboxes,
|
||||
labels=gt_labels,
|
||||
bboxes_ignore=gt_bboxes_ignore,
|
||||
masks_ignore=gt_masks_ignore,
|
||||
masks=gt_masks)
|
||||
|
||||
return ann
|
||||
|
||||
def prepare_train_img(self, index):
|
||||
"""Get training data and annotations from pipeline.
|
||||
|
||||
Args:
|
||||
index (int): Index of data.
|
||||
|
||||
Returns:
|
||||
dict: Training data and annotation after pipeline with new keys \
|
||||
introduced by pipeline.
|
||||
"""
|
||||
img_ann_info = self.data_infos[index]
|
||||
img_info = {
|
||||
'filename': img_ann_info['file_name'],
|
||||
'height': img_ann_info['height'],
|
||||
'width': img_ann_info['width']
|
||||
}
|
||||
ann_info = self._parse_anno_info(img_ann_info['annotations'])
|
||||
results = dict(img_info=img_info, ann_info=ann_info)
|
||||
results['bbox_fields'] = []
|
||||
results['mask_fields'] = []
|
||||
results['seg_fields'] = []
|
||||
self.pre_pipeline(results)
|
||||
|
||||
return self.pipeline(results)
|
||||
|
||||
def evaluate(self,
|
||||
results,
|
||||
metric='hmean-iou',
|
||||
score_thr=0.3,
|
||||
rank_list=None,
|
||||
logger=None,
|
||||
**kwargs):
|
||||
"""Evaluate the dataset.
|
||||
|
||||
Args:
|
||||
results (list): Testing results of the dataset.
|
||||
metric (str | list[str]): Metrics to be evaluated.
|
||||
score_thr (float): Score threshold for prediction map.
|
||||
logger (logging.Logger | str | None): Logger used for printing
|
||||
related information during evaluation. Default: None.
|
||||
rank_list (str): json file used to save eval result
|
||||
of each image after ranking.
|
||||
Returns:
|
||||
dict[str: float]
|
||||
"""
|
||||
metrics = metric if isinstance(metric, list) else [metric]
|
||||
allowed_metrics = ['hmean-iou', 'hmean-ic13']
|
||||
metrics = set(metrics) & set(allowed_metrics)
|
||||
|
||||
img_infos = []
|
||||
ann_infos = []
|
||||
for i in range(len(self)):
|
||||
img_ann_info = self.data_infos[i]
|
||||
img_info = {'filename': img_ann_info['file_name']}
|
||||
ann_info = self._parse_anno_info(img_ann_info['annotations'])
|
||||
img_infos.append(img_info)
|
||||
ann_infos.append(ann_info)
|
||||
|
||||
eval_results = eval_hmean(
|
||||
results,
|
||||
img_infos,
|
||||
ann_infos,
|
||||
metrics=metrics,
|
||||
score_thr=score_thr,
|
||||
logger=logger,
|
||||
rank_list=rank_list)
|
||||
|
||||
return eval_results
|
|
@ -0,0 +1,4 @@
|
|||
from .loader import HardDiskLoader, LmdbLoader
|
||||
from .parser import LineJsonParser, LineStrParser
|
||||
|
||||
__all__ = ['HardDiskLoader', 'LmdbLoader', 'LineStrParser', 'LineJsonParser']
|
|
@ -0,0 +1,108 @@
|
|||
import os.path as osp
|
||||
|
||||
from mmocr.datasets.builder import LOADERS, build_parser
|
||||
|
||||
|
||||
@LOADERS.register_module()
|
||||
class Loader:
|
||||
"""Load annotation from annotation file, and parse instance information to
|
||||
dict format with parser.
|
||||
|
||||
Args:
|
||||
ann_file (str): Annotation file path.
|
||||
parser (dict): Dictionary to construct parser
|
||||
to parse original annotation infos.
|
||||
repeat (int): Repeated times of annotations.
|
||||
"""
|
||||
|
||||
def __init__(self, ann_file, parser, repeat=1):
|
||||
assert isinstance(ann_file, str)
|
||||
assert isinstance(repeat, int)
|
||||
assert isinstance(parser, dict)
|
||||
assert repeat > 0
|
||||
assert osp.exists(ann_file), f'{ann_file} is not exist'
|
||||
|
||||
self.ori_data_infos = self._load(ann_file)
|
||||
self.parser = build_parser(parser)
|
||||
self.repeat = repeat
|
||||
|
||||
def __len__(self):
|
||||
return len(self.ori_data_infos) * self.repeat
|
||||
|
||||
def _load(self, ann_file):
|
||||
"""Load annotation file."""
|
||||
raise NotImplementedError
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Retrieve anno info of one instance with dict format."""
|
||||
return self.parser.get_item(self.ori_data_infos, index)
|
||||
|
||||
|
||||
@LOADERS.register_module()
|
||||
class HardDiskLoader(Loader):
|
||||
"""Load annotation file from hard disk to RAM.
|
||||
|
||||
Args:
|
||||
ann_file (str): Annotation file path.
|
||||
"""
|
||||
|
||||
def _load(self, ann_file):
|
||||
data_ret = []
|
||||
with open(ann_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
data_ret.append(line)
|
||||
|
||||
return data_ret
|
||||
|
||||
|
||||
@LOADERS.register_module()
|
||||
class LmdbLoader(Loader):
|
||||
"""Load annotation file with lmdb storage backend."""
|
||||
|
||||
def _load(self, ann_file):
|
||||
lmdb_anno_obj = LmdbAnnFileBackend(ann_file)
|
||||
|
||||
return lmdb_anno_obj
|
||||
|
||||
|
||||
class LmdbAnnFileBackend:
|
||||
"""Lmdb storage backend for annotation file.
|
||||
|
||||
Args:
|
||||
lmdb_path (str): Lmdb file path.
|
||||
"""
|
||||
|
||||
def __init__(self, lmdb_path, coding='utf8'):
|
||||
self.lmdb_path = lmdb_path
|
||||
self.coding = coding
|
||||
env = self._get_env()
|
||||
with env.begin(write=False) as txn:
|
||||
self.total_number = int(
|
||||
txn.get('total_number'.encode(self.coding)).decode(
|
||||
self.coding))
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Retrieval one line from lmdb file by index."""
|
||||
# only attach env to self when __getitem__ is called
|
||||
# because env object cannot be pickle
|
||||
if not hasattr(self, 'env'):
|
||||
self.env = self._get_env()
|
||||
|
||||
with self.env.begin(write=False) as txn:
|
||||
line = txn.get(str(index).encode(self.coding)).decode(self.coding)
|
||||
return line
|
||||
|
||||
def __len__(self):
|
||||
return self.total_number
|
||||
|
||||
def _get_env(self):
|
||||
import lmdb
|
||||
return lmdb.open(
|
||||
self.lmdb_path,
|
||||
max_readers=1,
|
||||
readonly=True,
|
||||
lock=False,
|
||||
readahead=False,
|
||||
meminit=False,
|
||||
)
|
|
@ -0,0 +1,69 @@
|
|||
import json
|
||||
|
||||
from mmocr.datasets.builder import PARSERS
|
||||
|
||||
|
||||
@PARSERS.register_module()
|
||||
class LineStrParser:
|
||||
"""Parse string of one line in annotation file to dict format.
|
||||
|
||||
Args:
|
||||
keys (list[str]): Keys in result dict.
|
||||
keys_idx (list[int]): Value index in sub-string list
|
||||
for each key above.
|
||||
separator (str): Separator to separate string to list of sub-string.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
keys=['filename', 'text'],
|
||||
keys_idx=[0, 1],
|
||||
separator=' '):
|
||||
assert isinstance(keys, list)
|
||||
assert isinstance(keys_idx, list)
|
||||
assert isinstance(separator, str)
|
||||
assert len(keys) > 0
|
||||
assert len(keys) == len(keys_idx)
|
||||
self.keys = keys
|
||||
self.keys_idx = keys_idx
|
||||
self.separator = separator
|
||||
|
||||
def get_item(self, data_ret, index):
|
||||
map_index = index % len(data_ret)
|
||||
line_str = data_ret[map_index]
|
||||
for split_key in self.separator:
|
||||
if split_key != ' ':
|
||||
line_str = line_str.replace(split_key, ' ')
|
||||
line_str = line_str.split()
|
||||
if len(line_str) <= max(self.keys_idx):
|
||||
raise Exception(
|
||||
f'key index: {max(self.keys_idx)} out of range: {line_str}')
|
||||
|
||||
line_info = {}
|
||||
for i, key in enumerate(self.keys):
|
||||
line_info[key] = line_str[self.keys_idx[i]]
|
||||
return line_info
|
||||
|
||||
|
||||
@PARSERS.register_module()
|
||||
class LineJsonParser:
|
||||
"""Parse json-string of one line in annotation file to dict format.
|
||||
|
||||
Args:
|
||||
keys (list[str]): Keys in both json-string and result dict.
|
||||
"""
|
||||
|
||||
def __init__(self, keys=[], **kwargs):
|
||||
assert isinstance(keys, list)
|
||||
assert len(keys) > 0
|
||||
self.keys = keys
|
||||
|
||||
def get_item(self, data_ret, index):
|
||||
map_index = index % len(data_ret)
|
||||
line_json_obj = json.loads(data_ret[map_index])
|
||||
line_info = {}
|
||||
for key in self.keys:
|
||||
if key not in line_json_obj:
|
||||
raise Exception(f'key {key} not in line json {line_json_obj}')
|
||||
line_info[key] = line_json_obj[key]
|
||||
|
||||
return line_info
|
|
@ -0,0 +1,8 @@
|
|||
from .backbones import * # noqa: F401,F403
|
||||
from .convertors import * # noqa: F401,F403
|
||||
from .decoders import * # noqa: F401,F403
|
||||
from .encoders import * # noqa: F401,F403
|
||||
from .heads import * # noqa: F401,F403
|
||||
from .losses import * # noqa: F401,F403
|
||||
from .necks import * # noqa: F401,F403
|
||||
from .recognizer import * # noqa: F401,F403
|
|
@ -0,0 +1,4 @@
|
|||
from .resnet31_ocr import ResNet31OCR
|
||||
from .very_deep_vgg import VeryDeepVgg
|
||||
|
||||
__all__ = ['ResNet31OCR', 'VeryDeepVgg']
|
|
@ -0,0 +1,149 @@
|
|||
import torch.nn as nn
|
||||
from mmcv.cnn import kaiming_init, uniform_init
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmdet.models.builder import BACKBONES
|
||||
from mmocr.models.textrecog.layers import BasicBlock
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ResNet31OCR(nn.Module):
|
||||
"""Implement ResNet backbone for text recognition, modified from
|
||||
`ResNet <https://arxiv.org/pdf/1512.03385.pdf>`_
|
||||
Args:
|
||||
base_channels (int): Number of channels of input image tensor.
|
||||
layers (list[int]): List of BasicBlock number for each stage.
|
||||
channels (list[int]): List of out_channels of Conv2d layer.
|
||||
out_indices (None | Sequence[int]): Indicdes of output stages.
|
||||
stage4_pool_cfg (dict): Dictionary to construct and configure
|
||||
pooling layer in stage 4.
|
||||
last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
base_channels=3,
|
||||
layers=[1, 2, 5, 3],
|
||||
channels=[64, 128, 256, 256, 512, 512, 512],
|
||||
out_indices=None,
|
||||
stage4_pool_cfg=dict(kernel_size=(2, 1), stride=(2, 1)),
|
||||
last_stage_pool=False):
|
||||
super().__init__()
|
||||
assert isinstance(base_channels, int)
|
||||
assert utils.is_type_list(layers, int)
|
||||
assert utils.is_type_list(channels, int)
|
||||
assert out_indices is None or (isinstance(out_indices, list)
|
||||
or isinstance(out_indices, tuple))
|
||||
assert isinstance(last_stage_pool, bool)
|
||||
|
||||
self.out_indices = out_indices
|
||||
self.last_stage_pool = last_stage_pool
|
||||
|
||||
# conv 1 (Conv, Conv)
|
||||
self.conv1_1 = nn.Conv2d(
|
||||
base_channels, channels[0], kernel_size=3, stride=1, padding=1)
|
||||
self.bn1_1 = nn.BatchNorm2d(channels[0])
|
||||
self.relu1_1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.conv1_2 = nn.Conv2d(
|
||||
channels[0], channels[1], kernel_size=3, stride=1, padding=1)
|
||||
self.bn1_2 = nn.BatchNorm2d(channels[1])
|
||||
self.relu1_2 = nn.ReLU(inplace=True)
|
||||
|
||||
# conv 2 (Max-pooling, Residual block, Conv)
|
||||
self.pool2 = nn.MaxPool2d(
|
||||
kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||
self.block2 = self._make_layer(channels[1], channels[2], layers[0])
|
||||
self.conv2 = nn.Conv2d(
|
||||
channels[2], channels[2], kernel_size=3, stride=1, padding=1)
|
||||
self.bn2 = nn.BatchNorm2d(channels[2])
|
||||
self.relu2 = nn.ReLU(inplace=True)
|
||||
|
||||
# conv 3 (Max-pooling, Residual block, Conv)
|
||||
self.pool3 = nn.MaxPool2d(
|
||||
kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||
self.block3 = self._make_layer(channels[2], channels[3], layers[1])
|
||||
self.conv3 = nn.Conv2d(
|
||||
channels[3], channels[3], kernel_size=3, stride=1, padding=1)
|
||||
self.bn3 = nn.BatchNorm2d(channels[3])
|
||||
self.relu3 = nn.ReLU(inplace=True)
|
||||
|
||||
# conv 4 (Max-pooling, Residual block, Conv)
|
||||
self.pool4 = nn.MaxPool2d(padding=0, ceil_mode=True, **stage4_pool_cfg)
|
||||
self.block4 = self._make_layer(channels[3], channels[4], layers[2])
|
||||
self.conv4 = nn.Conv2d(
|
||||
channels[4], channels[4], kernel_size=3, stride=1, padding=1)
|
||||
self.bn4 = nn.BatchNorm2d(channels[4])
|
||||
self.relu4 = nn.ReLU(inplace=True)
|
||||
|
||||
# conv 5 ((Max-pooling), Residual block, Conv)
|
||||
self.pool5 = None
|
||||
if self.last_stage_pool:
|
||||
self.pool5 = nn.MaxPool2d(
|
||||
kernel_size=2, stride=2, padding=0, ceil_mode=True) # 1/16
|
||||
self.block5 = self._make_layer(channels[4], channels[5], layers[3])
|
||||
self.conv5 = nn.Conv2d(
|
||||
channels[5], channels[5], kernel_size=3, stride=1, padding=1)
|
||||
self.bn5 = nn.BatchNorm2d(channels[5])
|
||||
self.relu5 = nn.ReLU(inplace=True)
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
# initialize weight and bias
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
uniform_init(m)
|
||||
|
||||
def _make_layer(self, input_channels, output_channels, blocks):
|
||||
layers = []
|
||||
for _ in range(blocks):
|
||||
downsample = None
|
||||
if input_channels != output_channels:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
input_channels,
|
||||
output_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
bias=False),
|
||||
nn.BatchNorm2d(output_channels),
|
||||
)
|
||||
layers.append(
|
||||
BasicBlock(
|
||||
input_channels, output_channels, downsample=downsample))
|
||||
input_channels = output_channels
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = self.conv1_1(x)
|
||||
x = self.bn1_1(x)
|
||||
x = self.relu1_1(x)
|
||||
|
||||
x = self.conv1_2(x)
|
||||
x = self.bn1_2(x)
|
||||
x = self.relu1_2(x)
|
||||
|
||||
outs = []
|
||||
for i in range(4):
|
||||
layer_index = i + 2
|
||||
pool_layer = getattr(self, f'pool{layer_index}')
|
||||
block_layer = getattr(self, f'block{layer_index}')
|
||||
conv_layer = getattr(self, f'conv{layer_index}')
|
||||
bn_layer = getattr(self, f'bn{layer_index}')
|
||||
relu_layer = getattr(self, f'relu{layer_index}')
|
||||
|
||||
if pool_layer is not None:
|
||||
x = pool_layer(x)
|
||||
x = block_layer(x)
|
||||
x = conv_layer(x)
|
||||
x = bn_layer(x)
|
||||
x = relu_layer(x)
|
||||
|
||||
outs.append(x)
|
||||
|
||||
if self.out_indices is not None:
|
||||
return tuple([outs[i] for i in self.out_indices])
|
||||
|
||||
return x
|
|
@ -0,0 +1,6 @@
|
|||
from .attn import AttnConvertor
|
||||
from .base import BaseConvertor
|
||||
from .ctc import CTCConvertor
|
||||
from .seg import SegConvertor
|
||||
|
||||
__all__ = ['BaseConvertor', 'CTCConvertor', 'AttnConvertor', 'SegConvertor']
|
|
@ -0,0 +1,140 @@
|
|||
import torch
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmocr.models.builder import CONVERTORS
|
||||
from .base import BaseConvertor
|
||||
|
||||
|
||||
@CONVERTORS.register_module()
|
||||
class AttnConvertor(BaseConvertor):
|
||||
"""Convert between text, index and tensor for encoder-decoder based
|
||||
pipeline.
|
||||
|
||||
Args:
|
||||
dict_type (str): Type of dict, should be one of {'DICT36', 'DICT90'}.
|
||||
dict_file (None|str): Character dict file path. If not none,
|
||||
higher priority than dict_type.
|
||||
dict_list (None|list[str]): Character list. If not none, higher
|
||||
priority than dict_type, but lower than dict_file.
|
||||
with_unknown (bool): If True, add `UKN` token to class.
|
||||
max_seq_len (int): Maximum sequence length of label.
|
||||
lower (bool): If True, convert original string to lower case.
|
||||
start_end_same (bool): Whether use the same index for
|
||||
start and end token or not. Default: True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dict_type='DICT90',
|
||||
dict_file=None,
|
||||
dict_list=None,
|
||||
with_unknown=True,
|
||||
max_seq_len=40,
|
||||
lower=False,
|
||||
start_end_same=True,
|
||||
**kwargs):
|
||||
super().__init__(dict_type, dict_file, dict_list)
|
||||
assert isinstance(with_unknown, bool)
|
||||
assert isinstance(max_seq_len, int)
|
||||
assert isinstance(lower, bool)
|
||||
|
||||
self.with_unknown = with_unknown
|
||||
self.max_seq_len = max_seq_len
|
||||
self.lower = lower
|
||||
self.start_end_same = start_end_same
|
||||
|
||||
self.update_dict()
|
||||
|
||||
def update_dict(self):
|
||||
start_end_token = '<BOS/EOS>'
|
||||
unknown_token = '<UKN>'
|
||||
padding_token = '<PAD>'
|
||||
|
||||
# unknown
|
||||
self.unknown_idx = None
|
||||
if self.with_unknown:
|
||||
self.idx2char.append(unknown_token)
|
||||
self.unknown_idx = len(self.idx2char) - 1
|
||||
|
||||
# BOS/EOS
|
||||
self.idx2char.append(start_end_token)
|
||||
self.start_idx = len(self.idx2char) - 1
|
||||
if not self.start_end_same:
|
||||
self.idx2char.append(start_end_token)
|
||||
self.end_idx = len(self.idx2char) - 1
|
||||
|
||||
# padding
|
||||
self.idx2char.append(padding_token)
|
||||
self.padding_idx = len(self.idx2char) - 1
|
||||
|
||||
# update char2idx
|
||||
self.char2idx = {}
|
||||
for idx, char in enumerate(self.idx2char):
|
||||
self.char2idx[char] = idx
|
||||
|
||||
def str2tensor(self, strings):
|
||||
"""
|
||||
Convert text-string into tensor.
|
||||
Args:
|
||||
strings (list[str]): ['hello', 'world']
|
||||
Returns:
|
||||
dict (str: Tensor | list[tensor]):
|
||||
tensors (list[Tensor]): [torch.Tensor([1,2,3,3,4]),
|
||||
torch.Tensor([5,4,6,3,7])]
|
||||
padded_targets (Tensor(bsz * max_seq_len))
|
||||
"""
|
||||
assert utils.is_type_list(strings, str)
|
||||
|
||||
tensors, padded_targets = [], []
|
||||
indexes = self.str2idx(strings)
|
||||
for index in indexes:
|
||||
tensor = torch.LongTensor(index)
|
||||
tensors.append(tensor)
|
||||
# target tensor for loss
|
||||
src_target = torch.LongTensor(tensor.size(0) + 2).fill_(0)
|
||||
src_target[-1] = self.end_idx
|
||||
src_target[0] = self.start_idx
|
||||
src_target[1:-1] = tensor
|
||||
padded_target = (torch.ones(self.max_seq_len) *
|
||||
self.padding_idx).long()
|
||||
char_num = src_target.size(0)
|
||||
if char_num > self.max_seq_len:
|
||||
padded_target = src_target[:self.max_seq_len]
|
||||
else:
|
||||
padded_target[:char_num] = src_target
|
||||
padded_targets.append(padded_target)
|
||||
padded_targets = torch.stack(padded_targets, 0).long()
|
||||
|
||||
return {'targets': tensors, 'padded_targets': padded_targets}
|
||||
|
||||
def tensor2idx(self, outputs, img_metas=None):
|
||||
"""
|
||||
Convert output tensor to text-index
|
||||
Args:
|
||||
outputs (tensor): model outputs with size: N * T * C
|
||||
img_metas (list[dict]): Each dict contains one image info.
|
||||
Returns:
|
||||
indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]]
|
||||
scores (list[list[float]]): [[0.9,0.8,0.95,0.97,0.94],
|
||||
[0.9,0.9,0.98,0.97,0.96]]
|
||||
"""
|
||||
batch_size = outputs.size(0)
|
||||
ignore_indexes = [self.padding_idx]
|
||||
indexes, scores = [], []
|
||||
for idx in range(batch_size):
|
||||
seq = outputs[idx, :, :]
|
||||
max_value, max_idx = torch.max(seq, -1)
|
||||
str_index, str_score = [], []
|
||||
output_index = max_idx.cpu().detach().numpy().tolist()
|
||||
output_score = max_value.cpu().detach().numpy().tolist()
|
||||
for char_index, char_score in zip(output_index, output_score):
|
||||
if char_index in ignore_indexes:
|
||||
continue
|
||||
if char_index == self.end_idx:
|
||||
break
|
||||
str_index.append(char_index)
|
||||
str_score.append(char_score)
|
||||
|
||||
indexes.append(str_index)
|
||||
scores.append(str_score)
|
||||
|
||||
return indexes, scores
|
|
@ -0,0 +1,115 @@
|
|||
from mmocr.models.builder import CONVERTORS
|
||||
|
||||
|
||||
@CONVERTORS.register_module()
|
||||
class BaseConvertor:
|
||||
"""Convert between text, index and tensor for text recognize pipeline.
|
||||
|
||||
Args:
|
||||
dict_type (str): Type of dict, should be either 'DICT36' or 'DICT90'.
|
||||
dict_file (None|str): Character dict file path. If not none,
|
||||
the dict_file is of higher priority than dict_type.
|
||||
dict_list (None|list[str]): Character list. If not none, the list
|
||||
is of higher priority than dict_type, but lower than dict_file.
|
||||
"""
|
||||
start_idx = end_idx = padding_idx = 0
|
||||
unknown_idx = None
|
||||
lower = False
|
||||
|
||||
DICT36 = tuple('0123456789abcdefghijklmnopqrstuvwxyz')
|
||||
DICT90 = tuple('0123456789abcdefghijklmnopqrstuvwxyz'
|
||||
'ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()'
|
||||
'*+,-./:;<=>?@[\\]_`~')
|
||||
|
||||
def __init__(self, dict_type='DICT90', dict_file=None, dict_list=None):
|
||||
assert dict_type in ('DICT36', 'DICT90')
|
||||
assert dict_file is None or isinstance(dict_file, str)
|
||||
assert dict_list is None or isinstance(dict_list, list)
|
||||
self.idx2char = []
|
||||
if dict_file is not None:
|
||||
with open(dict_file, encoding='utf-8') as fr:
|
||||
for line in fr:
|
||||
line = line.strip()
|
||||
if line != '':
|
||||
self.idx2char.append(line)
|
||||
elif dict_list is not None:
|
||||
self.idx2char = dict_list
|
||||
else:
|
||||
if dict_type == 'DICT36':
|
||||
self.idx2char = list(self.DICT36)
|
||||
else:
|
||||
self.idx2char = list(self.DICT90)
|
||||
|
||||
self.char2idx = {}
|
||||
for idx, char in enumerate(self.idx2char):
|
||||
self.char2idx[char] = idx
|
||||
|
||||
def num_classes(self):
|
||||
"""Number of output classes."""
|
||||
return len(self.idx2char)
|
||||
|
||||
def str2idx(self, strings):
|
||||
"""Convert strings to indexes.
|
||||
|
||||
Args:
|
||||
strings (list[str]): ['hello', 'world'].
|
||||
Returns:
|
||||
indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]].
|
||||
"""
|
||||
assert isinstance(strings, list)
|
||||
|
||||
indexes = []
|
||||
for string in strings:
|
||||
if self.lower:
|
||||
string = string.lower()
|
||||
index = []
|
||||
for char in string:
|
||||
char_idx = self.char2idx.get(char, self.unknown_idx)
|
||||
if char_idx is None:
|
||||
raise Exception(f'Chararcter: {char} not in dict,'
|
||||
f' please check gt_label and use'
|
||||
f' custom dict file,'
|
||||
f' or set "with_unknown=True"')
|
||||
index.append(char_idx)
|
||||
indexes.append(index)
|
||||
|
||||
return indexes
|
||||
|
||||
def str2tensor(self, strings):
|
||||
"""Convert text-string to input tensor.
|
||||
|
||||
Args:
|
||||
strings (list[str]): ['hello', 'world'].
|
||||
Returns:
|
||||
tensors (list[torch.Tensor]): [torch.Tensor([1,2,3,3,4]),
|
||||
torch.Tensor([5,4,6,3,7])].
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def idx2str(self, indexes):
|
||||
"""Convert indexes to text strings.
|
||||
|
||||
Args:
|
||||
indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]].
|
||||
Returns:
|
||||
strings (list[str]): ['hello', 'world'].
|
||||
"""
|
||||
assert isinstance(indexes, list)
|
||||
|
||||
strings = []
|
||||
for index in indexes:
|
||||
string = [self.idx2char[i] for i in index]
|
||||
strings.append(''.join(string))
|
||||
|
||||
return strings
|
||||
|
||||
def tensor2idx(self, output):
|
||||
"""Convert model output tensor to character indexes and scores.
|
||||
Args:
|
||||
output (tensor): The model outputs with size: N * T * C
|
||||
Returns:
|
||||
indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]].
|
||||
scores (list[list[float]]): [[0.9,0.8,0.95,0.97,0.94],
|
||||
[0.9,0.9,0.98,0.97,0.96]].
|
||||
"""
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,144 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmocr.models.builder import CONVERTORS
|
||||
from .base import BaseConvertor
|
||||
|
||||
|
||||
@CONVERTORS.register_module()
|
||||
class CTCConvertor(BaseConvertor):
|
||||
"""Convert between text, index and tensor for CTC loss-based pipeline.
|
||||
|
||||
Args:
|
||||
dict_type (str): Type of dict, should be either 'DICT36' or 'DICT90'.
|
||||
dict_file (None|str): Character dict file path. If not none, the file
|
||||
is of higher priority than dict_type.
|
||||
dict_list (None|list[str]): Character list. If not none, the list
|
||||
is of higher priority than dict_type, but lower than dict_file.
|
||||
with_unknown (bool): If True, add `UKN` token to class.
|
||||
lower (bool): If True, convert original string to lower case.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dict_type='DICT90',
|
||||
dict_file=None,
|
||||
dict_list=None,
|
||||
with_unknown=True,
|
||||
lower=False,
|
||||
**kwargs):
|
||||
super().__init__(dict_type, dict_file, dict_list)
|
||||
assert isinstance(with_unknown, bool)
|
||||
assert isinstance(lower, bool)
|
||||
|
||||
self.with_unknown = with_unknown
|
||||
self.lower = lower
|
||||
self.update_dict()
|
||||
|
||||
def update_dict(self):
|
||||
# CTC-blank
|
||||
blank_token = '<BLK>'
|
||||
self.blank_idx = 0
|
||||
self.idx2char.insert(0, blank_token)
|
||||
|
||||
# unknown
|
||||
self.unknown_idx = None
|
||||
if self.with_unknown:
|
||||
self.idx2char.append('<UKN>')
|
||||
self.unknown_idx = len(self.idx2char) - 1
|
||||
|
||||
# update char2idx
|
||||
self.char2idx = {}
|
||||
for idx, char in enumerate(self.idx2char):
|
||||
self.char2idx[char] = idx
|
||||
|
||||
def str2tensor(self, strings):
|
||||
"""Convert text-string to ctc-loss input tensor.
|
||||
|
||||
Args:
|
||||
strings (list[str]): ['hello', 'world'].
|
||||
Returns:
|
||||
dict (str: tensor | list[tensor]):
|
||||
tensors (list[tensor]): [torch.Tensor([1,2,3,3,4]),
|
||||
torch.Tensor([5,4,6,3,7])].
|
||||
flatten_targets (tensor): torch.Tensor([1,2,3,3,4,5,4,6,3,7]).
|
||||
target_lengths (tensor): torch.IntTensot([5,5]).
|
||||
"""
|
||||
assert utils.is_type_list(strings, str)
|
||||
|
||||
tensors = []
|
||||
indexes = self.str2idx(strings)
|
||||
for index in indexes:
|
||||
tensor = torch.IntTensor(index)
|
||||
tensors.append(tensor)
|
||||
target_lengths = torch.IntTensor([len(t) for t in tensors])
|
||||
flatten_target = torch.cat(tensors)
|
||||
|
||||
return {
|
||||
'targets': tensors,
|
||||
'flatten_targets': flatten_target,
|
||||
'target_lengths': target_lengths
|
||||
}
|
||||
|
||||
def tensor2idx(self, output, img_metas, topk=1, return_topk=False):
|
||||
"""Convert model output tensor to index-list.
|
||||
Args:
|
||||
output (tensor): The model outputs with size: N * T * C.
|
||||
img_metas (list[dict]): Each dict contains one image info.
|
||||
topk (int): The highest k classes to be returned.
|
||||
return_topk (bool): Whether to return topk or just top1.
|
||||
Returns:
|
||||
indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]].
|
||||
scores (list[list[float]]): [[0.9,0.8,0.95,0.97,0.94],
|
||||
[0.9,0.9,0.98,0.97,0.96]]
|
||||
(
|
||||
indexes_topk (list[list[list[int]->len=topk]]):
|
||||
scores_topk (list[list[list[float]->len=topk]])
|
||||
).
|
||||
"""
|
||||
assert utils.is_type_list(img_metas, dict)
|
||||
assert len(img_metas) == output.size(0)
|
||||
assert isinstance(topk, int)
|
||||
assert topk >= 1
|
||||
|
||||
valid_ratios = [
|
||||
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||
]
|
||||
|
||||
batch_size = output.size(0)
|
||||
output = F.softmax(output, dim=2)
|
||||
output = output.cpu().detach()
|
||||
batch_topk_value, batch_topk_idx = output.topk(topk, dim=2)
|
||||
batch_max_idx = batch_topk_idx[:, :, 0]
|
||||
scores_topk, indexes_topk = [], []
|
||||
scores, indexes = [], []
|
||||
feat_len = output.size(1)
|
||||
for b in range(batch_size):
|
||||
valid_ratio = valid_ratios[b]
|
||||
decode_len = min(feat_len, math.ceil(feat_len * valid_ratio))
|
||||
pred = batch_max_idx[b, :]
|
||||
select_idx = []
|
||||
prev_idx = self.blank_idx
|
||||
for t in range(decode_len):
|
||||
tmp_value = pred[t].item()
|
||||
if tmp_value not in (prev_idx, self.blank_idx):
|
||||
select_idx.append(t)
|
||||
prev_idx = tmp_value
|
||||
select_idx = torch.LongTensor(select_idx)
|
||||
topk_value = torch.index_select(batch_topk_value[b, :, :], 0,
|
||||
select_idx) # valid_seqlen * topk
|
||||
topk_idx = torch.index_select(batch_topk_idx[b, :, :], 0,
|
||||
select_idx)
|
||||
topk_idx_list, topk_value_list = topk_idx.numpy().tolist(
|
||||
), topk_value.numpy().tolist()
|
||||
indexes_topk.append(topk_idx_list)
|
||||
scores_topk.append(topk_value_list)
|
||||
indexes.append([x[0] for x in topk_idx_list])
|
||||
scores.append([x[0] for x in topk_value_list])
|
||||
|
||||
if return_topk:
|
||||
return indexes_topk, scores_topk
|
||||
|
||||
return indexes, scores
|
|
@ -0,0 +1,123 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmocr.models.builder import CONVERTORS
|
||||
from .base import BaseConvertor
|
||||
|
||||
|
||||
@CONVERTORS.register_module()
|
||||
class SegConvertor(BaseConvertor):
|
||||
"""Convert between text, index and tensor for segmentation based pipeline.
|
||||
|
||||
Args:
|
||||
dict_type (str): Type of dict, should be either 'DICT36' or 'DICT90'.
|
||||
dict_file (None|str): Character dict file path. If not none, the
|
||||
file is of higher priority than dict_type.
|
||||
dict_list (None|list[str]): Character list. If not none, the list
|
||||
is of higher priority than dict_type, but lower than dict_file.
|
||||
with_unknown (bool): If True, add `UKN` token to class.
|
||||
lower (bool): If True, convert original string to lower case.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dict_type='DICT36',
|
||||
dict_file=None,
|
||||
dict_list=None,
|
||||
with_unknown=True,
|
||||
lower=False,
|
||||
**kwargs):
|
||||
super().__init__(dict_type, dict_file, dict_list)
|
||||
assert isinstance(with_unknown, bool)
|
||||
assert isinstance(lower, bool)
|
||||
|
||||
self.with_unknown = with_unknown
|
||||
self.lower = lower
|
||||
self.update_dict()
|
||||
|
||||
def update_dict(self):
|
||||
# background
|
||||
self.idx2char.insert(0, '<BG>')
|
||||
|
||||
# unknown
|
||||
self.unknown_idx = None
|
||||
if self.with_unknown:
|
||||
self.idx2char.append('<UKN>')
|
||||
self.unknown_idx = len(self.idx2char) - 1
|
||||
|
||||
# update char2idx
|
||||
self.char2idx = {}
|
||||
for idx, char in enumerate(self.idx2char):
|
||||
self.char2idx[char] = idx
|
||||
|
||||
def tensor2str(self, output, img_metas=None):
|
||||
"""Convert model output tensor to string labels.
|
||||
Args:
|
||||
output (tensor): Model outputs with size: N * C * H * W
|
||||
img_metas (list[dict]): Each dict contains one image info.
|
||||
Returns:
|
||||
texts (list[str]): Decoded text labels.
|
||||
scores (list[list[float]]): Decoded chars scores.
|
||||
"""
|
||||
assert utils.is_type_list(img_metas, dict)
|
||||
assert len(img_metas) == output.size(0)
|
||||
|
||||
texts, scores = [], []
|
||||
for b in range(output.size(0)):
|
||||
seg_pred = output[b].detach()
|
||||
seg_res = torch.argmax(
|
||||
seg_pred, dim=0).cpu().numpy().astype(np.int32)
|
||||
|
||||
seg_thr = np.where(seg_res == 0, 0, 255).astype(np.uint8)
|
||||
_, labels, stats, centroids = cv2.connectedComponentsWithStats(
|
||||
seg_thr)
|
||||
|
||||
component_num = stats.shape[0]
|
||||
|
||||
all_res = []
|
||||
for i in range(component_num):
|
||||
temp_loc = (labels == i)
|
||||
temp_value = seg_res[temp_loc]
|
||||
temp_center = centroids[i]
|
||||
|
||||
temp_max_num = 0
|
||||
temp_max_cls = -1
|
||||
temp_total_num = 0
|
||||
for c in range(len(self.idx2char)):
|
||||
c_num = np.sum(temp_value == c)
|
||||
temp_total_num += c_num
|
||||
if c_num > temp_max_num:
|
||||
temp_max_num = c_num
|
||||
temp_max_cls = c
|
||||
|
||||
if temp_max_cls == 0:
|
||||
continue
|
||||
temp_max_score = 1.0 * temp_max_num / temp_total_num
|
||||
all_res.append(
|
||||
[temp_max_cls, temp_center, temp_max_num, temp_max_score])
|
||||
|
||||
all_res = sorted(all_res, key=lambda s: s[1][0])
|
||||
chars, char_scores = [], []
|
||||
for res in all_res:
|
||||
temp_area = res[2]
|
||||
if temp_area < 20:
|
||||
continue
|
||||
temp_char_index = res[0]
|
||||
if temp_char_index >= len(self.idx2char):
|
||||
temp_char = ''
|
||||
elif temp_char_index <= 0:
|
||||
temp_char = ''
|
||||
elif temp_char_index == self.unknown_idx:
|
||||
temp_char = ''
|
||||
else:
|
||||
temp_char = self.idx2char[temp_char_index]
|
||||
chars.append(temp_char)
|
||||
char_scores.append(res[3])
|
||||
|
||||
text = ''.join(chars)
|
||||
|
||||
texts.append(text)
|
||||
scores.append(char_scores)
|
||||
|
||||
return texts, scores
|
|
@ -0,0 +1,15 @@
|
|||
from .base_decoder import BaseDecoder
|
||||
from .crnn_decoder import CRNNDecoder
|
||||
from .position_attention_decoder import PositionAttentionDecoder
|
||||
from .robust_scanner_decoder import RobustScannerDecoder
|
||||
from .sar_decoder import ParallelSARDecoder, SequentialSARDecoder
|
||||
from .sar_decoder_with_bs import ParallelSARDecoderWithBS
|
||||
from .sequence_attention_decoder import SequenceAttentionDecoder
|
||||
from .transformer_decoder import TFDecoder
|
||||
|
||||
__all__ = [
|
||||
'CRNNDecoder', 'ParallelSARDecoder', 'SequentialSARDecoder',
|
||||
'ParallelSARDecoderWithBS', 'TFDecoder', 'BaseDecoder',
|
||||
'SequenceAttentionDecoder', 'PositionAttentionDecoder',
|
||||
'RobustScannerDecoder'
|
||||
]
|
|
@ -0,0 +1,32 @@
|
|||
import torch.nn as nn
|
||||
|
||||
from mmocr.models.builder import DECODERS
|
||||
|
||||
|
||||
@DECODERS.register_module()
|
||||
class BaseDecoder(nn.Module):
|
||||
"""Base decoder class for text recognition."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
def init_weights(self):
|
||||
pass
|
||||
|
||||
def forward_train(self, feat, out_enc, targets_dict, img_metas):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_test(self, feat, out_enc, img_metas):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self,
|
||||
feat,
|
||||
out_enc,
|
||||
targets_dict=None,
|
||||
img_metas=None,
|
||||
train_mode=True):
|
||||
self.train_mode = train_mode
|
||||
if train_mode:
|
||||
return self.forward_train(feat, out_enc, targets_dict, img_metas)
|
||||
|
||||
return self.forward_test(feat, out_enc, img_metas)
|
|
@ -0,0 +1,49 @@
|
|||
import torch.nn as nn
|
||||
from mmcv.cnn import xavier_init
|
||||
|
||||
from mmocr.models.builder import DECODERS
|
||||
from mmocr.models.textrecog.layers import BidirectionalLSTM
|
||||
from .base_decoder import BaseDecoder
|
||||
|
||||
|
||||
@DECODERS.register_module()
|
||||
class CRNNDecoder(BaseDecoder):
|
||||
|
||||
def __init__(self,
|
||||
in_channels=None,
|
||||
num_classes=None,
|
||||
rnn_flag=False,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.rnn_flag = rnn_flag
|
||||
|
||||
if rnn_flag:
|
||||
self.decoder = nn.Sequential(
|
||||
BidirectionalLSTM(in_channels, 256, 256),
|
||||
BidirectionalLSTM(256, 256, num_classes))
|
||||
else:
|
||||
self.decoder = nn.Conv2d(
|
||||
in_channels, num_classes, kernel_size=1, stride=1)
|
||||
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
xavier_init(m)
|
||||
|
||||
def forward_train(self, feat, out_enc, targets_dict, img_metas):
|
||||
assert feat.size(2) == 1, 'feature height must be 1'
|
||||
if self.rnn_flag:
|
||||
x = feat.squeeze(2) # [N, C, W]
|
||||
x = x.permute(2, 0, 1) # [W, N, C]
|
||||
x = self.decoder(x) # [W, N, C]
|
||||
outputs = x.permute(1, 0, 2).contiguous()
|
||||
else:
|
||||
x = self.decoder(feat)
|
||||
x = x.permute(0, 3, 1, 2).contiguous()
|
||||
n, w, c, h = x.size()
|
||||
outputs = x.view(n, w, c * h)
|
||||
return outputs
|
||||
|
||||
def forward_test(self, feat, out_enc, img_metas):
|
||||
return self.forward_train(feat, out_enc, None, img_metas)
|
|
@ -0,0 +1,407 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmocr.models.builder import DECODERS
|
||||
from .base_decoder import BaseDecoder
|
||||
|
||||
|
||||
@DECODERS.register_module()
|
||||
class ParallelSARDecoder(BaseDecoder):
|
||||
"""Implementation Parallel Decoder module in `SAR.
|
||||
|
||||
<https://arxiv.org/abs/1811.00751>`_
|
||||
|
||||
Args:
|
||||
number_classes (int): Output class number.
|
||||
channels (list[int]): Network layer channels.
|
||||
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
|
||||
dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
|
||||
dec_do_rnn (float): Dropout of RNN layer in decoder.
|
||||
dec_gru (bool): If True, use GRU, else LSTM in decoder.
|
||||
d_model (int): Dim of channels from backbone.
|
||||
d_enc (int): Dim of encoder RNN layer.
|
||||
d_k (int): Dim of channels of attention module.
|
||||
pred_dropout (float): Dropout probability of prediction layer.
|
||||
max_seq_len (int): Maximum sequence length for decoding.
|
||||
mask (bool): If True, mask padding in feature map.
|
||||
start_idx (int): Index of start token.
|
||||
padding_idx (int): Index of padding token.
|
||||
pred_concat (bool): If True, concat glimpse feature from
|
||||
attention with holistic feature and hidden state.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes=37,
|
||||
enc_bi_rnn=False,
|
||||
dec_bi_rnn=False,
|
||||
dec_do_rnn=0.0,
|
||||
dec_gru=False,
|
||||
d_model=512,
|
||||
d_enc=512,
|
||||
d_k=64,
|
||||
pred_dropout=0.0,
|
||||
max_seq_len=40,
|
||||
mask=True,
|
||||
start_idx=0,
|
||||
padding_idx=92,
|
||||
pred_concat=False,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.enc_bi_rnn = enc_bi_rnn
|
||||
self.d_k = d_k
|
||||
self.start_idx = start_idx
|
||||
self.max_seq_len = max_seq_len
|
||||
self.mask = mask
|
||||
self.pred_concat = pred_concat
|
||||
|
||||
encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
|
||||
decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1)
|
||||
# 2D attention layer
|
||||
self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k)
|
||||
self.conv3x3_1 = nn.Conv2d(
|
||||
d_model, d_k, kernel_size=3, stride=1, padding=1)
|
||||
self.conv1x1_2 = nn.Linear(d_k, 1)
|
||||
|
||||
# Decoder RNN layer
|
||||
kwargs = dict(
|
||||
input_size=encoder_rnn_out_size,
|
||||
hidden_size=encoder_rnn_out_size,
|
||||
num_layers=2,
|
||||
batch_first=True,
|
||||
dropout=dec_do_rnn,
|
||||
bidirectional=dec_bi_rnn)
|
||||
if dec_gru:
|
||||
self.rnn_decoder = nn.GRU(**kwargs)
|
||||
else:
|
||||
self.rnn_decoder = nn.LSTM(**kwargs)
|
||||
|
||||
# Decoder input embedding
|
||||
self.embedding = nn.Embedding(
|
||||
self.num_classes, encoder_rnn_out_size, padding_idx=padding_idx)
|
||||
|
||||
# Prediction layer
|
||||
self.pred_dropout = nn.Dropout(pred_dropout)
|
||||
pred_num_classes = num_classes - 1 # ignore padding_idx in prediction
|
||||
if pred_concat:
|
||||
fc_in_channel = decoder_rnn_out_size + d_model + d_enc
|
||||
else:
|
||||
fc_in_channel = d_model
|
||||
self.prediction = nn.Linear(fc_in_channel, pred_num_classes)
|
||||
|
||||
def _2d_attention(self,
|
||||
decoder_input,
|
||||
feat,
|
||||
holistic_feat,
|
||||
valid_ratios=None):
|
||||
y = self.rnn_decoder(decoder_input)[0]
|
||||
# y: bsz * (seq_len + 1) * hidden_size
|
||||
|
||||
attn_query = self.conv1x1_1(y) # bsz * (seq_len + 1) * attn_size
|
||||
bsz, seq_len, attn_size = attn_query.size()
|
||||
attn_query = attn_query.view(bsz, seq_len, attn_size, 1, 1)
|
||||
|
||||
attn_key = self.conv3x3_1(feat)
|
||||
# bsz * attn_size * h * w
|
||||
attn_key = attn_key.unsqueeze(1)
|
||||
# bsz * 1 * attn_size * h * w
|
||||
|
||||
attn_weight = torch.tanh(torch.add(attn_key, attn_query, alpha=1))
|
||||
# bsz * (seq_len + 1) * attn_size * h * w
|
||||
attn_weight = attn_weight.permute(0, 1, 3, 4, 2).contiguous()
|
||||
# bsz * (seq_len + 1) * h * w * attn_size
|
||||
attn_weight = self.conv1x1_2(attn_weight)
|
||||
# bsz * (seq_len + 1) * h * w * 1
|
||||
bsz, T, h, w, c = attn_weight.size()
|
||||
assert c == 1
|
||||
|
||||
if valid_ratios is not None:
|
||||
# cal mask of attention weight
|
||||
attn_mask = torch.zeros_like(attn_weight)
|
||||
for i, valid_ratio in enumerate(valid_ratios):
|
||||
valid_width = min(w, math.ceil(w * valid_ratio))
|
||||
attn_mask[i, :, :, valid_width:, :] = 1
|
||||
attn_weight = attn_weight.masked_fill(attn_mask.bool(),
|
||||
float('-inf'))
|
||||
|
||||
attn_weight = attn_weight.view(bsz, T, -1)
|
||||
attn_weight = F.softmax(attn_weight, dim=-1)
|
||||
attn_weight = attn_weight.view(bsz, T, h, w,
|
||||
c).permute(0, 1, 4, 2, 3).contiguous()
|
||||
|
||||
attn_feat = torch.sum(
|
||||
torch.mul(feat.unsqueeze(1), attn_weight), (3, 4), keepdim=False)
|
||||
# bsz * (seq_len + 1) * C
|
||||
|
||||
# linear transformation
|
||||
if self.pred_concat:
|
||||
hf_c = holistic_feat.size(-1)
|
||||
holistic_feat = holistic_feat.expand(bsz, seq_len, hf_c)
|
||||
y = self.prediction(torch.cat((y, attn_feat, holistic_feat), 2))
|
||||
else:
|
||||
y = self.prediction(attn_feat)
|
||||
# bsz * (seq_len + 1) * num_classes
|
||||
if self.train_mode:
|
||||
y = self.pred_dropout(y)
|
||||
|
||||
return y
|
||||
|
||||
def forward_train(self, feat, out_enc, targets_dict, img_metas):
|
||||
assert utils.is_type_list(img_metas, dict)
|
||||
assert len(img_metas) == feat.size(0)
|
||||
|
||||
valid_ratios = [
|
||||
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||
] if self.mask else None
|
||||
|
||||
targets = targets_dict['padded_targets'].to(feat.device)
|
||||
tgt_embedding = self.embedding(targets)
|
||||
# bsz * seq_len * emb_dim
|
||||
out_enc = out_enc.unsqueeze(1)
|
||||
# bsz * 1 * emb_dim
|
||||
in_dec = torch.cat((out_enc, tgt_embedding), dim=1)
|
||||
# bsz * (seq_len + 1) * C
|
||||
out_dec = self._2d_attention(
|
||||
in_dec, feat, out_enc, valid_ratios=valid_ratios)
|
||||
# bsz * (seq_len + 1) * num_classes
|
||||
|
||||
return out_dec[:, 1:, :] # bsz * seq_len * num_classes
|
||||
|
||||
def forward_test(self, feat, out_enc, img_metas):
|
||||
assert utils.is_type_list(img_metas, dict)
|
||||
assert len(img_metas) == feat.size(0)
|
||||
|
||||
valid_ratios = [
|
||||
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||
] if self.mask else None
|
||||
|
||||
seq_len = self.max_seq_len
|
||||
|
||||
bsz = feat.size(0)
|
||||
start_token = torch.full((bsz, ),
|
||||
self.start_idx,
|
||||
device=feat.device,
|
||||
dtype=torch.long)
|
||||
# bsz
|
||||
start_token = self.embedding(start_token)
|
||||
# bsz * emb_dim
|
||||
start_token = start_token.unsqueeze(1).expand(-1, seq_len, -1)
|
||||
# bsz * seq_len * emb_dim
|
||||
out_enc = out_enc.unsqueeze(1)
|
||||
# bsz * 1 * emb_dim
|
||||
decoder_input = torch.cat((out_enc, start_token), dim=1)
|
||||
# bsz * (seq_len + 1) * emb_dim
|
||||
|
||||
outputs = []
|
||||
for i in range(1, seq_len + 1):
|
||||
decoder_output = self._2d_attention(
|
||||
decoder_input, feat, out_enc, valid_ratios=valid_ratios)
|
||||
char_output = decoder_output[:, i, :] # bsz * num_classes
|
||||
char_output = F.softmax(char_output, -1)
|
||||
outputs.append(char_output)
|
||||
_, max_idx = torch.max(char_output, dim=1, keepdim=False)
|
||||
char_embedding = self.embedding(max_idx) # bsz * emb_dim
|
||||
if i < seq_len:
|
||||
decoder_input[:, i + 1, :] = char_embedding
|
||||
|
||||
outputs = torch.stack(outputs, 1) # bsz * seq_len * num_classes
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@DECODERS.register_module()
|
||||
class SequentialSARDecoder(BaseDecoder):
|
||||
"""Implementation Sequential Decoder module in `SAR.
|
||||
|
||||
<https://arxiv.org/abs/1811.00751>`_.
|
||||
|
||||
Args:
|
||||
number_classes (int): Number of output class.
|
||||
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
|
||||
dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
|
||||
dec_do_rnn (float): Dropout of RNN layer in decoder.
|
||||
dec_gru (bool): If True, use GRU, else LSTM in decoder.
|
||||
d_k (int): Dim of conv layers in attention module.
|
||||
d_model (int): Dim of channels from backbone.
|
||||
d_enc (int): Dim of encoder RNN layer.
|
||||
pred_dropout (float): Dropout probability of prediction layer.
|
||||
max_seq_len (int): Maximum sequence length during decoding.
|
||||
mask (bool): If True, mask padding in feature map.
|
||||
start_idx (int): Index of start token.
|
||||
padding_idx (int): Index of padding token.
|
||||
pred_concat (bool): If True, concat glimpse feature from
|
||||
attention with holistic feature and hidden state.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes=37,
|
||||
enc_bi_rnn=False,
|
||||
dec_bi_rnn=False,
|
||||
dec_gru=False,
|
||||
d_k=64,
|
||||
d_model=512,
|
||||
d_enc=512,
|
||||
pred_dropout=0.0,
|
||||
mask=True,
|
||||
max_seq_len=40,
|
||||
start_idx=0,
|
||||
padding_idx=92,
|
||||
pred_concat=False,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.enc_bi_rnn = enc_bi_rnn
|
||||
self.d_k = d_k
|
||||
self.start_idx = start_idx
|
||||
self.dec_gru = dec_gru
|
||||
self.max_seq_len = max_seq_len
|
||||
self.mask = mask
|
||||
self.pred_concat = pred_concat
|
||||
|
||||
encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
|
||||
decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1)
|
||||
# 2D attention layer
|
||||
self.conv1x1_1 = nn.Conv2d(
|
||||
decoder_rnn_out_size, d_k, kernel_size=1, stride=1)
|
||||
self.conv3x3_1 = nn.Conv2d(
|
||||
d_model, d_k, kernel_size=3, stride=1, padding=1)
|
||||
self.conv1x1_2 = nn.Conv2d(d_k, 1, kernel_size=1, stride=1)
|
||||
|
||||
# Decoder rnn layer
|
||||
if dec_gru:
|
||||
self.rnn_decoder_layer1 = nn.GRUCell(encoder_rnn_out_size,
|
||||
encoder_rnn_out_size)
|
||||
self.rnn_decoder_layer2 = nn.GRUCell(encoder_rnn_out_size,
|
||||
encoder_rnn_out_size)
|
||||
else:
|
||||
self.rnn_decoder_layer1 = nn.LSTMCell(encoder_rnn_out_size,
|
||||
encoder_rnn_out_size)
|
||||
self.rnn_decoder_layer2 = nn.LSTMCell(encoder_rnn_out_size,
|
||||
encoder_rnn_out_size)
|
||||
|
||||
# Decoder input embedding
|
||||
self.embedding = nn.Embedding(
|
||||
self.num_classes, encoder_rnn_out_size, padding_idx=padding_idx)
|
||||
|
||||
# Prediction layer
|
||||
self.pred_dropout = nn.Dropout(pred_dropout)
|
||||
pred_num_class = num_classes - 1 # ignore padding index
|
||||
if pred_concat:
|
||||
fc_in_channel = decoder_rnn_out_size + d_model + d_enc
|
||||
else:
|
||||
fc_in_channel = d_model
|
||||
self.prediction = nn.Linear(fc_in_channel, pred_num_class)
|
||||
|
||||
def _2d_attention(self,
|
||||
y_prev,
|
||||
feat,
|
||||
holistic_feat,
|
||||
hx1,
|
||||
cx1,
|
||||
hx2,
|
||||
cx2,
|
||||
valid_ratios=None):
|
||||
_, _, h_feat, w_feat = feat.size()
|
||||
if self.dec_gru:
|
||||
hx1 = cx1 = self.rnn_decoder_layer1(y_prev, hx1)
|
||||
hx2 = cx2 = self.rnn_decoder_layer2(hx1, hx2)
|
||||
else:
|
||||
hx1, cx1 = self.rnn_decoder_layer1(y_prev, (hx1, cx1))
|
||||
hx2, cx2 = self.rnn_decoder_layer2(hx1, (hx2, cx2))
|
||||
|
||||
tile_hx2 = hx2.view(hx2.size(0), hx2.size(1), 1, 1)
|
||||
attn_query = self.conv1x1_1(tile_hx2) # bsz * attn_size * 1 * 1
|
||||
attn_query = attn_query.expand(-1, -1, h_feat, w_feat)
|
||||
attn_key = self.conv3x3_1(feat)
|
||||
attn_weight = torch.tanh(torch.add(attn_key, attn_query, alpha=1))
|
||||
attn_weight = self.conv1x1_2(attn_weight)
|
||||
bsz, c, h, w = attn_weight.size()
|
||||
assert c == 1
|
||||
|
||||
if valid_ratios is not None:
|
||||
# cal mask of attention weight
|
||||
attn_mask = torch.zeros_like(attn_weight)
|
||||
for i, valid_ratio in enumerate(valid_ratios):
|
||||
valid_width = min(w, math.ceil(w * valid_ratio))
|
||||
attn_mask[i, :, :, valid_width:] = 1
|
||||
attn_weight = attn_weight.masked_fill(attn_mask.bool(),
|
||||
float('-inf'))
|
||||
|
||||
attn_weight = F.softmax(attn_weight.view(bsz, -1), dim=-1)
|
||||
attn_weight = attn_weight.view(bsz, c, h, w)
|
||||
|
||||
attn_feat = torch.sum(
|
||||
torch.mul(feat, attn_weight), (2, 3), keepdim=False) # n * c
|
||||
|
||||
# linear transformation
|
||||
if self.pred_concat:
|
||||
y = self.prediction(torch.cat((hx2, attn_feat, holistic_feat), 1))
|
||||
else:
|
||||
y = self.prediction(attn_feat)
|
||||
|
||||
return y, hx1, hx1, hx2, hx2
|
||||
|
||||
def forward_train(self, feat, out_enc, targets_dict, img_metas=None):
|
||||
assert utils.is_type_list(img_metas, dict)
|
||||
assert len(img_metas) == feat.size(0)
|
||||
|
||||
valid_ratios = [
|
||||
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||
] if self.mask else None
|
||||
|
||||
if self.train_mode:
|
||||
targets = targets_dict['padded_targets'].to(feat.device)
|
||||
tgt_embedding = self.embedding(targets)
|
||||
|
||||
outputs = []
|
||||
start_token = torch.full((feat.size(0), ),
|
||||
self.start_idx,
|
||||
device=feat.device,
|
||||
dtype=torch.long)
|
||||
start_token = self.embedding(start_token)
|
||||
for i in range(-1, self.max_seq_len):
|
||||
if i == -1:
|
||||
if self.dec_gru:
|
||||
hx1 = cx1 = self.rnn_decoder_layer1(out_enc)
|
||||
hx2 = cx2 = self.rnn_decoder_layer2(hx1)
|
||||
else:
|
||||
hx1, cx1 = self.rnn_decoder_layer1(out_enc)
|
||||
hx2, cx2 = self.rnn_decoder_layer2(hx1)
|
||||
if not self.train_mode:
|
||||
y_prev = start_token
|
||||
else:
|
||||
if self.train_mode:
|
||||
y_prev = tgt_embedding[:, i, :]
|
||||
y, hx1, cx1, hx2, cx2 = self._2d_attention(
|
||||
y_prev,
|
||||
feat,
|
||||
out_enc,
|
||||
hx1,
|
||||
cx1,
|
||||
hx2,
|
||||
cx2,
|
||||
valid_ratios=valid_ratios)
|
||||
if self.train_mode:
|
||||
y = self.pred_dropout(y)
|
||||
else:
|
||||
y = F.softmax(y, -1)
|
||||
_, max_idx = torch.max(y, dim=1, keepdim=False)
|
||||
char_embedding = self.embedding(max_idx)
|
||||
y_prev = char_embedding
|
||||
outputs.append(y)
|
||||
|
||||
outputs = torch.stack(outputs, 1)
|
||||
|
||||
return outputs
|
||||
|
||||
def forward_test(self, feat, out_enc, img_metas):
|
||||
assert utils.is_type_list(img_metas, dict)
|
||||
assert len(img_metas) == feat.size(0)
|
||||
|
||||
return self.forward_train(feat, out_enc, None, img_metas)
|
|
@ -0,0 +1,148 @@
|
|||
from queue import PriorityQueue
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmocr.models.builder import DECODERS
|
||||
from . import ParallelSARDecoder
|
||||
|
||||
|
||||
class DecodeNode:
|
||||
"""Node class to save decoded char indices and scores.
|
||||
|
||||
Args:
|
||||
indexes (list[int]): Char indices that decoded yes.
|
||||
scores (list[float]): Char scores that decoded yes.
|
||||
"""
|
||||
|
||||
def __init__(self, indexes=[1], scores=[0.9]):
|
||||
assert utils.is_type_list(indexes, int)
|
||||
assert utils.is_type_list(scores, float)
|
||||
assert utils.equal_len(indexes, scores)
|
||||
|
||||
self.indexes = indexes
|
||||
self.scores = scores
|
||||
|
||||
def eval(self):
|
||||
"""Calculate accumulated score."""
|
||||
accu_score = sum(self.scores)
|
||||
return accu_score
|
||||
|
||||
|
||||
@DECODERS.register_module()
|
||||
class ParallelSARDecoderWithBS(ParallelSARDecoder):
|
||||
"""Parallel Decoder module with beam-search in SAR.
|
||||
|
||||
Args:
|
||||
beam_width (int): Width for beam search.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
beam_width=5,
|
||||
num_classes=37,
|
||||
enc_bi_rnn=False,
|
||||
dec_bi_rnn=False,
|
||||
dec_do_rnn=0,
|
||||
dec_gru=False,
|
||||
d_model=512,
|
||||
d_enc=512,
|
||||
d_k=64,
|
||||
pred_dropout=0.0,
|
||||
max_seq_len=40,
|
||||
mask=True,
|
||||
start_idx=0,
|
||||
padding_idx=0,
|
||||
pred_concat=False,
|
||||
**kwargs):
|
||||
super().__init__(num_classes, enc_bi_rnn, dec_bi_rnn, dec_do_rnn,
|
||||
dec_gru, d_model, d_enc, d_k, pred_dropout,
|
||||
max_seq_len, mask, start_idx, padding_idx,
|
||||
pred_concat)
|
||||
assert isinstance(beam_width, int)
|
||||
assert beam_width > 0
|
||||
|
||||
self.beam_width = beam_width
|
||||
|
||||
def forward_test(self, feat, out_enc, img_metas):
|
||||
assert utils.is_type_list(img_metas, dict)
|
||||
assert len(img_metas) == feat.size(0)
|
||||
|
||||
valid_ratios = [
|
||||
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||
] if self.mask else None
|
||||
|
||||
seq_len = self.max_seq_len
|
||||
bsz = feat.size(0)
|
||||
assert bsz == 1, 'batch size must be 1 for beam search.'
|
||||
|
||||
start_token = torch.full((bsz, ),
|
||||
self.start_idx,
|
||||
device=feat.device,
|
||||
dtype=torch.long)
|
||||
# bsz
|
||||
start_token = self.embedding(start_token)
|
||||
# bsz * emb_dim
|
||||
start_token = start_token.unsqueeze(1).expand(-1, seq_len, -1)
|
||||
# bsz * seq_len * emb_dim
|
||||
out_enc = out_enc.unsqueeze(1)
|
||||
# bsz * 1 * emb_dim
|
||||
decoder_input = torch.cat((out_enc, start_token), dim=1)
|
||||
# bsz * (seq_len + 1) * emb_dim
|
||||
|
||||
# Initialize beam-search queue
|
||||
q = PriorityQueue()
|
||||
init_node = DecodeNode([self.start_idx], [0.0])
|
||||
q.put((-init_node.eval(), init_node))
|
||||
|
||||
for i in range(1, seq_len + 1):
|
||||
next_nodes = []
|
||||
beam_width = self.beam_width if i > 1 else 1
|
||||
for _ in range(beam_width):
|
||||
_, node = q.get()
|
||||
|
||||
input_seq = torch.clone(decoder_input) # bsz * T * emb_dim
|
||||
# fill previous input tokens (step 1...i) in input_seq
|
||||
for t, index in enumerate(node.indexes):
|
||||
input_token = torch.full((bsz, ),
|
||||
index,
|
||||
device=input_seq.device,
|
||||
dtype=torch.long)
|
||||
input_token = self.embedding(input_token) # bsz * emb_dim
|
||||
input_seq[:, t + 1, :] = input_token
|
||||
|
||||
output_seq = self._2d_attention(
|
||||
input_seq, feat, out_enc, valid_ratios=valid_ratios)
|
||||
|
||||
output_char = output_seq[:, i, :] # bsz * num_classes
|
||||
output_char = F.softmax(output_char, -1)
|
||||
topk_value, topk_idx = output_char.topk(self.beam_width, dim=1)
|
||||
topk_value, topk_idx = topk_value.squeeze(0), topk_idx.squeeze(
|
||||
0)
|
||||
|
||||
for k in range(self.beam_width):
|
||||
kth_score = topk_value[k].item()
|
||||
kth_idx = topk_idx[k].item()
|
||||
next_node = DecodeNode(node.indexes + [kth_idx],
|
||||
node.scores + [kth_score])
|
||||
delta = k * 1e-6
|
||||
next_nodes.append(
|
||||
(-node.eval() - kth_score - delta, next_node))
|
||||
# Use minus since priority queue sort
|
||||
# with ascending order
|
||||
|
||||
while not q.empty():
|
||||
q.get()
|
||||
|
||||
# Put all candidates to queue
|
||||
for next_node in next_nodes:
|
||||
q.put(next_node)
|
||||
|
||||
best_node = q.get()
|
||||
num_classes = self.num_classes - 1 # ignore padding index
|
||||
outputs = torch.zeros(bsz, seq_len, num_classes)
|
||||
for i in range(seq_len):
|
||||
idx = best_node[1].indexes[i + 1]
|
||||
outputs[0, i, idx] = best_node[1].scores[i + 1]
|
||||
|
||||
return outputs
|
|
@ -0,0 +1,99 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmocr.models.builder import DECODERS
|
||||
from mmocr.models.textrecog.layers import (DecoderLayer, PositionalEncoding,
|
||||
get_pad_mask, get_subsequent_mask)
|
||||
from .base_decoder import BaseDecoder
|
||||
|
||||
|
||||
@DECODERS.register_module()
|
||||
class TFDecoder(BaseDecoder):
|
||||
"""Transformer Decoder block with self attention mechanism."""
|
||||
|
||||
def __init__(self,
|
||||
n_layers=6,
|
||||
d_embedding=512,
|
||||
n_head=8,
|
||||
d_k=64,
|
||||
d_v=64,
|
||||
d_model=512,
|
||||
d_inner=256,
|
||||
n_position=200,
|
||||
dropout=0.1,
|
||||
num_classes=93,
|
||||
max_seq_len=40,
|
||||
start_idx=1,
|
||||
padding_idx=92,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.padding_idx = padding_idx
|
||||
self.start_idx = start_idx
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
self.trg_word_emb = nn.Embedding(
|
||||
num_classes, d_embedding, padding_idx=padding_idx)
|
||||
|
||||
self.position_enc = PositionalEncoding(
|
||||
d_embedding, n_position=n_position)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
self.layer_stack = nn.ModuleList([
|
||||
DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
|
||||
for _ in range(n_layers)
|
||||
])
|
||||
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
||||
|
||||
pred_num_class = num_classes - 1 # ignore padding_idx
|
||||
self.classifier = nn.Linear(d_model, pred_num_class)
|
||||
|
||||
def _attention(self, trg_seq, src, src_mask=None):
|
||||
trg_embedding = self.trg_word_emb(trg_seq)
|
||||
trg_pos_encoded = self.position_enc(trg_embedding)
|
||||
tgt = self.dropout(trg_pos_encoded)
|
||||
|
||||
trg_mask = get_pad_mask(
|
||||
trg_seq, pad_idx=self.padding_idx) & get_subsequent_mask(trg_seq)
|
||||
output = tgt
|
||||
for dec_layer in self.layer_stack:
|
||||
output = dec_layer(
|
||||
output,
|
||||
src,
|
||||
slf_attn_mask=trg_mask,
|
||||
dec_enc_attn_mask=src_mask)
|
||||
output = self.layer_norm(output)
|
||||
|
||||
return output
|
||||
|
||||
def forward_train(self, feat, out_enc, targets_dict, img_metas):
|
||||
targets = targets_dict['padded_targets'].to(out_enc.device)
|
||||
attn_output = self._attention(targets, out_enc, src_mask=None)
|
||||
outputs = self.classifier(attn_output)
|
||||
return outputs
|
||||
|
||||
def forward_test(self, feat, out_enc, img_metas):
|
||||
bsz = out_enc.size(0)
|
||||
init_target_seq = torch.full((bsz, self.max_seq_len + 1),
|
||||
self.padding_idx,
|
||||
device=out_enc.device,
|
||||
dtype=torch.long)
|
||||
# bsz * seq_len
|
||||
init_target_seq[:, 0] = self.start_idx
|
||||
|
||||
outputs = []
|
||||
for step in range(0, self.max_seq_len):
|
||||
decoder_output = self._attention(
|
||||
init_target_seq, out_enc, src_mask=None)
|
||||
# bsz * seq_len * 512
|
||||
step_result = F.softmax(
|
||||
self.classifier(decoder_output[:, step, :]), dim=-1)
|
||||
# bsz * num_classes
|
||||
outputs.append(step_result)
|
||||
_, step_max_index = torch.max(step_result, dim=-1)
|
||||
init_target_seq[:, step + 1] = step_max_index
|
||||
|
||||
outputs = torch.stack(outputs, dim=1)
|
||||
|
||||
return outputs
|
|
@ -0,0 +1,6 @@
|
|||
from .base_encoder import BaseEncoder
|
||||
from .channel_reduction_encoder import ChannelReductionEncoder
|
||||
from .sar_encoder import SAREncoder
|
||||
from .transformer_encoder import TFEncoder
|
||||
|
||||
__all__ = ['SAREncoder', 'TFEncoder', 'BaseEncoder', 'ChannelReductionEncoder']
|
|
@ -0,0 +1,14 @@
|
|||
import torch.nn as nn
|
||||
|
||||
from mmocr.models.builder import ENCODERS
|
||||
|
||||
|
||||
@ENCODERS.register_module()
|
||||
class BaseEncoder(nn.Module):
|
||||
"""Base Encoder class for text recognition."""
|
||||
|
||||
def init_weights(self):
|
||||
pass
|
||||
|
||||
def forward(self, feat, **kwargs):
|
||||
return feat
|
|
@ -0,0 +1,102 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import uniform_init, xavier_init
|
||||
|
||||
import mmocr.utils as utils
|
||||
from mmocr.models.builder import ENCODERS
|
||||
from .base_encoder import BaseEncoder
|
||||
|
||||
|
||||
@ENCODERS.register_module()
|
||||
class SAREncoder(BaseEncoder):
|
||||
"""Implementation of encoder module in `SAR.
|
||||
|
||||
<https://arxiv.org/abs/1811.00751>`_
|
||||
|
||||
Args:
|
||||
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
|
||||
enc_do_rnn (float): Dropout probability of RNN layer in encoder.
|
||||
enc_gru (bool): If True, use GRU, else LSTM in encoder.
|
||||
d_model (int): Dim of channels from backbone.
|
||||
d_enc (int): Dim of encoder RNN layer.
|
||||
mask (bool): If True, mask padding in RNN sequence.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
enc_bi_rnn=False,
|
||||
enc_do_rnn=0.0,
|
||||
enc_gru=False,
|
||||
d_model=512,
|
||||
d_enc=512,
|
||||
mask=True,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
assert isinstance(enc_bi_rnn, bool)
|
||||
assert isinstance(enc_do_rnn, (int, float))
|
||||
assert 0 <= enc_do_rnn < 1.0
|
||||
assert isinstance(enc_gru, bool)
|
||||
assert isinstance(d_model, int)
|
||||
assert isinstance(d_enc, int)
|
||||
assert isinstance(mask, bool)
|
||||
|
||||
self.enc_bi_rnn = enc_bi_rnn
|
||||
self.enc_do_rnn = enc_do_rnn
|
||||
self.mask = mask
|
||||
|
||||
# LSTM Encoder
|
||||
kwargs = dict(
|
||||
input_size=d_model,
|
||||
hidden_size=d_enc,
|
||||
num_layers=2,
|
||||
batch_first=True,
|
||||
dropout=enc_do_rnn,
|
||||
bidirectional=enc_bi_rnn)
|
||||
if enc_gru:
|
||||
self.rnn_encoder = nn.GRU(**kwargs)
|
||||
else:
|
||||
self.rnn_encoder = nn.LSTM(**kwargs)
|
||||
|
||||
# global feature transformation
|
||||
encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
|
||||
self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
|
||||
|
||||
def init_weights(self):
|
||||
# initialize weight and bias
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
xavier_init(m)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
uniform_init(m)
|
||||
|
||||
def forward(self, feat, img_metas=None):
|
||||
assert utils.is_type_list(img_metas, dict)
|
||||
assert len(img_metas) == feat.size(0)
|
||||
|
||||
valid_ratios = [
|
||||
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
|
||||
] if self.mask else None
|
||||
|
||||
h_feat = feat.size(2)
|
||||
feat_v = F.max_pool2d(
|
||||
feat, kernel_size=(h_feat, 1), stride=1, padding=0)
|
||||
feat_v = feat_v.squeeze(2) # bsz * C * W
|
||||
feat_v = feat_v.permute(0, 2, 1).contiguous() # bsz * W * C
|
||||
|
||||
holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C
|
||||
|
||||
if valid_ratios is not None:
|
||||
valid_hf = []
|
||||
T = holistic_feat.size(1)
|
||||
for i, valid_ratio in enumerate(valid_ratios):
|
||||
valid_step = min(T, math.ceil(T * valid_ratio)) - 1
|
||||
valid_hf.append(holistic_feat[i, valid_step, :])
|
||||
valid_hf = torch.stack(valid_hf, dim=0)
|
||||
else:
|
||||
valid_hf = holistic_feat[:, -1, :] # bsz * C
|
||||
|
||||
holistic_feat = self.linear(valid_hf) # bsz * C
|
||||
|
||||
return holistic_feat
|
|
@ -0,0 +1,16 @@
|
|||
from mmocr.models.builder import ENCODERS
|
||||
from .base_encoder import BaseEncoder
|
||||
|
||||
|
||||
@ENCODERS.register_module()
|
||||
class TFEncoder(BaseEncoder):
|
||||
"""Encode 2d feature map to 1d sequence."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, feat, img_metas=None):
|
||||
n, c, _, _ = feat.size()
|
||||
enc_output = feat.view(n, c, -1).transpose(2, 1).contiguous()
|
||||
|
||||
return enc_output
|
|
@ -0,0 +1,3 @@
|
|||
from .seg_head import SegHead
|
||||
|
||||
__all__ = ['SegHead']
|
|
@ -0,0 +1,50 @@
|
|||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from torch import nn
|
||||
|
||||
from mmdet.models.builder import HEADS
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class SegHead(nn.Module):
|
||||
"""Head for segmentation based text recognition.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
num_classes (int): Number of output classes.
|
||||
upsample_param (dict | None): Config dict for interpolation layer.
|
||||
Default: `dict(scale_factor=1.0, mode='nearest')`
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels=128, num_classes=37, upsample_param=None):
|
||||
super().__init__()
|
||||
assert isinstance(num_classes, int)
|
||||
assert num_classes > 0
|
||||
assert upsample_param is None or isinstance(upsample_param, dict)
|
||||
|
||||
self.upsample_param = upsample_param
|
||||
|
||||
self.seg_conv = ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
norm_cfg=dict(type='BN'))
|
||||
|
||||
# prediction
|
||||
self.pred_conv = nn.Conv2d(
|
||||
in_channels, num_classes, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def init_weights(self):
|
||||
pass
|
||||
|
||||
def forward(self, out_neck):
|
||||
|
||||
seg_map = self.seg_conv(out_neck[-1])
|
||||
seg_map = self.pred_conv(seg_map)
|
||||
|
||||
if self.upsample_param is not None:
|
||||
seg_map = F.interpolate(seg_map, **self.upsample_param)
|
||||
|
||||
return seg_map
|
|
@ -0,0 +1,15 @@
|
|||
from .conv_layer import BasicBlock, Bottleneck
|
||||
from .dot_product_attention_layer import DotProductAttentionLayer
|
||||
from .lstm_layer import BidirectionalLSTM
|
||||
from .position_aware_layer import PositionAwareLayer
|
||||
from .robust_scanner_fusion_layer import RobustScannerFusionLayer
|
||||
from .transformer_layer import (DecoderLayer, MultiHeadAttention,
|
||||
PositionalEncoding, PositionwiseFeedForward,
|
||||
get_pad_mask, get_subsequent_mask)
|
||||
|
||||
__all__ = [
|
||||
'BidirectionalLSTM', 'MultiHeadAttention', 'PositionalEncoding',
|
||||
'PositionwiseFeedForward', 'BasicBlock', 'Bottleneck',
|
||||
'RobustScannerFusionLayer', 'DotProductAttentionLayer',
|
||||
'PositionAwareLayer', 'DecoderLayer', 'get_pad_mask', 'get_subsequent_mask'
|
||||
]
|
|
@ -0,0 +1,93 @@
|
|||
import torch.nn as nn
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
return nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=False):
|
||||
super().__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
if downsample:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
inplanes, planes * self.expansion, 1, stride, bias=False),
|
||||
nn.BatchNorm2d(planes * self.expansion),
|
||||
)
|
||||
else:
|
||||
self.downsample = nn.Sequential()
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=False):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, 3, stride, 1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(
|
||||
planes, planes * self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
if downsample:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
inplanes, planes * self.expansion, 1, stride, bias=False),
|
||||
nn.BatchNorm2d(planes * self.expansion),
|
||||
)
|
||||
else:
|
||||
self.downsample = nn.Sequential()
|
||||
|
||||
def forward(self, x):
|
||||
residual = self.downsample(x)
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
|
@ -0,0 +1,20 @@
|
|||
import torch.nn as nn
|
||||
|
||||
|
||||
class BidirectionalLSTM(nn.Module):
|
||||
|
||||
def __init__(self, nIn, nHidden, nOut):
|
||||
super().__init__()
|
||||
|
||||
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
|
||||
self.embedding = nn.Linear(nHidden * 2, nOut)
|
||||
|
||||
def forward(self, input):
|
||||
recurrent, _ = self.rnn(input)
|
||||
T, b, h = recurrent.size()
|
||||
t_rec = recurrent.view(T * b, h)
|
||||
|
||||
output = self.embedding(t_rec) # [T * b, nOut]
|
||||
output = output.view(T, b, -1)
|
||||
|
||||
return output
|
|
@ -0,0 +1,172 @@
|
|||
"""This code is from https://github.com/jadore801120/attention-is-all-you-need-
|
||||
pytorch."""
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
"""Compose with three layers:
|
||||
1. MultiHeadSelfAttn
|
||||
2. MultiHeadEncoderDecoderAttn
|
||||
3. PositionwiseFeedForward
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
d_model=512,
|
||||
d_inner=256,
|
||||
n_head=8,
|
||||
d_k=64,
|
||||
d_v=64,
|
||||
dropout=0.1):
|
||||
super().__init__()
|
||||
self.slf_attn = MultiHeadAttention(
|
||||
n_head, d_model, d_k, d_v, dropout=dropout)
|
||||
self.enc_attn = MultiHeadAttention(
|
||||
n_head, d_model, d_k, d_v, dropout=dropout)
|
||||
self.pos_ffn = PositionwiseFeedForward(
|
||||
d_model, d_inner, dropout=dropout)
|
||||
|
||||
def forward(self,
|
||||
dec_input,
|
||||
enc_output,
|
||||
slf_attn_mask=None,
|
||||
dec_enc_attn_mask=None):
|
||||
|
||||
dec_output = self.slf_attn(
|
||||
dec_input, dec_input, dec_input, mask=slf_attn_mask)
|
||||
dec_output = self.enc_attn(
|
||||
dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
|
||||
dec_output = self.pos_ffn(dec_output)
|
||||
|
||||
return dec_output
|
||||
|
||||
|
||||
class ScaledDotProductAttention(nn.Module):
|
||||
"""Scaled Dot-Product Attention."""
|
||||
|
||||
def __init__(self, temperature, attn_dropout=0.1):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
self.dropout = nn.Dropout(attn_dropout)
|
||||
|
||||
def forward(self, q, k, v, mask=None):
|
||||
|
||||
attn = torch.matmul(q / self.temperature, k.transpose(1, 2))
|
||||
|
||||
if mask is not None:
|
||||
attn = attn.masked_fill(mask == 0, float('-inf'))
|
||||
|
||||
attn = self.dropout(F.softmax(attn, dim=-1))
|
||||
output = torch.matmul(attn, v)
|
||||
|
||||
return output, attn
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""Multi-Head Attention module."""
|
||||
|
||||
def __init__(self, n_head=8, d_model=512, d_k=64, d_v=64, dropout=0.1):
|
||||
super().__init__()
|
||||
|
||||
self.n_head = n_head
|
||||
self.d_k = d_k
|
||||
self.d_v = d_v
|
||||
|
||||
self.w_qs_list = nn.ModuleList(
|
||||
[nn.Linear(d_model, d_k, bias=False) for _ in range(n_head)])
|
||||
self.w_ks_list = nn.ModuleList(
|
||||
[nn.Linear(d_model, d_k, bias=False) for _ in range(n_head)])
|
||||
self.w_vs_list = nn.ModuleList(
|
||||
[nn.Linear(d_model, d_v, bias=False) for _ in range(n_head)])
|
||||
self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
|
||||
|
||||
self.attention = ScaledDotProductAttention(temperature=d_k**0.5)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
||||
|
||||
def forward(self, q, k, v, mask=None):
|
||||
|
||||
residual = q
|
||||
q = self.layer_norm(q)
|
||||
|
||||
attention_q_list = []
|
||||
for head_index in range(self.n_head):
|
||||
q_each = self.w_qs_list[head_index](q) # bsz * seq_len * d_k
|
||||
k_each = self.w_ks_list[head_index](k) # bsz * seq_len * d_k
|
||||
v_each = self.w_vs_list[head_index](v) # bsz * seq_len * d_v
|
||||
attention_q_each, _ = self.attention(
|
||||
q_each, k_each, v_each, mask=mask)
|
||||
attention_q_list.append(attention_q_each)
|
||||
|
||||
q = torch.cat(attention_q_list, dim=-1)
|
||||
q = self.dropout(self.fc(q))
|
||||
q += residual
|
||||
|
||||
return q
|
||||
|
||||
|
||||
class PositionwiseFeedForward(nn.Module):
|
||||
"""A two-feed-forward-layer module."""
|
||||
|
||||
def __init__(self, d_in, d_hid, dropout=0.1):
|
||||
super().__init__()
|
||||
self.w_1 = nn.Linear(d_in, d_hid) # position-wise
|
||||
self.w_2 = nn.Linear(d_hid, d_in) # position-wise
|
||||
self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
residual = x
|
||||
x = self.layer_norm(x)
|
||||
|
||||
x = self.w_2(F.relu(self.w_1(x)))
|
||||
x = self.dropout(x)
|
||||
x += residual
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
|
||||
def __init__(self, d_hid=512, n_position=200):
|
||||
super().__init__()
|
||||
|
||||
# Not a parameter
|
||||
self.register_buffer(
|
||||
'position_table',
|
||||
self._get_sinusoid_encoding_table(n_position, d_hid))
|
||||
|
||||
def _get_sinusoid_encoding_table(self, n_position, d_hid):
|
||||
"""Sinusoid position encoding table."""
|
||||
denominator = torch.Tensor([
|
||||
1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid)
|
||||
for hid_j in range(d_hid)
|
||||
])
|
||||
denominator = denominator.view(1, -1)
|
||||
pos_tensor = torch.arange(n_position).unsqueeze(-1).float()
|
||||
sinusoid_table = pos_tensor * denominator
|
||||
sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2])
|
||||
sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2])
|
||||
|
||||
return sinusoid_table.unsqueeze(0)
|
||||
|
||||
def forward(self, x):
|
||||
self.device = x.device
|
||||
return x + self.position_table[:, :x.size(1)].clone().detach()
|
||||
|
||||
|
||||
def get_pad_mask(seq, pad_idx):
|
||||
return (seq != pad_idx).unsqueeze(-2)
|
||||
|
||||
|
||||
def get_subsequent_mask(seq):
|
||||
"""For masking out the subsequent info."""
|
||||
len_s = seq.size(1)
|
||||
subsequent_mask = 1 - torch.triu(
|
||||
torch.ones((len_s, len_s), device=seq.device), diagonal=1)
|
||||
subsequent_mask = subsequent_mask.unsqueeze(0).bool()
|
||||
return subsequent_mask
|
|
@ -0,0 +1,5 @@
|
|||
from .ce_loss import CELoss, SARLoss, TFLoss
|
||||
from .ctc_loss import CTCLoss
|
||||
from .seg_loss import CAFCNLoss, SegLoss
|
||||
|
||||
__all__ = ['CELoss', 'SARLoss', 'CTCLoss', 'TFLoss', 'SegLoss', 'CAFCNLoss']
|
|
@ -0,0 +1,94 @@
|
|||
import torch.nn as nn
|
||||
|
||||
from mmdet.models.builder import LOSSES
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class CELoss(nn.Module):
|
||||
"""Implementation of loss module for encoder-decoder based text recognition
|
||||
method with CrossEntropy loss.
|
||||
|
||||
Args:
|
||||
ignore_index (int): Specifies a target value that is
|
||||
ignored and does not contribute to the input gradient.
|
||||
reduction (str): Specifies the reduction to apply to the output,
|
||||
should be one of the following: ('none', 'mean', 'sum').
|
||||
"""
|
||||
|
||||
def __init__(self, ignore_index=-1, reduction='none'):
|
||||
super().__init__()
|
||||
assert isinstance(ignore_index, int)
|
||||
assert isinstance(reduction, str)
|
||||
assert reduction in ['none', 'mean', 'sum']
|
||||
|
||||
self.loss_ce = nn.CrossEntropyLoss(
|
||||
ignore_index=ignore_index, reduction=reduction)
|
||||
|
||||
def format(self, outputs, targets_dict):
|
||||
targets = targets_dict['padded_targets']
|
||||
|
||||
return outputs.permute(0, 2, 1).contiguous(), targets
|
||||
|
||||
def forward(self, outputs, targets_dict):
|
||||
outputs, targets = self.format(outputs, targets_dict)
|
||||
|
||||
loss_ce = self.loss_ce(outputs, targets.to(outputs.device))
|
||||
losses = dict(loss_ce=loss_ce)
|
||||
|
||||
return losses
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class SARLoss(CELoss):
|
||||
"""Implementation of loss module in `SAR.
|
||||
|
||||
<https://arxiv.org/abs/1811.00751>`_.
|
||||
|
||||
Args:
|
||||
ignore_index (int): Specifies a target value that is
|
||||
ignored and does not contribute to the input gradient.
|
||||
reduction (str): Specifies the reduction to apply to the output,
|
||||
should be one of the following: ('none', 'mean', 'sum').
|
||||
"""
|
||||
|
||||
def __init__(self, ignore_index=0, reduction='mean', **kwargs):
|
||||
super().__init__(ignore_index, reduction)
|
||||
|
||||
def format(self, outputs, targets_dict):
|
||||
targets = targets_dict['padded_targets']
|
||||
# targets[0, :], [start_idx, idx1, idx2, ..., end_idx, pad_idx...]
|
||||
# outputs[0, :, 0], [idx1, idx2, ..., end_idx, ...]
|
||||
|
||||
# ignore first index of target in loss calculation
|
||||
targets = targets[:, 1:].contiguous()
|
||||
# ignore last index of outputs to be in same seq_len with targets
|
||||
outputs = outputs[:, :-1, :].permute(0, 2, 1).contiguous()
|
||||
|
||||
return outputs, targets
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class TFLoss(CELoss):
|
||||
"""Implementation of loss module for transformer."""
|
||||
|
||||
def __init__(self,
|
||||
ignore_index=-1,
|
||||
reduction='none',
|
||||
flatten=True,
|
||||
**kwargs):
|
||||
super().__init__(ignore_index, reduction)
|
||||
assert isinstance(flatten, bool)
|
||||
|
||||
self.flatten = flatten
|
||||
|
||||
def format(self, outputs, targets_dict):
|
||||
outputs = outputs[:, :-1, :].contiguous()
|
||||
targets = targets_dict['padded_targets']
|
||||
targets = targets[:, 1:].contiguous()
|
||||
if self.flatten:
|
||||
outputs = outputs.view(-1, outputs.size(-1))
|
||||
targets = targets.view(-1)
|
||||
else:
|
||||
outputs = outputs.permute(0, 2, 1).contiguous()
|
||||
|
||||
return outputs, targets
|
|
@ -0,0 +1,63 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmdet.models.builder import LOSSES
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class CTCLoss(nn.Module):
|
||||
"""Implementation of loss module for CTC-loss based text recognition.
|
||||
|
||||
Args:
|
||||
flatten (bool): If True, use flattened targets, else padded targets.
|
||||
blank (int): Blank label. Default 0.
|
||||
reduction (str): Specifies the reduction to apply to the output,
|
||||
should be one of the following: ('none', 'mean', 'sum').
|
||||
zero_infinity (bool): Whether to zero infinite losses and
|
||||
the associated gradients. Default: False.
|
||||
Infinite losses mainly occur when the inputs
|
||||
are too short to be aligned to the targets.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
flatten=True,
|
||||
blank=0,
|
||||
reduction='mean',
|
||||
zero_infinity=False,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
assert isinstance(flatten, bool)
|
||||
assert isinstance(blank, int)
|
||||
assert isinstance(reduction, str)
|
||||
assert isinstance(zero_infinity, bool)
|
||||
|
||||
self.flatten = flatten
|
||||
self.blank = blank
|
||||
self.ctc_loss = nn.CTCLoss(
|
||||
blank=blank, reduction=reduction, zero_infinity=zero_infinity)
|
||||
|
||||
def forward(self, outputs, targets_dict):
|
||||
|
||||
outputs = torch.log_softmax(outputs, dim=2)
|
||||
bsz, seq_len = outputs.size(0), outputs.size(1)
|
||||
input_lengths = torch.full(
|
||||
size=(bsz, ), fill_value=seq_len, dtype=torch.long)
|
||||
outputs_for_loss = outputs.permute(1, 0, 2).contiguous() # T * N * C
|
||||
|
||||
if self.flatten:
|
||||
targets = targets_dict['flatten_targets']
|
||||
else:
|
||||
targets = torch.full(
|
||||
size=(bsz, seq_len), fill_value=self.blank, dtype=torch.long)
|
||||
for idx, tensor in enumerate(targets_dict['targets']):
|
||||
valid_len = min(tensor.size(0), seq_len)
|
||||
targets[idx, :valid_len] = tensor[:valid_len]
|
||||
|
||||
target_lengths = targets_dict['target_lengths']
|
||||
|
||||
loss_ctc = self.ctc_loss(outputs_for_loss, targets, input_lengths,
|
||||
target_lengths)
|
||||
|
||||
losses = dict(loss_ctc=loss_ctc)
|
||||
|
||||
return losses
|
|
@ -0,0 +1,176 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmdet.models.builder import LOSSES
|
||||
from mmocr.models.common.losses import DiceLoss
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class SegLoss(nn.Module):
|
||||
"""Implementation of loss module for segmentation based text recognition
|
||||
method.
|
||||
|
||||
Args:
|
||||
seg_downsample_ratio (float): Downsample ratio of
|
||||
segmentation map.
|
||||
seg_with_loss_weight (bool): If True, set weight for
|
||||
segmentation loss.
|
||||
ignore_index (int): Specifies a target value that is ignored
|
||||
and does not contribute to the input gradient.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
seg_downsample_ratio=0.5,
|
||||
seg_with_loss_weight=True,
|
||||
ignore_index=255,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
assert isinstance(seg_downsample_ratio, (int, float))
|
||||
assert 0 < seg_downsample_ratio <= 1
|
||||
assert isinstance(ignore_index, int)
|
||||
|
||||
self.seg_downsample_ratio = seg_downsample_ratio
|
||||
self.seg_with_loss_weight = seg_with_loss_weight
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
def seg_loss(self, out_head, gt_kernels):
|
||||
seg_map = out_head # bsz * num_classes * H/2 * W/2
|
||||
seg_target = [
|
||||
item[1].rescale(self.seg_downsample_ratio).to_tensor(
|
||||
torch.long, seg_map.device) for item in gt_kernels
|
||||
]
|
||||
seg_target = torch.stack(seg_target).squeeze(1)
|
||||
|
||||
loss_weight = None
|
||||
if self.seg_with_loss_weight:
|
||||
N = torch.sum(seg_target != self.ignore_index)
|
||||
N_neg = torch.sum(seg_target == 0)
|
||||
weight_val = 1.0 * N_neg / (N - N_neg)
|
||||
loss_weight = torch.ones(seg_map.size(1), device=seg_map.device)
|
||||
loss_weight[1:] = weight_val
|
||||
|
||||
loss_seg = F.cross_entropy(
|
||||
seg_map,
|
||||
seg_target,
|
||||
weight=loss_weight,
|
||||
ignore_index=self.ignore_index)
|
||||
|
||||
return loss_seg
|
||||
|
||||
def forward(self, out_neck, out_head, gt_kernels):
|
||||
|
||||
losses = {}
|
||||
|
||||
loss_seg = self.seg_loss(out_head, gt_kernels)
|
||||
|
||||
losses['loss_seg'] = loss_seg
|
||||
|
||||
return losses
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class CAFCNLoss(SegLoss):
|
||||
"""Implementation of loss module in `CA-FCN.
|
||||
|
||||
<https://arxiv.org/pdf/1809.06508.pdf>`_
|
||||
|
||||
Args:
|
||||
alpha (float): Weight ratio of attention loss.
|
||||
attn_s2_downsample_ratio (float): Downsample ratio
|
||||
of attention map from output stage 2.
|
||||
attn_s3_downsample_ratio (float): Downsample ratio
|
||||
of attention map from output stage 3.
|
||||
seg_downsample_ratio (float): Downsample ratio of
|
||||
segmentation map.
|
||||
attn_with_dice_loss (bool): If True, use dice_loss for attention,
|
||||
else BCELoss.
|
||||
with_attn (bool): If True, include attention loss, else
|
||||
segmentation loss only.
|
||||
seg_with_loss_weight (bool): If True, set weight for
|
||||
segmentation loss.
|
||||
ignore_index (int): Specifies a target value that is ignored
|
||||
and does not contribute to the input gradient.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
alpha=1.0,
|
||||
attn_s2_downsample_ratio=0.25,
|
||||
attn_s3_downsample_ratio=0.125,
|
||||
seg_downsample_ratio=0.5,
|
||||
attn_with_dice_loss=False,
|
||||
with_attn=True,
|
||||
seg_with_loss_weight=True,
|
||||
ignore_index=255):
|
||||
super().__init__(seg_downsample_ratio, seg_with_loss_weight,
|
||||
ignore_index)
|
||||
assert isinstance(alpha, (int, float))
|
||||
assert isinstance(attn_s2_downsample_ratio, (int, float))
|
||||
assert isinstance(attn_s3_downsample_ratio, (int, float))
|
||||
assert 0 < attn_s2_downsample_ratio <= 1
|
||||
assert 0 < attn_s3_downsample_ratio <= 1
|
||||
|
||||
self.alpha = alpha
|
||||
self.attn_s2_downsample_ratio = attn_s2_downsample_ratio
|
||||
self.attn_s3_downsample_ratio = attn_s3_downsample_ratio
|
||||
self.with_attn = with_attn
|
||||
self.attn_with_dice_loss = attn_with_dice_loss
|
||||
|
||||
# attention loss
|
||||
if with_attn:
|
||||
if attn_with_dice_loss:
|
||||
self.criterion_attn = DiceLoss()
|
||||
else:
|
||||
self.criterion_attn = nn.BCELoss(reduction='none')
|
||||
|
||||
def attn_loss(self, out_neck, gt_kernels):
|
||||
attn_map_s2 = out_neck[0] # bsz * 2 * H/4 * W/4
|
||||
|
||||
mask_s2 = torch.stack([
|
||||
item[2].rescale(self.attn_s2_downsample_ratio).to_tensor(
|
||||
torch.float, attn_map_s2.device) for item in gt_kernels
|
||||
])
|
||||
|
||||
attn_target_s2 = torch.stack([
|
||||
item[0].rescale(self.attn_s2_downsample_ratio).to_tensor(
|
||||
torch.float, attn_map_s2.device) for item in gt_kernels
|
||||
])
|
||||
|
||||
mask_s3 = torch.stack([
|
||||
item[2].rescale(self.attn_s3_downsample_ratio).to_tensor(
|
||||
torch.float, attn_map_s2.device) for item in gt_kernels
|
||||
])
|
||||
|
||||
attn_target_s3 = torch.stack([
|
||||
item[0].rescale(self.attn_s3_downsample_ratio).to_tensor(
|
||||
torch.float, attn_map_s2.device) for item in gt_kernels
|
||||
])
|
||||
|
||||
targets = [
|
||||
attn_target_s2, attn_target_s3, attn_target_s3, attn_target_s3
|
||||
]
|
||||
|
||||
masks = [mask_s2, mask_s3, mask_s3, mask_s3]
|
||||
|
||||
loss_attn = 0.
|
||||
for i in range(len(out_neck) - 1):
|
||||
pred = out_neck[i]
|
||||
if self.attn_with_dice_loss:
|
||||
loss_attn += self.criterion_attn(pred, targets[i], masks[i])
|
||||
else:
|
||||
loss_attn += torch.sum(
|
||||
self.criterion_attn(pred, targets[i]) *
|
||||
masks[i]) / torch.sum(masks[i])
|
||||
|
||||
return loss_attn
|
||||
|
||||
def forward(self, out_neck, out_head, gt_kernels):
|
||||
|
||||
losses = super().forward(out_neck, out_head, gt_kernels)
|
||||
|
||||
if self.with_attn:
|
||||
loss_attn = self.attn_loss(out_neck, gt_kernels)
|
||||
losses['loss_attn'] = loss_attn
|
||||
|
||||
return losses
|
|
@ -0,0 +1,5 @@
|
|||
from .cafcn_neck import CAFCNNeck
|
||||
from .fpn_ocr import FPNOCR
|
||||
from .fpn_seg import FPNSeg
|
||||
|
||||
__all__ = ['CAFCNNeck', 'FPNSeg', 'FPNOCR']
|
|
@ -0,0 +1,223 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.ops import DeformConv2dPack
|
||||
from torch import nn
|
||||
|
||||
from mmdet.models.builder import NECKS
|
||||
|
||||
|
||||
class CharAttn(nn.Module):
|
||||
"""Implementation of Character attention module in `CA-FCN.
|
||||
|
||||
<https://arxiv.org/pdf/1809.06508.pdf>`_
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels=128, out_channels=128, deformable=False):
|
||||
super().__init__()
|
||||
assert isinstance(in_channels, int)
|
||||
assert isinstance(deformable, bool)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.deformable = deformable
|
||||
|
||||
# attention layers
|
||||
self.attn_layer = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
norm_cfg=dict(type='BN')),
|
||||
ConvModule(
|
||||
in_channels,
|
||||
1,
|
||||
3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='Sigmoid')))
|
||||
|
||||
conv_kwargs = dict(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(3, 3),
|
||||
stride=1,
|
||||
padding=1)
|
||||
if self.deformable:
|
||||
self.conv = DeformConv2dPack(**conv_kwargs)
|
||||
else:
|
||||
self.conv = nn.Conv2d(**conv_kwargs)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, in_feat):
|
||||
# Calculate attn map
|
||||
attn_map = self.attn_layer(in_feat) # N * 1 * H * W
|
||||
|
||||
in_feat = self.relu(self.bn(self.conv(in_feat)))
|
||||
|
||||
out_feat_map = self._upsample_mul(in_feat, 1 + attn_map)
|
||||
|
||||
return out_feat_map, attn_map
|
||||
|
||||
def _upsample_add(self, x, y):
|
||||
return F.interpolate(x, size=y.size()[2:]) + y
|
||||
|
||||
def _upsample_mul(self, x, y):
|
||||
return F.interpolate(x, size=y.size()[2:]) * y
|
||||
|
||||
|
||||
class FeatGenerator(nn.Module):
|
||||
"""Generate attention-augmented stage feature from backbone stage
|
||||
feature."""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=512,
|
||||
out_channels=128,
|
||||
deformable=True,
|
||||
concat=False,
|
||||
upsample=False,
|
||||
with_attn=True):
|
||||
super().__init__()
|
||||
|
||||
self.concat = concat
|
||||
self.upsample = upsample
|
||||
self.with_attn = with_attn
|
||||
|
||||
if with_attn:
|
||||
self.char_attn = CharAttn(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
deformable=deformable)
|
||||
else:
|
||||
self.char_attn = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
norm_cfg=dict(type='BN'))
|
||||
|
||||
if concat:
|
||||
self.conv_to_concat = ConvModule(
|
||||
out_channels,
|
||||
out_channels,
|
||||
3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
norm_cfg=dict(type='BN'))
|
||||
|
||||
kernel_size = (3, 1) if deformable else 3
|
||||
padding = (1, 0) if deformable else 1
|
||||
tmp_in_channels = out_channels * 2 if concat else out_channels
|
||||
|
||||
self.conv_after_concat = ConvModule(
|
||||
tmp_in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=padding,
|
||||
norm_cfg=dict(type='BN'))
|
||||
|
||||
def forward(self, x, y=None, size=None):
|
||||
if self.with_attn:
|
||||
feat_map, attn_map = self.char_attn(x)
|
||||
else:
|
||||
feat_map = self.char_attn(x)
|
||||
attn_map = feat_map
|
||||
|
||||
if self.concat:
|
||||
y = self.conv_to_concat(y)
|
||||
feat_map = torch.cat((y, feat_map), dim=1)
|
||||
|
||||
feat_map = self.conv_after_concat(feat_map)
|
||||
|
||||
if self.upsample:
|
||||
feat_map = F.interpolate(feat_map, size)
|
||||
|
||||
return attn_map, feat_map
|
||||
|
||||
|
||||
@NECKS.register_module()
|
||||
class CAFCNNeck(nn.Module):
|
||||
"""Implementation of neck module in `CA-FCN.
|
||||
|
||||
<https://arxiv.org/pdf/1809.06508.pdf>`_
|
||||
|
||||
Args:
|
||||
in_channels (list[int]): Number of input channels for each scale.
|
||||
out_channels (int): Number of output channels for each scale.
|
||||
deformable (bool): If True, use deformable conv.
|
||||
with_attn (bool): If True, add attention for each output feature map.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=[128, 256, 512, 512],
|
||||
out_channels=128,
|
||||
deformable=True,
|
||||
with_attn=True):
|
||||
super().__init__()
|
||||
|
||||
self.deformable = deformable
|
||||
self.with_attn = with_attn
|
||||
|
||||
# stage_in5_to_out5
|
||||
self.s5 = FeatGenerator(
|
||||
in_channels=in_channels[-1],
|
||||
out_channels=out_channels,
|
||||
deformable=deformable,
|
||||
concat=False,
|
||||
with_attn=with_attn)
|
||||
|
||||
# stage_in4_to_out4
|
||||
self.s4 = FeatGenerator(
|
||||
in_channels=in_channels[-2],
|
||||
out_channels=out_channels,
|
||||
deformable=deformable,
|
||||
concat=True,
|
||||
with_attn=with_attn)
|
||||
|
||||
# stage_in3_to_out3
|
||||
self.s3 = FeatGenerator(
|
||||
in_channels=in_channels[-3],
|
||||
out_channels=out_channels,
|
||||
deformable=False,
|
||||
concat=True,
|
||||
upsample=True,
|
||||
with_attn=with_attn)
|
||||
|
||||
# stage_in2_to_out2
|
||||
self.s2 = FeatGenerator(
|
||||
in_channels=in_channels[-4],
|
||||
out_channels=out_channels,
|
||||
deformable=False,
|
||||
concat=True,
|
||||
upsample=True,
|
||||
with_attn=with_attn)
|
||||
|
||||
def init_weights(self):
|
||||
pass
|
||||
|
||||
def forward(self, feats):
|
||||
in_stage1 = feats[0]
|
||||
in_stage2, in_stage3 = feats[1], feats[2]
|
||||
in_stage4, in_stage5 = feats[3], feats[4]
|
||||
# out stage 5
|
||||
out_s5_attn_map, out_s5 = self.s5(in_stage5)
|
||||
|
||||
# out stage 4
|
||||
out_s4_attn_map, out_s4 = self.s4(in_stage4, out_s5)
|
||||
|
||||
# out stage 3
|
||||
out_s3_attn_map, out_s3 = self.s3(in_stage3, out_s4,
|
||||
in_stage2.size()[2:])
|
||||
|
||||
# out stage 2
|
||||
out_s2_attn_map, out_s2 = self.s2(in_stage2, out_s3,
|
||||
in_stage1.size()[2:])
|
||||
|
||||
return (out_s2_attn_map, out_s3_attn_map, out_s4_attn_map,
|
||||
out_s5_attn_map, out_s2)
|
|
@ -0,0 +1,70 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmdet.models.builder import NECKS
|
||||
|
||||
|
||||
@NECKS.register_module()
|
||||
class FPNOCR(nn.Module):
|
||||
"""FPN-like Network for segmentation based text recognition.
|
||||
|
||||
Args:
|
||||
in_channels (list[int]): Number of input channels for each scale.
|
||||
out_channels (int): Number of output channels for each scale.
|
||||
last_stage_only (bool): If True, output last stage only.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, last_stage_only=True):
|
||||
super(FPNOCR, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_ins = len(in_channels)
|
||||
|
||||
self.last_stage_only = last_stage_only
|
||||
|
||||
self.lateral_convs = nn.ModuleList()
|
||||
self.smooth_convs_1x1 = nn.ModuleList()
|
||||
self.smooth_convs_3x3 = nn.ModuleList()
|
||||
|
||||
for i in range(self.num_ins):
|
||||
l_conv = ConvModule(
|
||||
in_channels[i], out_channels, 1, norm_cfg=dict(type='BN'))
|
||||
self.lateral_convs.append(l_conv)
|
||||
|
||||
for i in range(self.num_ins - 1):
|
||||
s_conv_1x1 = ConvModule(
|
||||
out_channels * 2, out_channels, 1, norm_cfg=dict(type='BN'))
|
||||
s_conv_3x3 = ConvModule(
|
||||
out_channels,
|
||||
out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
norm_cfg=dict(type='BN'))
|
||||
self.smooth_convs_1x1.append(s_conv_1x1)
|
||||
self.smooth_convs_3x3.append(s_conv_3x3)
|
||||
|
||||
def init_weights(self):
|
||||
pass
|
||||
|
||||
def _upsample_x2(self, x):
|
||||
return F.interpolate(x, scale_factor=2, mode='bilinear')
|
||||
|
||||
def forward(self, inputs):
|
||||
lateral_features = [
|
||||
l_conv(inputs[i]) for i, l_conv in enumerate(self.lateral_convs)
|
||||
]
|
||||
|
||||
outs = []
|
||||
for i in range(len(self.smooth_convs_3x3), 0, -1): # 3, 2, 1
|
||||
last_out = lateral_features[-1] if len(outs) == 0 else outs[-1]
|
||||
upsample = self._upsample_x2(last_out)
|
||||
upsample_cat = torch.cat((upsample, lateral_features[i - 1]),
|
||||
dim=1)
|
||||
smooth_1x1 = self.smooth_convs_1x1[i - 1](upsample_cat)
|
||||
smooth_3x3 = self.smooth_convs_3x3[i - 1](smooth_1x1)
|
||||
outs.append(smooth_3x3)
|
||||
|
||||
return tuple(outs[-1:]) if self.last_stage_only else tuple(outs)
|
|
@ -0,0 +1,43 @@
|
|||
import torch.nn.functional as F
|
||||
from mmcv.runner import auto_fp16
|
||||
|
||||
from mmdet.models.builder import NECKS
|
||||
from mmdet.models.necks import FPN
|
||||
|
||||
|
||||
@NECKS.register_module()
|
||||
class FPNSeg(FPN):
|
||||
"""Feature Pyramid Network for segmentation based text recognition.
|
||||
|
||||
Args:
|
||||
in_channels (list[int]): Number of input channels for each scale.
|
||||
out_channels (int): Number of output channels for each scale.
|
||||
num_outs (int): Number of output scales.
|
||||
upsample_param (dict | None): Config dict for interpolate layer.
|
||||
Default: `dict(scale_factor=1.0, mode='nearest')`
|
||||
last_stage_only (bool): If True, output last stage of FPN only.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_outs,
|
||||
upsample_param=None,
|
||||
last_stage_only=True):
|
||||
super().__init__(in_channels, out_channels, num_outs)
|
||||
self.upsample_param = upsample_param
|
||||
self.last_stage_only = last_stage_only
|
||||
|
||||
@auto_fp16()
|
||||
def forward(self, inputs):
|
||||
outs = super().forward(inputs)
|
||||
|
||||
outs = list(outs)
|
||||
|
||||
if self.upsample_param is not None:
|
||||
outs[0] = F.interpolate(outs[0], **self.upsample_param)
|
||||
|
||||
if self.last_stage_only:
|
||||
return tuple(outs[0:1])
|
||||
|
||||
return tuple(outs[::-1])
|
|
@ -0,0 +1,13 @@
|
|||
from .base import BaseRecognizer
|
||||
from .cafcn import CAFCNNet
|
||||
from .crnn import CRNNNet
|
||||
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
||||
from .robust_scanner import RobustScanner
|
||||
from .sar import SARNet
|
||||
from .seg_recognizer import SegRecognizer
|
||||
from .transformer import TransformerNet
|
||||
|
||||
__all__ = [
|
||||
'BaseRecognizer', 'EncodeDecodeRecognizer', 'CRNNNet', 'SARNet',
|
||||
'TransformerNet', 'SegRecognizer', 'RobustScanner', 'CAFCNNet'
|
||||
]
|
|
@ -0,0 +1,236 @@
|
|||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections import OrderedDict
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import auto_fp16
|
||||
from mmcv.utils import print_log
|
||||
|
||||
from mmdet.utils import get_root_logger
|
||||
from mmocr.core import imshow_text_label
|
||||
|
||||
|
||||
class BaseRecognizer(nn.Module, metaclass=ABCMeta):
|
||||
"""Base class for text recognition."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fp16_enabled = False
|
||||
|
||||
@abstractmethod
|
||||
def extract_feat(self, imgs):
|
||||
"""Extract features from images."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def forward_train(self, imgs, img_metas, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
img (tensor): tensors with shape (N, C, H, W).
|
||||
Typically should be mean centered and std scaled.
|
||||
img_metas (list[dict]): List of image info dict where each dict
|
||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||||
For details of the values of these keys, see
|
||||
:class:`mmdet.datasets.pipelines.Collect`.
|
||||
kwargs (keyword arguments): Specific to concrete implementation.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def simple_test(self, img, img_metas, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def aug_test(self, imgs, img_metas, **kwargs):
|
||||
"""Test function with test time augmentation.
|
||||
|
||||
Args:
|
||||
imgs (list[tensor]): Tensor should have shape NxCxHxW,
|
||||
which contains all images in the batch.
|
||||
img_metas (list[list[dict]]): The metadata of images.
|
||||
"""
|
||||
pass
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
"""Initialize the weights for detector.
|
||||
|
||||
Args:
|
||||
pretrained (str, optional): Path to pre-trained weights.
|
||||
Defaults to None.
|
||||
"""
|
||||
if pretrained is not None:
|
||||
logger = get_root_logger()
|
||||
print_log(f'load model from: {pretrained}', logger=logger)
|
||||
|
||||
def forward_test(self, imgs, img_metas, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
imgs (tensor | list[tensor]): Tensor should have shape NxCxHxW,
|
||||
which contains all images in the batch.
|
||||
img_metas (list[dict] | list[list[dict]]):
|
||||
The outer list indicates images in a batch.
|
||||
"""
|
||||
if isinstance(imgs, list):
|
||||
assert len(imgs) == len(img_metas)
|
||||
assert len(imgs) > 0
|
||||
assert imgs[0].size(0) == 1, 'aug test does not support ' \
|
||||
'inference with batch size ' \
|
||||
f'{imgs[0].size(0)}'
|
||||
return self.aug_test(imgs, img_metas, **kwargs)
|
||||
|
||||
return self.simple_test(imgs, img_metas, **kwargs)
|
||||
|
||||
@auto_fp16(apply_to=('img', ))
|
||||
def forward(self, img, img_metas, return_loss=True, **kwargs):
|
||||
"""Calls either :func:`forward_train` or :func:`forward_test` depending
|
||||
on whether ``return_loss`` is ``True``.
|
||||
|
||||
Note that img and img_meta are single-nested (i.e. tensor and
|
||||
list[dict]).
|
||||
"""
|
||||
if return_loss:
|
||||
return self.forward_train(img, img_metas, **kwargs)
|
||||
|
||||
return self.forward_test(img, img_metas, **kwargs)
|
||||
|
||||
def _parse_losses(self, losses):
|
||||
"""Parse the raw outputs (losses) of the network.
|
||||
|
||||
Args:
|
||||
losses (dict): Raw outputs of the network, which usually contain
|
||||
losses and other necessary infomation.
|
||||
|
||||
Returns:
|
||||
tuple[tensor, dict]: (loss, log_vars), loss is the loss tensor \
|
||||
which may be a weighted sum of all losses, log_vars contains \
|
||||
all the variables to be sent to the logger.
|
||||
"""
|
||||
log_vars = OrderedDict()
|
||||
for loss_name, loss_value in losses.items():
|
||||
if isinstance(loss_value, torch.Tensor):
|
||||
log_vars[loss_name] = loss_value.mean()
|
||||
elif isinstance(loss_value, list):
|
||||
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
|
||||
else:
|
||||
raise TypeError(
|
||||
f'{loss_name} is not a tensor or list of tensors')
|
||||
|
||||
loss = sum(_value for _key, _value in log_vars.items()
|
||||
if 'loss' in _key)
|
||||
|
||||
log_vars['loss'] = loss
|
||||
for loss_name, loss_value in log_vars.items():
|
||||
# reduce loss when distributed training
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
loss_value = loss_value.data.clone()
|
||||
dist.all_reduce(loss_value.div_(dist.get_world_size()))
|
||||
log_vars[loss_name] = loss_value.item()
|
||||
|
||||
return loss, log_vars
|
||||
|
||||
def train_step(self, data, optimizer):
|
||||
"""The iteration step during training.
|
||||
|
||||
This method defines an iteration step during training, except for the
|
||||
back propagation and optimizer update, which are done by an optimizer
|
||||
hook. Note that in some complicated cases or models (e.g. GAN),
|
||||
the whole process (including the back propagation and optimizer update)
|
||||
is also defined by this method.
|
||||
|
||||
Args:
|
||||
data (dict): The outputs of dataloader.
|
||||
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
|
||||
runner is passed to ``train_step()``. This argument is unused
|
||||
and reserved.
|
||||
|
||||
Returns:
|
||||
dict: It should contain at least 3 keys: ``loss``, ``log_vars``, \
|
||||
``num_samples``.
|
||||
|
||||
- ``loss`` is a tensor for back propagation, which is a \
|
||||
weighted sum of multiple losses.
|
||||
- ``log_vars`` contains all the variables to be sent to the
|
||||
logger.
|
||||
- ``num_samples`` indicates the batch size used for \
|
||||
averaging the logs (Note: for the \
|
||||
DDP model, num_samples refers to the batch size for each GPU).
|
||||
"""
|
||||
losses = self(**data)
|
||||
loss, log_vars = self._parse_losses(losses)
|
||||
|
||||
outputs = dict(
|
||||
loss=loss, log_vars=log_vars, num_samples=len(data['img_metas']))
|
||||
|
||||
return outputs
|
||||
|
||||
def val_step(self, data, optimizer):
|
||||
"""The iteration step during validation.
|
||||
|
||||
This method shares the same signature as :func:`train_step`, but is
|
||||
used during val epochs. Note that the evaluation after training epochs
|
||||
is not implemented by this method, but by an evaluation hook.
|
||||
"""
|
||||
losses = self(**data)
|
||||
loss, log_vars = self._parse_losses(losses)
|
||||
|
||||
outputs = dict(
|
||||
loss=loss, log_vars=log_vars, num_samples=len(data['img_metas']))
|
||||
|
||||
return outputs
|
||||
|
||||
def show_result(self,
|
||||
img,
|
||||
result,
|
||||
gt_label='',
|
||||
win_name='',
|
||||
show=False,
|
||||
wait_time=0,
|
||||
out_file=None,
|
||||
**kwargs):
|
||||
"""Draw `result` on `img`.
|
||||
|
||||
Args:
|
||||
img (str or tensor): The image to be displayed.
|
||||
result (dict): The results to draw on `img`.
|
||||
gt_label (str): Ground truth label of img.
|
||||
win_name (str): The window name.
|
||||
wait_time (int): Value of waitKey param.
|
||||
Default: 0.
|
||||
show (bool): Whether to show the image.
|
||||
Default: False.
|
||||
out_file (str or None): The output filename.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
img (tensor): Only if not `show` or `out_file`.
|
||||
"""
|
||||
img = mmcv.imread(img)
|
||||
img = img.copy()
|
||||
pred_label = None
|
||||
if 'text' in result.keys():
|
||||
pred_label = result['text']
|
||||
|
||||
# if out_file specified, do not show image in window
|
||||
if out_file is not None:
|
||||
show = False
|
||||
# draw text label
|
||||
if pred_label is not None:
|
||||
img = imshow_text_label(
|
||||
img,
|
||||
pred_label,
|
||||
gt_label,
|
||||
show=show,
|
||||
win_name=win_name,
|
||||
wait_time=wait_time,
|
||||
out_file=out_file)
|
||||
|
||||
if not (show or out_file):
|
||||
warnings.warn('show==False and out_file is not specified, only '
|
||||
'result image will be returned')
|
||||
return img
|
||||
|
||||
return img
|
|
@ -0,0 +1,7 @@
|
|||
from mmdet.models.builder import DETECTORS
|
||||
from .seg_recognizer import SegRecognizer
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
class CAFCNNet(SegRecognizer):
|
||||
"""Implementation of `CAFCN <https://arxiv.org/pdf/1809.06508.pdf>`_"""
|
|
@ -0,0 +1,18 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmdet.models.builder import DETECTORS
|
||||
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
class CRNNNet(EncodeDecodeRecognizer):
|
||||
"""CTC-loss based recognizer."""
|
||||
|
||||
def forward_conversion(self, params, img):
|
||||
x = self.extract_feat(img)
|
||||
x = self.encoder(x)
|
||||
outs = self.decoder(x)
|
||||
outs = F.softmax(outs, dim=2)
|
||||
params = torch.pow(params, 1)
|
||||
return outs, params
|
|
@ -0,0 +1,173 @@
|
|||
from mmdet.models.builder import DETECTORS, build_backbone, build_loss
|
||||
from mmocr.models.builder import (build_convertor, build_decoder,
|
||||
build_encoder, build_preprocessor)
|
||||
from .base import BaseRecognizer
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
class EncodeDecodeRecognizer(BaseRecognizer):
|
||||
"""Base class for encode-decode recognizer."""
|
||||
|
||||
def __init__(self,
|
||||
preprocessor=None,
|
||||
backbone=None,
|
||||
encoder=None,
|
||||
decoder=None,
|
||||
loss=None,
|
||||
label_convertor=None,
|
||||
train_cfg=None,
|
||||
test_cfg=None,
|
||||
max_seq_len=40,
|
||||
pretrained=None):
|
||||
super().__init__()
|
||||
|
||||
# Label convertor (str2tensor, tensor2str)
|
||||
assert label_convertor is not None
|
||||
label_convertor.update(max_seq_len=max_seq_len)
|
||||
self.label_convertor = build_convertor(label_convertor)
|
||||
|
||||
# Preprocessor module, e.g., TPS
|
||||
self.preprocessor = None
|
||||
if preprocessor is not None:
|
||||
self.preprocessor = build_preprocessor(preprocessor)
|
||||
|
||||
# Backbone
|
||||
assert backbone is not None
|
||||
self.backbone = build_backbone(backbone)
|
||||
|
||||
# Encoder module
|
||||
self.encoder = None
|
||||
if encoder is not None:
|
||||
self.encoder = build_encoder(encoder)
|
||||
|
||||
# Decoder module
|
||||
assert decoder is not None
|
||||
decoder.update(num_classes=self.label_convertor.num_classes())
|
||||
decoder.update(start_idx=self.label_convertor.start_idx)
|
||||
decoder.update(padding_idx=self.label_convertor.padding_idx)
|
||||
decoder.update(max_seq_len=max_seq_len)
|
||||
self.decoder = build_decoder(decoder)
|
||||
|
||||
# Loss
|
||||
assert loss is not None
|
||||
loss.update(ignore_index=self.label_convertor.padding_idx)
|
||||
self.loss = build_loss(loss)
|
||||
|
||||
self.train_cfg = train_cfg
|
||||
self.test_cfg = test_cfg
|
||||
self.max_seq_len = max_seq_len
|
||||
self.init_weights(pretrained=pretrained)
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
"""Initialize the weights of recognizer."""
|
||||
super().init_weights(pretrained)
|
||||
|
||||
if self.preprocessor is not None:
|
||||
self.preprocessor.init_weights()
|
||||
|
||||
self.backbone.init_weights()
|
||||
|
||||
if self.encoder is not None:
|
||||
self.encoder.init_weights()
|
||||
|
||||
self.decoder.init_weights()
|
||||
|
||||
def extract_feat(self, img):
|
||||
"""Directly extract features from the backbone."""
|
||||
if self.preprocessor is not None:
|
||||
img = self.preprocessor(img)
|
||||
|
||||
x = self.backbone(img)
|
||||
|
||||
return x
|
||||
|
||||
def forward_train(self, img, img_metas):
|
||||
"""
|
||||
Args:
|
||||
img (tensor): Input images of shape (N, C, H, W).
|
||||
Typically these should be mean centered and std scaled.
|
||||
img_metas (list[dict]): A list of image info dict where each dict
|
||||
contains: 'img_shape', 'filename', and may also contain
|
||||
'ori_shape', and 'img_norm_cfg'.
|
||||
For details on the values of these keys see
|
||||
:class:`mmdet.datasets.pipelines.Collect`.
|
||||
|
||||
Returns:
|
||||
dict[str, tensor]: A dictionary of loss components.
|
||||
"""
|
||||
feat = self.extract_feat(img)
|
||||
|
||||
gt_labels = [img_meta['text'] for img_meta in img_metas]
|
||||
|
||||
targets_dict = self.label_convertor.str2tensor(gt_labels)
|
||||
|
||||
out_enc = None
|
||||
if self.encoder is not None:
|
||||
out_enc = self.encoder(feat, img_metas)
|
||||
|
||||
out_dec = self.decoder(
|
||||
feat, out_enc, targets_dict, img_metas, train_mode=True)
|
||||
|
||||
loss_inputs = (
|
||||
out_dec,
|
||||
targets_dict,
|
||||
)
|
||||
losses = self.loss(*loss_inputs)
|
||||
|
||||
return losses
|
||||
|
||||
def simple_test(self, img, img_metas, **kwargs):
|
||||
"""Test function with test time augmentation.
|
||||
|
||||
Args:
|
||||
imgs (torch.Tensor): Image input tensor.
|
||||
img_metas (list[dict]): List of image information.
|
||||
|
||||
Returns:
|
||||
list[str]: Text label result of each image.
|
||||
"""
|
||||
feat = self.extract_feat(img)
|
||||
|
||||
out_enc = None
|
||||
if self.encoder is not None:
|
||||
out_enc = self.encoder(feat, img_metas)
|
||||
|
||||
out_dec = self.decoder(
|
||||
feat, out_enc, None, img_metas, train_mode=False)
|
||||
|
||||
label_indexes, label_scores = \
|
||||
self.label_convertor.tensor2idx(out_dec, img_metas)
|
||||
label_strings = self.label_convertor.idx2str(label_indexes)
|
||||
|
||||
# flatten batch results
|
||||
results = []
|
||||
for string, score in zip(label_strings, label_scores):
|
||||
results.append(dict(text=string, score=score))
|
||||
|
||||
return results
|
||||
|
||||
def merge_aug_results(self, aug_results):
|
||||
out_text, out_score = '', -1
|
||||
for result in aug_results:
|
||||
text = result[0]['text']
|
||||
score = sum(result[0]['score']) / max(1, len(text))
|
||||
if score > out_score:
|
||||
out_text = text
|
||||
out_score = score
|
||||
out_results = [dict(text=out_text, score=out_score)]
|
||||
return out_results
|
||||
|
||||
def aug_test(self, imgs, img_metas, **kwargs):
|
||||
"""Test function as well as time augmentation.
|
||||
|
||||
Args:
|
||||
imgs (list[tensor]): Tensor should have shape NxCxHxW,
|
||||
which contains all images in the batch.
|
||||
img_metas (list[list[dict]]): The metadata of images.
|
||||
"""
|
||||
aug_results = []
|
||||
for img, img_meta in zip(imgs, img_metas):
|
||||
result = self.simple_test(img, img_meta, **kwargs)
|
||||
aug_results.append(result)
|
||||
|
||||
return self.merge_aug_results(aug_results)
|
|
@ -0,0 +1,7 @@
|
|||
from mmdet.models.builder import DETECTORS
|
||||
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
class SARNet(EncodeDecodeRecognizer):
|
||||
"""Implementation of `SAR <https://arxiv.org/abs/1811.00751>`_"""
|
|
@ -0,0 +1,153 @@
|
|||
from mmdet.models.builder import (DETECTORS, build_backbone, build_head,
|
||||
build_loss, build_neck)
|
||||
from mmocr.models.builder import build_convertor, build_preprocessor
|
||||
from .base import BaseRecognizer
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
class SegRecognizer(BaseRecognizer):
|
||||
"""Base class for segmentation based recognizer."""
|
||||
|
||||
def __init__(self,
|
||||
preprocessor=None,
|
||||
backbone=None,
|
||||
neck=None,
|
||||
head=None,
|
||||
loss=None,
|
||||
label_convertor=None,
|
||||
train_cfg=None,
|
||||
test_cfg=None,
|
||||
pretrained=None):
|
||||
super().__init__()
|
||||
|
||||
# Label_convertor
|
||||
assert label_convertor is not None
|
||||
self.label_convertor = build_convertor(label_convertor)
|
||||
|
||||
# Preprocessor module, e.g., TPS
|
||||
self.preprocessor = None
|
||||
if preprocessor is not None:
|
||||
self.preprocessor = build_preprocessor(preprocessor)
|
||||
|
||||
# Backbone
|
||||
assert backbone is not None
|
||||
self.backbone = build_backbone(backbone)
|
||||
|
||||
# Neck
|
||||
assert neck is not None
|
||||
self.neck = build_neck(neck)
|
||||
|
||||
# Head
|
||||
assert head is not None
|
||||
head.update(num_classes=self.label_convertor.num_classes())
|
||||
self.head = build_head(head)
|
||||
|
||||
# Loss
|
||||
assert loss is not None
|
||||
self.loss = build_loss(loss)
|
||||
|
||||
self.train_cfg = train_cfg
|
||||
self.test_cfg = test_cfg
|
||||
self.init_weights(pretrained=pretrained)
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
"""Initialize the weights of recognizer."""
|
||||
super().init_weights(pretrained)
|
||||
|
||||
if self.preprocessor is not None:
|
||||
self.preprocessor.init_weights()
|
||||
|
||||
self.backbone.init_weights(pretrained=pretrained)
|
||||
|
||||
if self.neck is not None:
|
||||
self.neck.init_weights()
|
||||
|
||||
self.head.init_weights()
|
||||
|
||||
def extract_feat(self, img):
|
||||
"""Directly extract features from the backbone."""
|
||||
if self.preprocessor is not None:
|
||||
img = self.preprocessor(img)
|
||||
|
||||
x = self.backbone(img)
|
||||
|
||||
return x
|
||||
|
||||
def forward_train(self, img, img_metas, gt_kernels=None):
|
||||
"""
|
||||
Args:
|
||||
img (tensor): Input images of shape (N, C, H, W).
|
||||
Typically these should be mean centered and std scaled.
|
||||
img_metas (list[dict]): A list of image info dict where each dict
|
||||
contains: 'img_shape', 'filename', and may also contain
|
||||
'ori_shape', and 'img_norm_cfg'.
|
||||
For details on the values of these keys see
|
||||
:class:`mmdet.datasets.pipelines.Collect`.
|
||||
|
||||
Returns:
|
||||
dict[str, tensor]: A dictionary of loss components.
|
||||
"""
|
||||
|
||||
feats = self.extract_feat(img)
|
||||
|
||||
out_neck = self.neck(feats)
|
||||
|
||||
out_head = self.head(out_neck)
|
||||
|
||||
loss_inputs = (out_neck, out_head, gt_kernels)
|
||||
|
||||
losses = self.loss(*loss_inputs)
|
||||
|
||||
return losses
|
||||
|
||||
def simple_test(self, img, img_metas, **kwargs):
|
||||
"""Test function without test time augmentation.
|
||||
|
||||
Args:
|
||||
imgs (torch.Tensor): Image input tensor.
|
||||
img_metas (list[dict]): List of image information.
|
||||
|
||||
Returns:
|
||||
list[str]: Text label result of each image.
|
||||
"""
|
||||
|
||||
feat = self.extract_feat(img)
|
||||
|
||||
out_neck = self.neck(feat)
|
||||
|
||||
out_head = self.head(out_neck)
|
||||
|
||||
texts, scores = self.label_convertor.tensor2str(out_head, img_metas)
|
||||
|
||||
# flatten batch results
|
||||
results = []
|
||||
for text, score in zip(texts, scores):
|
||||
results.append(dict(text=text, score=score))
|
||||
|
||||
return results
|
||||
|
||||
def merge_aug_results(self, aug_results):
|
||||
out_text, out_score = '', -1
|
||||
for result in aug_results:
|
||||
text = result[0]['text']
|
||||
score = sum(result[0]['score']) / max(1, len(text))
|
||||
if score > out_score:
|
||||
out_text = text
|
||||
out_score = score
|
||||
out_results = [dict(text=out_text, score=out_score)]
|
||||
return out_results
|
||||
|
||||
def aug_test(self, imgs, img_metas, **kwargs):
|
||||
"""Test function with test time augmentation.
|
||||
|
||||
Args:
|
||||
imgs (list[tensor]): Tensor should have shape NxCxHxW,
|
||||
which contains all images in the batch.
|
||||
img_metas (list[list[dict]]): The metadata of images.
|
||||
"""
|
||||
aug_results = []
|
||||
for img, img_meta in zip(imgs, img_metas):
|
||||
result = self.simple_test(img, img_meta, **kwargs)
|
||||
aug_results.append(result)
|
||||
|
||||
return self.merge_aug_results(aug_results)
|
|
@ -0,0 +1,7 @@
|
|||
from mmdet.models.builder import DETECTORS
|
||||
from .encode_decode_recognizer import EncodeDecodeRecognizer
|
||||
|
||||
|
||||
@DETECTORS.register_module()
|
||||
class TransformerNet(EncodeDecodeRecognizer):
|
||||
"""Implementation of Transformer based OCR."""
|
|
@ -0,0 +1,74 @@
|
|||
import os.path as osp
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mmocr.datasets.base_dataset import BaseDataset
|
||||
|
||||
|
||||
def _create_dummy_ann_file(ann_file):
|
||||
ann_info1 = 'sample1.jpg hello'
|
||||
ann_info2 = 'sample2.jpg world'
|
||||
|
||||
with open(ann_file, 'w') as fw:
|
||||
for ann_info in [ann_info1, ann_info2]:
|
||||
fw.write(ann_info + '\n')
|
||||
|
||||
|
||||
def _create_dummy_loader():
|
||||
loader = dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=1,
|
||||
parser=dict(type='LineStrParser', keys=['file_name', 'text']))
|
||||
return loader
|
||||
|
||||
|
||||
def test_custom_dataset():
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
# create dummy data
|
||||
ann_file = osp.join(tmp_dir.name, 'fake_data.txt')
|
||||
_create_dummy_ann_file(ann_file)
|
||||
loader = _create_dummy_loader()
|
||||
|
||||
for mode in [True, False]:
|
||||
dataset = BaseDataset(ann_file, loader, pipeline=[], test_mode=mode)
|
||||
|
||||
# test len
|
||||
assert len(dataset) == len(dataset.data_infos)
|
||||
|
||||
# test set group flag
|
||||
assert np.allclose(dataset.flag, [0, 0])
|
||||
|
||||
# test prepare_train_img
|
||||
expect_results = {
|
||||
'img_info': {
|
||||
'file_name': 'sample1.jpg',
|
||||
'text': 'hello'
|
||||
},
|
||||
'img_prefix': ''
|
||||
}
|
||||
assert dataset.prepare_train_img(0) == expect_results
|
||||
|
||||
# test prepare_test_img
|
||||
assert dataset.prepare_test_img(0) == expect_results
|
||||
|
||||
# test __getitem__
|
||||
assert dataset[0] == expect_results
|
||||
|
||||
# test get_next_index
|
||||
assert dataset._get_next_index(0) == 1
|
||||
|
||||
# test format_resuls
|
||||
expect_results_copy = {
|
||||
key: value
|
||||
for key, value in expect_results.items()
|
||||
}
|
||||
dataset.format_results(expect_results)
|
||||
assert expect_results_copy == expect_results
|
||||
|
||||
# test evaluate
|
||||
with pytest.raises(NotImplementedError):
|
||||
dataset.evaluate(expect_results)
|
||||
|
||||
tmp_dir.cleanup()
|
|
@ -0,0 +1,96 @@
|
|||
import math
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mmocr.datasets.pipelines.crop import (box_jitter, convert_canonical,
|
||||
crop_img, sort_vertex, warp_img)
|
||||
|
||||
|
||||
def test_order_vertex():
|
||||
dummy_points_x = [20, 20, 120, 120]
|
||||
dummy_points_y = [20, 40, 40, 20]
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
sort_vertex([], dummy_points_y)
|
||||
with pytest.raises(AssertionError):
|
||||
sort_vertex(dummy_points_x, [])
|
||||
|
||||
ordered_points_x, ordered_points_y = sort_vertex(dummy_points_x,
|
||||
dummy_points_y)
|
||||
|
||||
expect_points_x = [20, 120, 120, 20]
|
||||
expect_points_y = [20, 20, 40, 40]
|
||||
|
||||
assert np.allclose(ordered_points_x, expect_points_x)
|
||||
assert np.allclose(ordered_points_y, expect_points_y)
|
||||
|
||||
|
||||
def test_convert_canonical():
|
||||
dummy_points_x = [120, 120, 20, 20]
|
||||
dummy_points_y = [20, 40, 40, 20]
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
convert_canonical([], dummy_points_y)
|
||||
with pytest.raises(AssertionError):
|
||||
convert_canonical(dummy_points_x, [])
|
||||
|
||||
ordered_points_x, ordered_points_y = convert_canonical(
|
||||
dummy_points_x, dummy_points_y)
|
||||
|
||||
expect_points_x = [20, 120, 120, 20]
|
||||
expect_points_y = [20, 20, 40, 40]
|
||||
|
||||
assert np.allclose(ordered_points_x, expect_points_x)
|
||||
assert np.allclose(ordered_points_y, expect_points_y)
|
||||
|
||||
|
||||
def test_box_jitter():
|
||||
dummy_points_x = [20, 120, 120, 20]
|
||||
dummy_points_y = [20, 20, 40, 40]
|
||||
|
||||
kwargs = dict(jitter_ratio_x=0.0, jitter_ratio_y=0.0)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
box_jitter([], dummy_points_y)
|
||||
with pytest.raises(AssertionError):
|
||||
box_jitter(dummy_points_x, [])
|
||||
with pytest.raises(AssertionError):
|
||||
box_jitter(dummy_points_x, dummy_points_y, jitter_ratio_x=1.)
|
||||
with pytest.raises(AssertionError):
|
||||
box_jitter(dummy_points_x, dummy_points_y, jitter_ratio_y=1.)
|
||||
|
||||
box_jitter(dummy_points_x, dummy_points_y, **kwargs)
|
||||
|
||||
assert np.allclose(dummy_points_x, [20, 120, 120, 20])
|
||||
assert np.allclose(dummy_points_y, [20, 20, 40, 40])
|
||||
|
||||
|
||||
def test_opencv_crop():
|
||||
dummy_img = np.ones((600, 600, 3), dtype=np.uint8)
|
||||
dummy_box = [20, 20, 120, 20, 120, 40, 20, 40]
|
||||
|
||||
cropped_img = warp_img(dummy_img, dummy_box)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
warp_img(dummy_img, [])
|
||||
with pytest.raises(AssertionError):
|
||||
warp_img(dummy_img, [20, 40, 40, 20])
|
||||
|
||||
assert math.isclose(cropped_img.shape[0], 20)
|
||||
assert math.isclose(cropped_img.shape[1], 100)
|
||||
|
||||
|
||||
def test_min_rect_crop():
|
||||
dummy_img = np.ones((600, 600, 3), dtype=np.uint8)
|
||||
dummy_box = [20, 20, 120, 20, 120, 40, 20, 40]
|
||||
|
||||
cropped_img = crop_img(dummy_img, dummy_box)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
crop_img(dummy_img, [])
|
||||
with pytest.raises(AssertionError):
|
||||
crop_img(dummy_img, [20, 40, 40, 20])
|
||||
|
||||
assert math.isclose(cropped_img.shape[0], 20)
|
||||
assert math.isclose(cropped_img.shape[1], 100)
|
|
@ -0,0 +1,83 @@
|
|||
import json
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mmocr.datasets.text_det_dataset import TextDetDataset
|
||||
|
||||
|
||||
def _create_dummy_ann_file(ann_file):
|
||||
ann_info1 = {
|
||||
'file_name':
|
||||
'sample1.jpg',
|
||||
'height':
|
||||
640,
|
||||
'width':
|
||||
640,
|
||||
'annotations': [{
|
||||
'iscrowd': 0,
|
||||
'category_id': 1,
|
||||
'bbox': [50, 70, 80, 100],
|
||||
'segmentation': [[50, 70, 80, 70, 80, 100, 50, 100]]
|
||||
}, {
|
||||
'iscrowd':
|
||||
1,
|
||||
'category_id':
|
||||
1,
|
||||
'bbox': [120, 140, 200, 200],
|
||||
'segmentation': [[120, 140, 200, 140, 200, 200, 120, 200]]
|
||||
}]
|
||||
}
|
||||
|
||||
with open(ann_file, 'w') as fw:
|
||||
fw.write(json.dumps(ann_info1) + '\n')
|
||||
|
||||
|
||||
def _create_dummy_loader():
|
||||
loader = dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=1,
|
||||
parser=dict(
|
||||
type='LineJsonParser',
|
||||
keys=['file_name', 'height', 'width', 'annotations']))
|
||||
return loader
|
||||
|
||||
|
||||
def test_detect_dataset():
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
# create dummy data
|
||||
ann_file = osp.join(tmp_dir.name, 'fake_data.txt')
|
||||
_create_dummy_ann_file(ann_file)
|
||||
|
||||
# test initialization
|
||||
loader = _create_dummy_loader()
|
||||
dataset = TextDetDataset(ann_file, loader, pipeline=[])
|
||||
|
||||
# test _parse_ann_info
|
||||
img_ann_info = dataset.data_infos[0]
|
||||
ann = dataset._parse_anno_info(img_ann_info['annotations'])
|
||||
print(ann['bboxes'])
|
||||
assert np.allclose(ann['bboxes'], [[50., 70., 80., 100.]])
|
||||
assert np.allclose(ann['labels'], [1])
|
||||
assert np.allclose(ann['bboxes_ignore'], [[120, 140, 200, 200]])
|
||||
assert np.allclose(ann['masks'], [[[50, 70, 80, 70, 80, 100, 50, 100]]])
|
||||
assert np.allclose(ann['masks_ignore'],
|
||||
[[[120, 140, 200, 140, 200, 200, 120, 200]]])
|
||||
|
||||
tmp_dir.cleanup()
|
||||
|
||||
# test prepare_train_img
|
||||
pipeline_results = dataset.prepare_train_img(0)
|
||||
assert np.allclose(pipeline_results['bbox_fields'], [])
|
||||
assert np.allclose(pipeline_results['mask_fields'], [])
|
||||
assert np.allclose(pipeline_results['seg_fields'], [])
|
||||
expect_img_info = {'filename': 'sample1.jpg', 'height': 640, 'width': 640}
|
||||
assert pipeline_results['img_info'] == expect_img_info
|
||||
|
||||
# test evluation
|
||||
metrics = 'hmean-iou'
|
||||
results = [{'boundary_result': [[50, 70, 80, 70, 80, 100, 50, 100, 1]]}]
|
||||
eval_res = dataset.evaluate(results, metrics)
|
||||
|
||||
assert eval_res['hmean-iou:hmean'] == 1
|
|
@ -0,0 +1,71 @@
|
|||
import json
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
from tools.data.utils.txt2lmdb import converter
|
||||
|
||||
from mmocr.datasets.utils.loader import HardDiskLoader, LmdbLoader, Loader
|
||||
|
||||
|
||||
def _create_dummy_line_str_file(ann_file):
|
||||
ann_info1 = 'sample1.jpg hello'
|
||||
ann_info2 = 'sample2.jpg world'
|
||||
|
||||
with open(ann_file, 'w') as fw:
|
||||
for ann_info in [ann_info1, ann_info2]:
|
||||
fw.write(ann_info + '\n')
|
||||
|
||||
|
||||
def _create_dummy_line_json_file(ann_file):
|
||||
ann_info1 = {'filename': 'sample1.jpg', 'text': 'hello'}
|
||||
ann_info2 = {'filename': 'sample2.jpg', 'text': 'world'}
|
||||
|
||||
with open(ann_file, 'w') as fw:
|
||||
for ann_info in [ann_info1, ann_info2]:
|
||||
fw.write(json.dumps(ann_info) + '\n')
|
||||
|
||||
|
||||
def test_loader():
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
# create dummy data
|
||||
ann_file = osp.join(tmp_dir.name, 'fake_data.txt')
|
||||
_create_dummy_line_str_file(ann_file)
|
||||
|
||||
parser = dict(
|
||||
type='LineStrParser',
|
||||
keys=['filename', 'text'],
|
||||
keys_idx=[0, 1],
|
||||
separator=' ')
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
Loader(ann_file, parser, repeat=0)
|
||||
with pytest.raises(AssertionError):
|
||||
Loader(ann_file, [], repeat=1)
|
||||
with pytest.raises(AssertionError):
|
||||
Loader('sample.txt', parser, repeat=1)
|
||||
with pytest.raises(NotImplementedError):
|
||||
loader = Loader(ann_file, parser, repeat=1)
|
||||
print(loader)
|
||||
|
||||
# test text loader and line str parser
|
||||
text_loader = HardDiskLoader(ann_file, parser, repeat=1)
|
||||
assert len(text_loader) == 2
|
||||
assert text_loader.ori_data_infos[0] == 'sample1.jpg hello'
|
||||
assert text_loader[0] == {'filename': 'sample1.jpg', 'text': 'hello'}
|
||||
|
||||
# test text loader and linedict parser
|
||||
_create_dummy_line_json_file(ann_file)
|
||||
json_parser = dict(type='LineJsonParser', keys=['filename', 'text'])
|
||||
text_loader = HardDiskLoader(ann_file, json_parser, repeat=1)
|
||||
assert text_loader[0] == {'filename': 'sample1.jpg', 'text': 'hello'}
|
||||
|
||||
# test lmdb loader and line str parser
|
||||
_create_dummy_line_str_file(ann_file)
|
||||
lmdb_file = osp.join(tmp_dir.name, 'fake_data.lmdb')
|
||||
converter(ann_file, lmdb_file)
|
||||
|
||||
lmdb_loader = LmdbLoader(lmdb_file, parser, repeat=1)
|
||||
assert lmdb_loader[0] == {'filename': 'sample1.jpg', 'text': 'hello'}
|
||||
|
||||
tmp_dir.cleanup()
|
|
@ -0,0 +1,51 @@
|
|||
import math
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
|
||||
from mmocr.datasets.ocr_dataset import OCRDataset
|
||||
|
||||
|
||||
def _create_dummy_ann_file(ann_file):
|
||||
ann_info1 = 'sample1.jpg hello'
|
||||
ann_info2 = 'sample2.jpg world'
|
||||
|
||||
with open(ann_file, 'w') as fw:
|
||||
for ann_info in [ann_info1, ann_info2]:
|
||||
fw.write(ann_info + '\n')
|
||||
|
||||
|
||||
def _create_dummy_loader():
|
||||
loader = dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=1,
|
||||
parser=dict(type='LineStrParser', keys=['file_name', 'text']))
|
||||
return loader
|
||||
|
||||
|
||||
def test_detect_dataset():
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
# create dummy data
|
||||
ann_file = osp.join(tmp_dir.name, 'fake_data.txt')
|
||||
_create_dummy_ann_file(ann_file)
|
||||
|
||||
# test initialization
|
||||
loader = _create_dummy_loader()
|
||||
dataset = OCRDataset(ann_file, loader, pipeline=[])
|
||||
|
||||
tmp_dir.cleanup()
|
||||
|
||||
# test pre_pipeline
|
||||
img_info = dataset.data_infos[0]
|
||||
results = dict(img_info=img_info)
|
||||
dataset.pre_pipeline(results)
|
||||
assert results['img_prefix'] == dataset.img_prefix
|
||||
assert results['text'] == img_info['text']
|
||||
|
||||
# test evluation
|
||||
metric = 'acc'
|
||||
results = [{'text': 'hello'}, {'text': 'worl'}]
|
||||
eval_res = dataset.evaluate(results, metric)
|
||||
|
||||
assert math.isclose(eval_res['word_acc'], 0.5, abs_tol=1e-4)
|
||||
assert math.isclose(eval_res['char_precision'], 1.0, abs_tol=1e-4)
|
||||
assert math.isclose(eval_res['char_recall'], 0.9, abs_tol=1e-4)
|
|
@ -0,0 +1,127 @@
|
|||
import json
|
||||
import math
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from mmocr.datasets.ocr_seg_dataset import OCRSegDataset
|
||||
|
||||
|
||||
def _create_dummy_ann_file(ann_file):
|
||||
ann_info1 = {
|
||||
'file_name':
|
||||
'sample1.png',
|
||||
'annotations': [{
|
||||
'char_text':
|
||||
'F',
|
||||
'char_box': [11.0, 0.0, 22.0, 0.0, 12.0, 12.0, 0.0, 12.0]
|
||||
}, {
|
||||
'char_text':
|
||||
'r',
|
||||
'char_box': [23.0, 2.0, 31.0, 1.0, 24.0, 11.0, 16.0, 11.0]
|
||||
}, {
|
||||
'char_text':
|
||||
'o',
|
||||
'char_box': [33.0, 2.0, 43.0, 2.0, 36.0, 12.0, 25.0, 12.0]
|
||||
}, {
|
||||
'char_text':
|
||||
'm',
|
||||
'char_box': [46.0, 2.0, 61.0, 2.0, 53.0, 12.0, 39.0, 12.0]
|
||||
}, {
|
||||
'char_text':
|
||||
':',
|
||||
'char_box': [61.0, 2.0, 69.0, 2.0, 63.0, 12.0, 55.0, 12.0]
|
||||
}],
|
||||
'text':
|
||||
'From:'
|
||||
}
|
||||
ann_info2 = {
|
||||
'file_name':
|
||||
'sample2.png',
|
||||
'annotations': [{
|
||||
'char_text': 'o',
|
||||
'char_box': [0.0, 5.0, 7.0, 5.0, 9.0, 15.0, 2.0, 15.0]
|
||||
}, {
|
||||
'char_text':
|
||||
'u',
|
||||
'char_box': [7.0, 4.0, 14.0, 4.0, 18.0, 18.0, 11.0, 18.0]
|
||||
}, {
|
||||
'char_text':
|
||||
't',
|
||||
'char_box': [13.0, 1.0, 19.0, 2.0, 24.0, 18.0, 17.0, 18.0]
|
||||
}],
|
||||
'text':
|
||||
'out'
|
||||
}
|
||||
|
||||
with open(ann_file, 'w') as fw:
|
||||
for ann_info in [ann_info1, ann_info2]:
|
||||
fw.write(json.dumps(ann_info) + '\n')
|
||||
|
||||
return ann_info1, ann_info2
|
||||
|
||||
|
||||
def _create_dummy_loader():
|
||||
loader = dict(
|
||||
type='HardDiskLoader',
|
||||
repeat=1,
|
||||
parser=dict(
|
||||
type='LineJsonParser', keys=['file_name', 'text', 'annotations']))
|
||||
return loader
|
||||
|
||||
|
||||
def test_ocr_seg_dataset():
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
# create dummy data
|
||||
ann_file = osp.join(tmp_dir.name, 'fake_data.txt')
|
||||
ann_info1, ann_info2 = _create_dummy_ann_file(ann_file)
|
||||
|
||||
# test initialization
|
||||
loader = _create_dummy_loader()
|
||||
dataset = OCRSegDataset(ann_file, loader, pipeline=[])
|
||||
|
||||
tmp_dir.cleanup()
|
||||
|
||||
# test pre_pipeline
|
||||
img_info = dataset.data_infos[0]
|
||||
results = dict(img_info=img_info)
|
||||
dataset.pre_pipeline(results)
|
||||
assert results['img_prefix'] == dataset.img_prefix
|
||||
|
||||
# test _parse_anno_info
|
||||
annos = ann_info1['annotations']
|
||||
with pytest.raises(AssertionError):
|
||||
dataset._parse_anno_info(annos[0])
|
||||
annos2 = ann_info2['annotations']
|
||||
with pytest.raises(AssertionError):
|
||||
dataset._parse_anno_info([{'char_text': 'i'}])
|
||||
with pytest.raises(AssertionError):
|
||||
dataset._parse_anno_info([{'char_box': [1, 2, 3, 4, 5, 6, 7, 8]}])
|
||||
annos2[0]['char_box'] = [1, 2, 3]
|
||||
with pytest.raises(AssertionError):
|
||||
dataset._parse_anno_info(annos2)
|
||||
|
||||
return_anno = dataset._parse_anno_info(annos)
|
||||
assert return_anno['chars'] == ['F', 'r', 'o', 'm', ':']
|
||||
assert len(return_anno['char_rects']) == 5
|
||||
|
||||
# test prepare_train_img
|
||||
expect_results = {
|
||||
'img_info': {
|
||||
'filename': 'sample1.png'
|
||||
},
|
||||
'img_prefix': '',
|
||||
'ann_info': return_anno
|
||||
}
|
||||
data = dataset.prepare_train_img(0)
|
||||
assert data == expect_results
|
||||
|
||||
# test evluation
|
||||
metric = 'acc'
|
||||
results = [{'text': 'From:'}, {'text': 'ou'}]
|
||||
eval_res = dataset.evaluate(results, metric)
|
||||
|
||||
assert math.isclose(eval_res['word_acc'], 0.5, abs_tol=1e-4)
|
||||
assert math.isclose(eval_res['char_precision'], 1.0, abs_tol=1e-4)
|
||||
assert math.isclose(eval_res['char_recall'], 0.857, abs_tol=1e-4)
|
|
@ -0,0 +1,93 @@
|
|||
import os.path as osp
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mmocr.datasets.pipelines.ocr_seg_targets import OCRSegTargets
|
||||
|
||||
|
||||
def _create_dummy_dict_file(dict_file):
|
||||
chars = list('0123456789')
|
||||
with open(dict_file, 'w') as fw:
|
||||
for char in chars:
|
||||
fw.write(char + '\n')
|
||||
|
||||
|
||||
def test_ocr_segm_targets():
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
# create dummy dict file
|
||||
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||
_create_dummy_dict_file(dict_file)
|
||||
# dummy label convertor
|
||||
label_convertor = dict(
|
||||
type='SegConvertor',
|
||||
dict_file=dict_file,
|
||||
with_unknown=True,
|
||||
lower=True)
|
||||
# test init
|
||||
with pytest.raises(AssertionError):
|
||||
OCRSegTargets(None, 0.5, 0.5)
|
||||
with pytest.raises(AssertionError):
|
||||
OCRSegTargets(label_convertor, '1by2', 0.5)
|
||||
with pytest.raises(AssertionError):
|
||||
OCRSegTargets(label_convertor, 0.5, 2)
|
||||
|
||||
ocr_seg_tgt = OCRSegTargets(label_convertor, 0.5, 0.5)
|
||||
# test generate kernels
|
||||
img_size = (8, 8)
|
||||
pad_size = (8, 10)
|
||||
char_boxes = [[2, 2, 6, 6]]
|
||||
char_idxs = [2]
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
ocr_seg_tgt.generate_kernels(8, pad_size, char_boxes, char_idxs, 0.5,
|
||||
True)
|
||||
with pytest.raises(AssertionError):
|
||||
ocr_seg_tgt.generate_kernels(img_size, pad_size, [2, 2, 6, 6],
|
||||
char_idxs, 0.5, True)
|
||||
with pytest.raises(AssertionError):
|
||||
ocr_seg_tgt.generate_kernels(img_size, pad_size, char_boxes, 2, 0.5,
|
||||
True)
|
||||
|
||||
attn_tgt = ocr_seg_tgt.generate_kernels(
|
||||
img_size, pad_size, char_boxes, char_idxs, 0.5, binary=True)
|
||||
expect_attn_tgt = [[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
|
||||
[0, 0, 0, 1, 1, 1, 0, 0, 255, 255],
|
||||
[0, 0, 0, 1, 1, 1, 0, 0, 255, 255],
|
||||
[0, 0, 0, 1, 1, 1, 0, 0, 255, 255],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255]]
|
||||
assert np.allclose(attn_tgt, np.array(expect_attn_tgt, dtype=np.int32))
|
||||
|
||||
segm_tgt = ocr_seg_tgt.generate_kernels(
|
||||
img_size, pad_size, char_boxes, char_idxs, 0.5, binary=False)
|
||||
expect_segm_tgt = [[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
|
||||
[0, 0, 0, 2, 2, 2, 0, 0, 255, 255],
|
||||
[0, 0, 0, 2, 2, 2, 0, 0, 255, 255],
|
||||
[0, 0, 0, 2, 2, 2, 0, 0, 255, 255],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 255, 255]]
|
||||
assert np.allclose(segm_tgt, np.array(expect_segm_tgt, dtype=np.int32))
|
||||
|
||||
# test __call__
|
||||
results = {}
|
||||
results['img_shape'] = (4, 4, 3)
|
||||
results['resize_shape'] = (8, 8, 3)
|
||||
results['pad_shape'] = (8, 10)
|
||||
results['ann_info'] = {}
|
||||
results['ann_info']['char_rects'] = [[1, 1, 3, 3]]
|
||||
results['ann_info']['chars'] = ['1']
|
||||
|
||||
results = ocr_seg_tgt(results)
|
||||
assert results['mask_fields'] == ['gt_kernels']
|
||||
assert np.allclose(results['gt_kernels'].masks[0],
|
||||
np.array(expect_attn_tgt, dtype=np.int32))
|
||||
assert np.allclose(results['gt_kernels'].masks[1],
|
||||
np.array(expect_segm_tgt, dtype=np.int32))
|
||||
|
||||
tmp_dir.cleanup()
|
|
@ -0,0 +1,94 @@
|
|||
import math
|
||||
import unittest.mock as mock
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as TF
|
||||
|
||||
import mmocr.datasets.pipelines.ocr_transforms as transforms
|
||||
|
||||
|
||||
def test_resize_ocr():
|
||||
input_img = np.ones((64, 256, 3), dtype=np.uint8)
|
||||
|
||||
rci = transforms.ResizeOCR(
|
||||
32, min_width=32, max_width=160, keep_aspect_ratio=True)
|
||||
results = {'img_shape': input_img.shape, 'img': input_img}
|
||||
|
||||
# test call
|
||||
results = rci(results)
|
||||
assert np.allclose([32, 160, 3], results['pad_shape'])
|
||||
assert np.allclose([32, 160, 3], results['img'].shape)
|
||||
assert 'valid_ratio' in results
|
||||
assert math.isclose(results['valid_ratio'], 0.8)
|
||||
assert math.isclose(np.sum(results['img'][:, 129:, :]), 0)
|
||||
|
||||
rci = transforms.ResizeOCR(
|
||||
32, min_width=32, max_width=160, keep_aspect_ratio=False)
|
||||
results = {'img_shape': input_img.shape, 'img': input_img}
|
||||
results = rci(results)
|
||||
assert math.isclose(results['valid_ratio'], 1)
|
||||
|
||||
|
||||
def test_to_tensor():
|
||||
input_img = np.ones((64, 256, 3), dtype=np.uint8)
|
||||
|
||||
expect_output = TF.to_tensor(input_img)
|
||||
rci = transforms.ToTensorOCR()
|
||||
|
||||
results = {'img': input_img}
|
||||
results = rci(results)
|
||||
|
||||
assert np.allclose(results['img'].numpy(), expect_output.numpy())
|
||||
|
||||
|
||||
def test_normalize():
|
||||
inputs = torch.zeros(3, 10, 10)
|
||||
|
||||
expect_output = torch.ones_like(inputs) * (-1)
|
||||
rci = transforms.NormalizeOCR(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
|
||||
results = {'img': inputs}
|
||||
results = rci(results)
|
||||
|
||||
assert np.allclose(results['img'].numpy(), expect_output.numpy())
|
||||
|
||||
|
||||
@mock.patch('%s.transforms.np.random.random' % __name__)
|
||||
def test_online_crop(mock_random):
|
||||
kwargs = dict(
|
||||
box_keys=['x1', 'y1', 'x2', 'y2', 'x3', 'y3', 'x4', 'y4'],
|
||||
jitter_prob=0.5,
|
||||
max_jitter_ratio_x=0.05,
|
||||
max_jitter_ratio_y=0.02)
|
||||
|
||||
mock_random.side_effect = [0.1, 1, 1, 1]
|
||||
|
||||
src_img = np.ones((100, 100, 3), dtype=np.uint8)
|
||||
results = {
|
||||
'img': src_img,
|
||||
'img_info': {
|
||||
'x1': '20',
|
||||
'y1': '20',
|
||||
'x2': '40',
|
||||
'y2': '20',
|
||||
'x3': '40',
|
||||
'y3': '40',
|
||||
'x4': '20',
|
||||
'y4': '40'
|
||||
}
|
||||
}
|
||||
|
||||
rci = transforms.OnlineCropOCR(**kwargs)
|
||||
|
||||
results = rci(results)
|
||||
|
||||
assert np.allclose(results['img_shape'], [20, 20, 3])
|
||||
|
||||
# test not crop
|
||||
mock_random.side_effect = [0.1, 1, 1, 1]
|
||||
results['img_info'] = {}
|
||||
results['img'] = src_img
|
||||
|
||||
results = rci(results)
|
||||
assert np.allclose(results['img'].shape, [100, 100, 3])
|
|
@ -0,0 +1,59 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from mmocr.datasets.utils.parser import LineJsonParser, LineStrParser
|
||||
|
||||
|
||||
def test_line_str_parser():
|
||||
data_ret = ['sample1.jpg hello', 'sample2.jpg world']
|
||||
keys = ['filename', 'text']
|
||||
keys_idx = [0, 1]
|
||||
separator = ' '
|
||||
|
||||
# test init
|
||||
with pytest.raises(AssertionError):
|
||||
parser = LineStrParser('filename', keys_idx, separator)
|
||||
with pytest.raises(AssertionError):
|
||||
parser = LineStrParser(keys, keys_idx, [' '])
|
||||
with pytest.raises(AssertionError):
|
||||
parser = LineStrParser(keys, [0], separator)
|
||||
|
||||
# test get_item
|
||||
parser = LineStrParser(keys, keys_idx, separator)
|
||||
assert parser.get_item(data_ret, 0) == \
|
||||
{'filename': 'sample1.jpg', 'text': 'hello'}
|
||||
|
||||
with pytest.raises(Exception):
|
||||
parser = LineStrParser(['filename', 'text', 'ignore'], [0, 1, 2],
|
||||
separator)
|
||||
parser.get_item(data_ret, 0)
|
||||
|
||||
|
||||
def test_line_dict_parser():
|
||||
data_ret = [
|
||||
json.dumps({
|
||||
'filename': 'sample1.jpg',
|
||||
'text': 'hello'
|
||||
}),
|
||||
json.dumps({
|
||||
'filename': 'sample2.jpg',
|
||||
'text': 'world'
|
||||
})
|
||||
]
|
||||
keys = ['filename', 'text']
|
||||
|
||||
# test init
|
||||
with pytest.raises(AssertionError):
|
||||
parser = LineJsonParser('filename')
|
||||
with pytest.raises(AssertionError):
|
||||
parser = LineJsonParser([])
|
||||
|
||||
# test get_item
|
||||
parser = LineJsonParser(keys)
|
||||
assert parser.get_item(data_ret, 0) == \
|
||||
{'filename': 'sample1.jpg', 'text': 'hello'}
|
||||
|
||||
with pytest.raises(Exception):
|
||||
parser = LineJsonParser(['img_name', 'text'])
|
||||
parser.get_item(data_ret, 0)
|
|
@ -0,0 +1,33 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mmocr.datasets.pipelines.test_time_aug import MultiRotateAugOCR
|
||||
|
||||
|
||||
def test_resize_ocr():
|
||||
input_img1 = np.ones((64, 256, 3), dtype=np.uint8)
|
||||
input_img2 = np.ones((64, 32, 3), dtype=np.uint8)
|
||||
|
||||
rci = MultiRotateAugOCR(transforms=[], rotate_degrees=[0, 90, 270])
|
||||
|
||||
# test invalid arguments
|
||||
with pytest.raises(AssertionError):
|
||||
MultiRotateAugOCR(transforms=[], rotate_degrees=[45])
|
||||
with pytest.raises(AssertionError):
|
||||
MultiRotateAugOCR(transforms=[], rotate_degrees=[20.5])
|
||||
|
||||
# test call with input_img1
|
||||
results = {'img_shape': input_img1.shape, 'img': input_img1}
|
||||
results = rci(results)
|
||||
assert np.allclose([64, 256, 3], results['img_shape'])
|
||||
assert len(results['img']) == 1
|
||||
assert len(results['img_shape']) == 1
|
||||
assert np.allclose([64, 256, 3], results['img_shape'][0])
|
||||
|
||||
# test call with input_img2
|
||||
results = {'img_shape': input_img2.shape, 'img': input_img2}
|
||||
results = rci(results)
|
||||
assert np.allclose([64, 32, 3], results['img_shape'])
|
||||
assert len(results['img']) == 3
|
||||
assert len(results['img_shape']) == 3
|
||||
assert np.allclose([64, 32, 3], results['img_shape'][0])
|
|
@ -0,0 +1,77 @@
|
|||
import os.path as osp
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmocr.models.textrecog.convertors import AttnConvertor
|
||||
|
||||
|
||||
def _create_dummy_dict_file(dict_file):
|
||||
characters = list('helowrd')
|
||||
with open(dict_file, 'w') as fw:
|
||||
for char in characters:
|
||||
fw.write(char + '\n')
|
||||
|
||||
|
||||
def test_attn_label_convertor():
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
# create dummy data
|
||||
dict_file = osp.join(tmp_dir.name, 'fake_dict.txt')
|
||||
_create_dummy_dict_file(dict_file)
|
||||
|
||||
# test invalid arguments
|
||||
with pytest.raises(AssertionError):
|
||||
AttnConvertor(5)
|
||||
with pytest.raises(AssertionError):
|
||||
AttnConvertor('DICT90', dict_file, '1')
|
||||
with pytest.raises(AssertionError):
|
||||
AttnConvertor('DICT90', dict_file, True, '1')
|
||||
|
||||
label_convertor = AttnConvertor(dict_file=dict_file, max_seq_len=10)
|
||||
# test init and parse_dict
|
||||
assert label_convertor.num_classes() == 10
|
||||
assert len(label_convertor.idx2char) == 10
|
||||
assert label_convertor.idx2char[0] == 'h'
|
||||
assert label_convertor.idx2char[1] == 'e'
|
||||
assert label_convertor.idx2char[-3] == '<UKN>'
|
||||
assert label_convertor.char2idx['h'] == 0
|
||||
assert label_convertor.unknown_idx == 7
|
||||
|
||||
# test encode str to tensor
|
||||
strings = ['hell']
|
||||
targets_dict = label_convertor.str2tensor(strings)
|
||||
assert torch.allclose(targets_dict['targets'][0],
|
||||
torch.LongTensor([0, 1, 2, 2]))
|
||||
assert torch.allclose(targets_dict['padded_targets'][0],
|
||||
torch.LongTensor([8, 0, 1, 2, 2, 8, 9, 9, 9, 9]))
|
||||
|
||||
# test decode output to index
|
||||
dummy_output = torch.Tensor([[[100, 2, 3, 4, 5, 6, 7, 8, 9],
|
||||
[1, 100, 3, 4, 5, 6, 7, 8, 9],
|
||||
[1, 2, 100, 4, 5, 6, 7, 8, 9],
|
||||
[1, 2, 100, 4, 5, 6, 7, 8, 9],
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 100],
|
||||
[1, 2, 3, 4, 5, 6, 7, 100, 9],
|
||||
[1, 2, 3, 4, 5, 6, 7, 100, 9],
|
||||
[1, 2, 3, 4, 5, 6, 7, 100, 9],
|
||||
[1, 2, 3, 4, 5, 6, 7, 100, 9],
|
||||
[1, 2, 3, 4, 5, 6, 7, 100, 9]]])
|
||||
indexes, scores = label_convertor.tensor2idx(dummy_output)
|
||||
assert np.allclose(indexes, [[0, 1, 2, 2]])
|
||||
|
||||
# test encode_str_label_to_index
|
||||
with pytest.raises(AssertionError):
|
||||
label_convertor.str2idx('hell')
|
||||
tmp_indexes = label_convertor.str2idx(strings)
|
||||
assert np.allclose(tmp_indexes, [[0, 1, 2, 2]])
|
||||
|
||||
# test decode_index to str_label
|
||||
input_indexes = [[0, 1, 2, 2]]
|
||||
with pytest.raises(AssertionError):
|
||||
label_convertor.idx2str('hell')
|
||||
output_strings = label_convertor.idx2str(input_indexes)
|
||||
assert output_strings[0] == 'hell'
|
||||
|
||||
tmp_dir.cleanup()
|
|
@ -0,0 +1,79 @@
|
|||
import os.path as osp
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmocr.models.textrecog.convertors import BaseConvertor, CTCConvertor
|
||||
|
||||
|
||||
def _create_dummy_dict_file(dict_file):
|
||||
chars = list('helowrd')
|
||||
with open(dict_file, 'w') as fw:
|
||||
for char in chars:
|
||||
fw.write(char + '\n')
|
||||
|
||||
|
||||
def test_ctc_label_convertor():
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
# create dummy data
|
||||
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||
_create_dummy_dict_file(dict_file)
|
||||
|
||||
# test invalid arguments
|
||||
with pytest.raises(AssertionError):
|
||||
CTCConvertor(5)
|
||||
|
||||
label_convertor = CTCConvertor(dict_file=dict_file, with_unknown=False)
|
||||
# test init and parse_chars
|
||||
assert label_convertor.num_classes() == 8
|
||||
assert len(label_convertor.idx2char) == 8
|
||||
assert label_convertor.idx2char[0] == '<BLK>'
|
||||
assert label_convertor.char2idx['h'] == 1
|
||||
assert label_convertor.unknown_idx is None
|
||||
|
||||
# test encode str to tensor
|
||||
strings = ['hell']
|
||||
expect_tensor = torch.IntTensor([1, 2, 3, 3])
|
||||
targets_dict = label_convertor.str2tensor(strings)
|
||||
assert torch.allclose(targets_dict['targets'][0], expect_tensor)
|
||||
assert torch.allclose(targets_dict['flatten_targets'], expect_tensor)
|
||||
assert torch.allclose(targets_dict['target_lengths'], torch.IntTensor([4]))
|
||||
|
||||
# test decode output to index
|
||||
dummy_output = torch.Tensor([[[1, 100, 3, 4, 5, 6, 7, 8],
|
||||
[100, 2, 3, 4, 5, 6, 7, 8],
|
||||
[1, 2, 100, 4, 5, 6, 7, 8],
|
||||
[1, 2, 100, 4, 5, 6, 7, 8],
|
||||
[100, 2, 3, 4, 5, 6, 7, 8],
|
||||
[1, 2, 3, 100, 5, 6, 7, 8],
|
||||
[100, 2, 3, 4, 5, 6, 7, 8],
|
||||
[1, 2, 3, 100, 5, 6, 7, 8]]])
|
||||
indexes, scores = label_convertor.tensor2idx(
|
||||
dummy_output, img_metas=[{
|
||||
'valid_ratio': 1.0
|
||||
}])
|
||||
assert np.allclose(indexes, [[1, 2, 3, 3]])
|
||||
|
||||
# test encode_str_label_to_index
|
||||
with pytest.raises(AssertionError):
|
||||
label_convertor.str2idx('hell')
|
||||
tmp_indexes = label_convertor.str2idx(strings)
|
||||
assert np.allclose(tmp_indexes, [[1, 2, 3, 3]])
|
||||
|
||||
# test deocde_index_to_str_label
|
||||
input_indexes = [[1, 2, 3, 3]]
|
||||
with pytest.raises(AssertionError):
|
||||
label_convertor.idx2str('hell')
|
||||
output_strings = label_convertor.idx2str(input_indexes)
|
||||
assert output_strings[0] == 'hell'
|
||||
|
||||
tmp_dir.cleanup()
|
||||
|
||||
|
||||
def test_base_label_convertor():
|
||||
with pytest.raises(NotImplementedError):
|
||||
label_convertor = BaseConvertor()
|
||||
label_convertor.str2tensor(None)
|
||||
label_convertor.tensor2idx(None)
|
|
@ -0,0 +1,36 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmocr.models.textrecog.backbones import ResNet31OCR, VeryDeepVgg
|
||||
|
||||
|
||||
def test_resnet31_ocr_backbone():
|
||||
"""Test resnet backbone."""
|
||||
with pytest.raises(AssertionError):
|
||||
ResNet31OCR(2.5)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
ResNet31OCR(3, layers=5)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
ResNet31OCR(3, channels=5)
|
||||
|
||||
# Test ResNet18 forward
|
||||
model = ResNet31OCR()
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 32, 160)
|
||||
feat = model(imgs)
|
||||
assert feat.shape == torch.Size([1, 512, 4, 40])
|
||||
|
||||
|
||||
def test_vgg_deep_vgg_ocr_backbone():
|
||||
|
||||
model = VeryDeepVgg()
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 32, 160)
|
||||
feats = model(imgs)
|
||||
assert feats.shape == torch.Size([1, 512, 1, 41])
|
|
@ -0,0 +1,112 @@
|
|||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmocr.models.textrecog.decoders import (BaseDecoder, ParallelSARDecoder,
|
||||
ParallelSARDecoderWithBS,
|
||||
SequentialSARDecoder, TFDecoder)
|
||||
from mmocr.models.textrecog.decoders.sar_decoder_with_bs import DecodeNode
|
||||
|
||||
|
||||
def _create_dummy_input():
|
||||
feat = torch.rand(1, 512, 4, 40)
|
||||
out_enc = torch.rand(1, 512)
|
||||
tgt_dict = {'padded_targets': torch.LongTensor([[1, 1, 1, 1, 36]])}
|
||||
img_metas = [{'valid_ratio': 1.0}]
|
||||
|
||||
return feat, out_enc, tgt_dict, img_metas
|
||||
|
||||
|
||||
def test_base_decoder():
|
||||
decoder = BaseDecoder()
|
||||
with pytest.raises(NotImplementedError):
|
||||
decoder.forward_train(None, None, None, None)
|
||||
with pytest.raises(NotImplementedError):
|
||||
decoder.forward_test(None, None, None)
|
||||
|
||||
|
||||
def test_parallel_sar_decoder():
|
||||
# test parallel sar decoder
|
||||
decoder = ParallelSARDecoder(num_classes=37, padding_idx=36, max_seq_len=5)
|
||||
decoder.init_weights()
|
||||
decoder.train()
|
||||
|
||||
feat, out_enc, tgt_dict, img_metas = _create_dummy_input()
|
||||
with pytest.raises(AssertionError):
|
||||
decoder(feat, out_enc, tgt_dict, [], True)
|
||||
with pytest.raises(AssertionError):
|
||||
decoder(feat, out_enc, tgt_dict, img_metas * 2, True)
|
||||
|
||||
out_train = decoder(feat, out_enc, tgt_dict, img_metas, True)
|
||||
assert out_train.shape == torch.Size([1, 5, 36])
|
||||
|
||||
out_test = decoder(feat, out_enc, tgt_dict, img_metas, False)
|
||||
assert out_test.shape == torch.Size([1, 5, 36])
|
||||
|
||||
|
||||
def test_sequential_sar_decoder():
|
||||
# test parallel sar decoder
|
||||
decoder = SequentialSARDecoder(
|
||||
num_classes=37, padding_idx=36, max_seq_len=5)
|
||||
decoder.init_weights()
|
||||
decoder.train()
|
||||
|
||||
feat, out_enc, tgt_dict, img_metas = _create_dummy_input()
|
||||
with pytest.raises(AssertionError):
|
||||
decoder(feat, out_enc, tgt_dict, [])
|
||||
with pytest.raises(AssertionError):
|
||||
decoder(feat, out_enc, tgt_dict, img_metas * 2)
|
||||
|
||||
out_train = decoder(feat, out_enc, tgt_dict, img_metas, True)
|
||||
assert out_train.shape == torch.Size([1, 5, 36])
|
||||
|
||||
out_test = decoder(feat, out_enc, tgt_dict, img_metas, False)
|
||||
assert out_test.shape == torch.Size([1, 5, 36])
|
||||
|
||||
|
||||
def test_parallel_sar_decoder_with_beam_search():
|
||||
with pytest.raises(AssertionError):
|
||||
ParallelSARDecoderWithBS(beam_width='beam')
|
||||
with pytest.raises(AssertionError):
|
||||
ParallelSARDecoderWithBS(beam_width=0)
|
||||
|
||||
feat, out_enc, tgt_dict, img_metas = _create_dummy_input()
|
||||
decoder = ParallelSARDecoderWithBS(
|
||||
beam_width=1, num_classes=37, padding_idx=36, max_seq_len=5)
|
||||
decoder.init_weights()
|
||||
decoder.train()
|
||||
with pytest.raises(AssertionError):
|
||||
decoder(feat, out_enc, tgt_dict, [])
|
||||
with pytest.raises(AssertionError):
|
||||
decoder(feat, out_enc, tgt_dict, img_metas * 2)
|
||||
|
||||
out_test = decoder(feat, out_enc, tgt_dict, img_metas, train_mode=False)
|
||||
assert out_test.shape == torch.Size([1, 5, 36])
|
||||
|
||||
# test decodenode
|
||||
with pytest.raises(AssertionError):
|
||||
DecodeNode(1, 1)
|
||||
with pytest.raises(AssertionError):
|
||||
DecodeNode([1, 2], ['4', '3'])
|
||||
with pytest.raises(AssertionError):
|
||||
DecodeNode([1, 2], [0.5])
|
||||
decode_node = DecodeNode([1, 2], [0.7, 0.8])
|
||||
assert math.isclose(decode_node.eval(), 1.5)
|
||||
|
||||
|
||||
def test_transformer_decoder():
|
||||
decoder = TFDecoder(num_classes=37, padding_idx=36, max_seq_len=5)
|
||||
decoder.init_weights()
|
||||
decoder.train()
|
||||
|
||||
out_enc = torch.rand(1, 128, 512)
|
||||
tgt_dict = {'padded_targets': torch.LongTensor([[1, 1, 1, 1, 36]])}
|
||||
img_metas = [{'valid_ratio': 1.0}]
|
||||
tgt_dict['padded_targets'] = tgt_dict['padded_targets']
|
||||
|
||||
out_train = decoder(None, out_enc, tgt_dict, img_metas, True)
|
||||
assert out_train.shape == torch.Size([1, 5, 36])
|
||||
|
||||
out_test = decoder(None, out_enc, tgt_dict, img_metas, False)
|
||||
assert out_test.shape == torch.Size([1, 5, 36])
|
|
@ -0,0 +1,53 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmocr.models.textrecog.encoders import BaseEncoder, SAREncoder, TFEncoder
|
||||
|
||||
|
||||
def test_sar_encoder():
|
||||
with pytest.raises(AssertionError):
|
||||
SAREncoder(enc_bi_rnn='bi')
|
||||
with pytest.raises(AssertionError):
|
||||
SAREncoder(enc_do_rnn=2)
|
||||
with pytest.raises(AssertionError):
|
||||
SAREncoder(enc_gru='gru')
|
||||
with pytest.raises(AssertionError):
|
||||
SAREncoder(d_model=512.5)
|
||||
with pytest.raises(AssertionError):
|
||||
SAREncoder(d_enc=200.5)
|
||||
with pytest.raises(AssertionError):
|
||||
SAREncoder(mask='mask')
|
||||
|
||||
encoder = SAREncoder()
|
||||
encoder.init_weights()
|
||||
encoder.train()
|
||||
|
||||
feat = torch.randn(1, 512, 4, 40)
|
||||
with pytest.raises(AssertionError):
|
||||
encoder(feat)
|
||||
img_metas = [{'valid_ratio': 1.0}]
|
||||
with pytest.raises(AssertionError):
|
||||
encoder(feat, img_metas * 2)
|
||||
out_enc = encoder(feat, img_metas)
|
||||
|
||||
assert out_enc.shape == torch.Size([1, 512])
|
||||
|
||||
|
||||
def test_transformer_encoder():
|
||||
tf_encoder = TFEncoder()
|
||||
tf_encoder.init_weights()
|
||||
tf_encoder.train()
|
||||
|
||||
feat = torch.randn(1, 512, 4, 40)
|
||||
out_enc = tf_encoder(feat)
|
||||
assert out_enc.shape == torch.Size([1, 160, 512])
|
||||
|
||||
|
||||
def test_base_encoder():
|
||||
encoder = BaseEncoder()
|
||||
encoder.init_weights()
|
||||
encoder.train()
|
||||
|
||||
feat = torch.randn(1, 256, 4, 40)
|
||||
out_enc = encoder(feat)
|
||||
assert out_enc.shape == torch.Size([1, 256, 4, 40])
|
|
@ -0,0 +1,16 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmocr.models.textrecog import SegHead
|
||||
|
||||
|
||||
def test_cafcn_head():
|
||||
with pytest.raises(AssertionError):
|
||||
SegHead(num_classes='100')
|
||||
with pytest.raises(AssertionError):
|
||||
SegHead(num_classes=-1)
|
||||
|
||||
cafcn_head = SegHead(num_classes=37)
|
||||
out_neck = (torch.rand(1, 128, 32, 32), )
|
||||
out_head = cafcn_head(out_neck)
|
||||
assert out_head.shape == torch.Size([1, 37, 32, 32])
|
|
@ -0,0 +1,55 @@
|
|||
import torch
|
||||
|
||||
from mmocr.models.textrecog.layers import (BasicBlock, Bottleneck,
|
||||
DecoderLayer, PositionalEncoding,
|
||||
get_pad_mask, get_subsequent_mask)
|
||||
from mmocr.models.textrecog.layers.conv_layer import conv3x3
|
||||
|
||||
|
||||
def test_conv_layer():
|
||||
conv3by3 = conv3x3(3, 6)
|
||||
assert conv3by3.in_channels == 3
|
||||
assert conv3by3.out_channels == 6
|
||||
assert conv3by3.kernel_size == (3, 3)
|
||||
|
||||
x = torch.rand(1, 64, 224, 224)
|
||||
# test basic block
|
||||
basic_block = BasicBlock(64, 64)
|
||||
assert basic_block.expansion == 1
|
||||
|
||||
out = basic_block(x)
|
||||
|
||||
assert out.shape == torch.Size([1, 64, 224, 224])
|
||||
|
||||
# test bottle neck
|
||||
bottle_neck = Bottleneck(64, 64, downsample=True)
|
||||
assert bottle_neck.expansion == 4
|
||||
|
||||
out = bottle_neck(x)
|
||||
|
||||
assert out.shape == torch.Size([1, 256, 224, 224])
|
||||
|
||||
|
||||
def test_transformer_layer():
|
||||
# test decoder_layer
|
||||
decoder_layer = DecoderLayer()
|
||||
in_dec = torch.rand(1, 30, 512)
|
||||
out_enc = torch.rand(1, 128, 512)
|
||||
out_dec = decoder_layer(in_dec, out_enc)
|
||||
assert out_dec.shape == torch.Size([1, 30, 512])
|
||||
|
||||
# test positional_encoding
|
||||
pos_encoder = PositionalEncoding()
|
||||
x = torch.rand(1, 30, 512)
|
||||
out = pos_encoder(x)
|
||||
assert out.size() == x.size()
|
||||
|
||||
# test get pad mask
|
||||
seq = torch.rand(1, 30)
|
||||
pad_idx = 0
|
||||
out = get_pad_mask(seq, pad_idx)
|
||||
assert out.shape == torch.Size([1, 1, 30])
|
||||
|
||||
# test get_subsequent_mask
|
||||
out_mask = get_subsequent_mask(seq)
|
||||
assert out_mask.shape == torch.Size([1, 30, 30])
|
|
@ -0,0 +1,123 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmdet.core import BitmapMasks
|
||||
from mmocr.models.common.losses import DiceLoss
|
||||
from mmocr.models.textrecog.losses import (CAFCNLoss, CELoss, CTCLoss, SARLoss,
|
||||
TFLoss)
|
||||
|
||||
|
||||
def test_ctc_loss():
|
||||
# test CTCLoss
|
||||
ctc_loss = CTCLoss()
|
||||
outputs = torch.zeros(2, 40, 37)
|
||||
targets_dict = {
|
||||
'flatten_targets': torch.IntTensor([1, 2, 3, 4, 5]),
|
||||
'target_lengths': torch.LongTensor([2, 3])
|
||||
}
|
||||
|
||||
losses = ctc_loss(outputs, targets_dict)
|
||||
assert isinstance(losses, dict)
|
||||
assert 'loss_ctc' in losses
|
||||
assert torch.allclose(losses['loss_ctc'],
|
||||
torch.tensor(losses['loss_ctc'].item()).float())
|
||||
|
||||
|
||||
def test_ce_loss():
|
||||
with pytest.raises(AssertionError):
|
||||
CELoss(ignore_index='ignore')
|
||||
with pytest.raises(AssertionError):
|
||||
CELoss(reduction=1)
|
||||
with pytest.raises(AssertionError):
|
||||
CELoss(reduction='avg')
|
||||
|
||||
ce_loss = CELoss(ignore_index=0)
|
||||
outputs = torch.rand(1, 10, 37)
|
||||
targets_dict = {
|
||||
'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]])
|
||||
}
|
||||
losses = ce_loss(outputs, targets_dict)
|
||||
assert isinstance(losses, dict)
|
||||
assert 'loss_ce' in losses
|
||||
print(losses['loss_ce'].size())
|
||||
assert losses['loss_ce'].size(1) == 10
|
||||
|
||||
|
||||
def test_sar_loss():
|
||||
outputs = torch.rand(1, 10, 37)
|
||||
targets_dict = {
|
||||
'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]])
|
||||
}
|
||||
sar_loss = SARLoss()
|
||||
new_output, new_target = sar_loss.format(outputs, targets_dict)
|
||||
assert new_output.shape == torch.Size([1, 37, 9])
|
||||
assert new_target.shape == torch.Size([1, 9])
|
||||
|
||||
|
||||
def test_tf_loss():
|
||||
with pytest.raises(AssertionError):
|
||||
TFLoss(flatten=1.0)
|
||||
|
||||
outputs = torch.rand(1, 10, 37)
|
||||
targets_dict = {
|
||||
'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]])
|
||||
}
|
||||
tf_loss = TFLoss(flatten=False)
|
||||
new_output, new_target = tf_loss.format(outputs, targets_dict)
|
||||
assert new_output.shape == torch.Size([1, 37, 9])
|
||||
assert new_target.shape == torch.Size([1, 9])
|
||||
|
||||
|
||||
def test_cafcn_loss():
|
||||
with pytest.raises(AssertionError):
|
||||
CAFCNLoss(alpha='1')
|
||||
with pytest.raises(AssertionError):
|
||||
CAFCNLoss(attn_s2_downsample_ratio='2')
|
||||
with pytest.raises(AssertionError):
|
||||
CAFCNLoss(attn_s3_downsample_ratio='1.5')
|
||||
with pytest.raises(AssertionError):
|
||||
CAFCNLoss(seg_downsample_ratio='1.5')
|
||||
with pytest.raises(AssertionError):
|
||||
CAFCNLoss(attn_s2_downsample_ratio=2)
|
||||
with pytest.raises(AssertionError):
|
||||
CAFCNLoss(attn_s3_downsample_ratio=1.5)
|
||||
with pytest.raises(AssertionError):
|
||||
CAFCNLoss(seg_downsample_ratio=1.5)
|
||||
|
||||
bsz = 1
|
||||
H = W = 64
|
||||
out_neck = (torch.ones(bsz, 1, H // 4, W // 4) * 0.5,
|
||||
torch.ones(bsz, 1, H // 8, W // 8) * 0.5,
|
||||
torch.ones(bsz, 1, H // 8, W // 8) * 0.5,
|
||||
torch.ones(bsz, 1, H // 8, W // 8) * 0.5,
|
||||
torch.ones(bsz, 1, H // 2, W // 2) * 0.5)
|
||||
out_head = torch.rand(bsz, 37, H // 2, W // 2)
|
||||
|
||||
attn_tgt = np.zeros((H, W), dtype=np.float32)
|
||||
segm_tgt = np.zeros((H, W), dtype=np.float32)
|
||||
mask = np.ones((H, W), dtype=np.float32)
|
||||
gt_kernels = BitmapMasks([attn_tgt, segm_tgt, mask], H, W)
|
||||
|
||||
cafcn_loss = CAFCNLoss()
|
||||
losses = cafcn_loss(out_neck, out_head, [gt_kernels])
|
||||
assert isinstance(losses, dict)
|
||||
assert 'loss_seg' in losses
|
||||
assert torch.allclose(losses['loss_seg'],
|
||||
torch.tensor(losses['loss_seg'].item()).float())
|
||||
|
||||
|
||||
def test_dice_loss():
|
||||
with pytest.raises(AssertionError):
|
||||
DiceLoss(eps='1')
|
||||
|
||||
dice_loss = DiceLoss()
|
||||
pred = torch.rand(1, 1, 32, 32)
|
||||
gt = torch.rand(1, 1, 32, 32)
|
||||
|
||||
loss = dice_loss(pred, gt, None)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
|
||||
mask = torch.rand(1, 1, 1, 1)
|
||||
loss = dice_loss(pred, gt, mask)
|
||||
assert isinstance(loss, torch.Tensor)
|
|
@ -0,0 +1,48 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmocr.models.textrecog.necks.cafcn_neck import (CAFCNNeck, CharAttn,
|
||||
FeatGenerator)
|
||||
|
||||
|
||||
def test_char_attn():
|
||||
with pytest.raises(AssertionError):
|
||||
CharAttn(in_channels=5.0)
|
||||
with pytest.raises(AssertionError):
|
||||
CharAttn(deformable='deformabel')
|
||||
|
||||
in_feat = torch.rand(1, 128, 32, 32)
|
||||
char_attn = CharAttn()
|
||||
out_feat_map, attn_map = char_attn(in_feat)
|
||||
assert attn_map.shape == torch.Size([1, 1, 32, 32])
|
||||
assert out_feat_map.shape == torch.Size([1, 128, 32, 32])
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='TODO: re-enable after CI support pytorch>1.4')
|
||||
def test_feat_generator():
|
||||
in_feat = torch.rand(1, 128, 32, 32)
|
||||
feat_generator = FeatGenerator(in_channels=128, out_channels=128)
|
||||
|
||||
attn_map, feat_map = feat_generator(in_feat)
|
||||
assert attn_map.shape == torch.Size([1, 1, 32, 32])
|
||||
assert feat_map.shape == torch.Size([1, 128, 32, 32])
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='TODO: re-enable after CI support pytorch>1.4')
|
||||
def test_cafcn_neck():
|
||||
in_s1 = torch.rand(1, 64, 64, 64)
|
||||
in_s2 = torch.rand(1, 128, 32, 32)
|
||||
in_s3 = torch.rand(1, 256, 16, 16)
|
||||
in_s4 = torch.rand(1, 512, 16, 16)
|
||||
in_s5 = torch.rand(1, 512, 16, 16)
|
||||
|
||||
cafcn_neck = CAFCNNeck()
|
||||
cafcn_neck.init_weights()
|
||||
cafcn_neck.train()
|
||||
|
||||
out_neck = cafcn_neck((in_s1, in_s2, in_s3, in_s4, in_s5))
|
||||
assert out_neck[0].shape == torch.Size([1, 1, 32, 32])
|
||||
assert out_neck[1].shape == torch.Size([1, 1, 16, 16])
|
||||
assert out_neck[2].shape == torch.Size([1, 1, 16, 16])
|
||||
assert out_neck[3].shape == torch.Size([1, 1, 16, 16])
|
||||
assert out_neck[4].shape == torch.Size([1, 128, 64, 64])
|
|
@ -0,0 +1,156 @@
|
|||
import os.path as osp
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmdet.core import BitmapMasks
|
||||
from mmocr.models.textrecog.recognizer import (EncodeDecodeRecognizer,
|
||||
SegRecognizer)
|
||||
|
||||
|
||||
def _create_dummy_dict_file(dict_file):
|
||||
chars = list('helowrd')
|
||||
with open(dict_file, 'w') as fw:
|
||||
for char in chars:
|
||||
fw.write(char + '\n')
|
||||
|
||||
|
||||
def test_base_recognizer():
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
# create dummy data
|
||||
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||
_create_dummy_dict_file(dict_file)
|
||||
|
||||
label_convertor = dict(
|
||||
type='CTCConvertor', dict_file=dict_file, with_unknown=False)
|
||||
|
||||
preprocessor = None
|
||||
backbone = dict(type='VeryDeepVgg', leakyRelu=False)
|
||||
encoder = None
|
||||
decoder = dict(type='CRNNDecoder', in_channels=512, rnn_flag=True)
|
||||
loss = dict(type='CTCLoss')
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
EncodeDecodeRecognizer(backbone=None)
|
||||
with pytest.raises(AssertionError):
|
||||
EncodeDecodeRecognizer(decoder=None)
|
||||
with pytest.raises(AssertionError):
|
||||
EncodeDecodeRecognizer(loss=None)
|
||||
with pytest.raises(AssertionError):
|
||||
EncodeDecodeRecognizer(label_convertor=None)
|
||||
|
||||
recognizer = EncodeDecodeRecognizer(
|
||||
preprocessor=preprocessor,
|
||||
backbone=backbone,
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
loss=loss,
|
||||
label_convertor=label_convertor)
|
||||
|
||||
recognizer.init_weights()
|
||||
recognizer.train()
|
||||
|
||||
imgs = torch.rand(1, 3, 32, 160)
|
||||
|
||||
# test extract feat
|
||||
feat = recognizer.extract_feat(imgs)
|
||||
assert feat.shape == torch.Size([1, 512, 1, 41])
|
||||
|
||||
# test forward train
|
||||
img_metas = [{'text': 'hello', 'valid_ratio': 1.0}]
|
||||
losses = recognizer.forward_train(imgs, img_metas)
|
||||
assert isinstance(losses, dict)
|
||||
assert 'loss_ctc' in losses
|
||||
|
||||
# test simple test
|
||||
results = recognizer.simple_test(imgs, img_metas)
|
||||
assert isinstance(results, list)
|
||||
assert isinstance(results[0], dict)
|
||||
assert 'text' in results[0]
|
||||
assert 'score' in results[0]
|
||||
|
||||
# test aug_test
|
||||
aug_results = recognizer.aug_test([imgs, imgs], [img_metas, img_metas])
|
||||
assert isinstance(aug_results, list)
|
||||
assert isinstance(aug_results[0], dict)
|
||||
assert 'text' in aug_results[0]
|
||||
assert 'score' in aug_results[0]
|
||||
|
||||
tmp_dir.cleanup()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='TODO: re-enable after CI support pytorch>1.4')
|
||||
def test_seg_recognizer():
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
# create dummy data
|
||||
dict_file = osp.join(tmp_dir.name, 'fake_chars.txt')
|
||||
_create_dummy_dict_file(dict_file)
|
||||
|
||||
label_convertor = dict(
|
||||
type='SegConvertor', dict_file=dict_file, with_unknown=False)
|
||||
|
||||
preprocessor = None
|
||||
backbone = dict(type='ResNet31OCR')
|
||||
neck = dict(type='FPNOCR')
|
||||
head = dict(type='SegHead')
|
||||
loss = dict(type='SegLoss')
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
SegRecognizer(backbone=None)
|
||||
with pytest.raises(AssertionError):
|
||||
SegRecognizer(neck=None)
|
||||
with pytest.raises(AssertionError):
|
||||
SegRecognizer(head=None)
|
||||
with pytest.raises(AssertionError):
|
||||
SegRecognizer(loss=None)
|
||||
with pytest.raises(AssertionError):
|
||||
SegRecognizer(label_convertor=None)
|
||||
|
||||
recognizer = SegRecognizer(
|
||||
preprocessor=preprocessor,
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=head,
|
||||
loss=loss,
|
||||
label_convertor=label_convertor)
|
||||
|
||||
recognizer.init_weights()
|
||||
recognizer.train()
|
||||
|
||||
imgs = torch.rand(1, 3, 64, 256)
|
||||
|
||||
# test extract feat
|
||||
feats = recognizer.extract_feat(imgs)
|
||||
assert len(feats) == 5
|
||||
assert feats[0].shape == torch.Size([1, 64, 32, 128])
|
||||
assert feats[1].shape == torch.Size([1, 128, 16, 64])
|
||||
assert feats[2].shape == torch.Size([1, 256, 8, 32])
|
||||
assert feats[3].shape == torch.Size([1, 512, 8, 32])
|
||||
assert feats[4].shape == torch.Size([1, 512, 8, 32])
|
||||
|
||||
attn_tgt = np.zeros((64, 256), dtype=np.float32)
|
||||
segm_tgt = np.zeros((64, 256), dtype=np.float32)
|
||||
gt_kernels = BitmapMasks([attn_tgt, segm_tgt], 64, 256)
|
||||
|
||||
# test forward train
|
||||
img_metas = [{'text': 'hello', 'valid_ratio': 1.0}]
|
||||
losses = recognizer.forward_train(imgs, img_metas, gt_kernels=[gt_kernels])
|
||||
assert isinstance(losses, dict)
|
||||
|
||||
# test simple test
|
||||
results = recognizer.simple_test(imgs, img_metas)
|
||||
assert isinstance(results, list)
|
||||
assert isinstance(results[0], dict)
|
||||
assert 'text' in results[0]
|
||||
assert 'score' in results[0]
|
||||
|
||||
# test aug_test
|
||||
aug_results = recognizer.aug_test([imgs, imgs], [img_metas, img_metas])
|
||||
assert isinstance(aug_results, list)
|
||||
assert isinstance(aug_results[0], dict)
|
||||
assert 'text' in aug_results[0]
|
||||
assert 'score' in aug_results[0]
|
||||
|
||||
tmp_dir.cleanup()
|
|
@ -0,0 +1,66 @@
|
|||
"""Test text label visualize."""
|
||||
import os.path as osp
|
||||
import random
|
||||
import tempfile
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mmocr.core.visualize as visualize_utils
|
||||
|
||||
|
||||
def test_tile_image():
|
||||
dummp_imgs, heights, widths = [], [], []
|
||||
for _ in range(3):
|
||||
h = random.randint(100, 300)
|
||||
w = random.randint(100, 300)
|
||||
heights.append(h)
|
||||
widths.append(w)
|
||||
# dummy_img = Image.new('RGB', (w, h), Image.ANTIALIAS)
|
||||
dummy_img = np.ones((h, w, 3), dtype=np.uint8)
|
||||
dummp_imgs.append(dummy_img)
|
||||
joint_img = visualize_utils.tile_image(dummp_imgs)
|
||||
assert joint_img.shape[0] == sum(heights)
|
||||
assert joint_img.shape[1] == max(widths)
|
||||
|
||||
# test invalid arguments
|
||||
with pytest.raises(AssertionError):
|
||||
visualize_utils.tile_image(dummp_imgs[0])
|
||||
with pytest.raises(AssertionError):
|
||||
visualize_utils.tile_image([])
|
||||
|
||||
|
||||
@mock.patch('%s.visualize_utils.mmcv.imread' % __name__)
|
||||
@mock.patch('%s.visualize_utils.mmcv.imshow' % __name__)
|
||||
@mock.patch('%s.visualize_utils.mmcv.imwrite' % __name__)
|
||||
def test_show_text_label(mock_imwrite, mock_imshow, mock_imread):
|
||||
img = np.ones((32, 160), dtype=np.uint8)
|
||||
pred_label = 'hello'
|
||||
gt_label = 'world'
|
||||
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
out_file = osp.join(tmp_dir.name, 'tmp.jpg')
|
||||
|
||||
# test invalid arguments
|
||||
with pytest.raises(AssertionError):
|
||||
visualize_utils.imshow_text_label(5, pred_label, gt_label)
|
||||
with pytest.raises(AssertionError):
|
||||
visualize_utils.imshow_text_label(img, pred_label, 4)
|
||||
with pytest.raises(AssertionError):
|
||||
visualize_utils.imshow_text_label(img, 3, gt_label)
|
||||
with pytest.raises(AssertionError):
|
||||
visualize_utils.imshow_text_label(
|
||||
img, pred_label, gt_label, show=True, wait_time=0.1)
|
||||
|
||||
mock_imread.side_effect = [img, img]
|
||||
visualize_utils.imshow_text_label(
|
||||
img, pred_label, gt_label, out_file=out_file)
|
||||
visualize_utils.imshow_text_label(
|
||||
img, pred_label, gt_label, out_file=None, show=True)
|
||||
|
||||
# test showing img
|
||||
mock_imshow.assert_called_once()
|
||||
mock_imwrite.assert_called_once()
|
||||
|
||||
tmp_dir.cleanup()
|
|
@ -0,0 +1,93 @@
|
|||
import argparse
|
||||
import codecs
|
||||
import json
|
||||
import os.path as osp
|
||||
|
||||
import cv2
|
||||
|
||||
|
||||
def read_json(fpath):
|
||||
with codecs.open(fpath, 'r', 'utf-8') as f:
|
||||
obj = json.load(f)
|
||||
return obj
|
||||
|
||||
|
||||
def parse_old_label(img_prefix, in_path):
|
||||
imgid2imgname = {}
|
||||
imgid2anno = {}
|
||||
idx = 0
|
||||
with open(in_path, 'r') as fr:
|
||||
for line in fr:
|
||||
line = line.strip().split()
|
||||
img_full_path = osp.join(img_prefix, line[0])
|
||||
if not osp.exists(img_full_path):
|
||||
continue
|
||||
img = cv2.imread(img_full_path)
|
||||
h, w = img.shape[:2]
|
||||
img_info = {}
|
||||
img_info['file_name'] = line[0]
|
||||
img_info['height'] = h
|
||||
img_info['width'] = w
|
||||
imgid2imgname[idx] = img_info
|
||||
imgid2anno[idx] = []
|
||||
for i in range(len(line[1:]) // 8):
|
||||
seg = [int(x) for x in line[(1 + i * 8):(1 + (i + 1) * 8)]]
|
||||
points_x = seg[0:2:8]
|
||||
points_y = seg[1:2:9]
|
||||
box = [
|
||||
min(points_x),
|
||||
min(points_y),
|
||||
max(points_x),
|
||||
max(points_y)
|
||||
]
|
||||
new_anno = {}
|
||||
new_anno['iscrowd'] = 0
|
||||
new_anno['category_id'] = 1
|
||||
new_anno['bbox'] = box
|
||||
new_anno['segmentation'] = [seg]
|
||||
imgid2anno[idx].append(new_anno)
|
||||
idx += 1
|
||||
|
||||
return imgid2imgname, imgid2anno
|
||||
|
||||
|
||||
def gen_line_dict_file(out_path, imgid2imgname, imgid2anno):
|
||||
# import pdb; pdb.set_trace()
|
||||
with codecs.open(out_path, 'w', 'utf-8') as fw:
|
||||
for key, value in imgid2imgname.items():
|
||||
if key in imgid2anno:
|
||||
anno = imgid2anno[key]
|
||||
line_dict = {}
|
||||
line_dict['file_name'] = value['file_name']
|
||||
line_dict['height'] = value['height']
|
||||
line_dict['width'] = value['width']
|
||||
line_dict['annotations'] = anno
|
||||
line_dict_str = json.dumps(line_dict)
|
||||
fw.write(line_dict_str + '\n')
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--img-prefix',
|
||||
help='image prefix, to generate full image path with "image_name"')
|
||||
parser.add_argument(
|
||||
'--in-path',
|
||||
help='mapping file of image_name and ann_file,'
|
||||
' "image_name ann_file" in each line')
|
||||
parser.add_argument(
|
||||
'--out-path', help='output txt path with line-json format')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
imgid2imgname, imgid2anno = parse_old_label(args.img_prefix, args.in_path)
|
||||
gen_line_dict_file(args.out_path, imgid2imgname, imgid2anno)
|
||||
print('finish')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,74 @@
|
|||
import argparse
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import lmdb
|
||||
|
||||
|
||||
def converter(imglist, output, batch_size=1000, coding='utf-8'):
|
||||
# read imglist
|
||||
with open(imglist) as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# create lmdb database
|
||||
if Path(output).is_dir():
|
||||
while True:
|
||||
print('%s already exist, delete or not? [Y/n]' % output)
|
||||
Yn = input().strip()
|
||||
if Yn in ['Y', 'y']:
|
||||
shutil.rmtree(output)
|
||||
break
|
||||
elif Yn in ['N', 'n']:
|
||||
return
|
||||
print('create database %s' % output)
|
||||
Path(output).mkdir(parents=True, exist_ok=False)
|
||||
env = lmdb.open(output, map_size=1099511627776)
|
||||
|
||||
# build lmdb
|
||||
beg_time = time.strftime('%H:%M:%S')
|
||||
for beg_index in range(0, len(lines), batch_size):
|
||||
end_index = min(beg_index + batch_size, len(lines))
|
||||
sys.stdout.write('\r[%s-%s], processing [%d-%d] / %d' %
|
||||
(beg_time, time.strftime('%H:%M:%S'), beg_index,
|
||||
end_index, len(lines)))
|
||||
sys.stdout.flush()
|
||||
batch = [(str(index).encode(coding), lines[index].encode(coding))
|
||||
for index in range(beg_index, end_index)]
|
||||
with env.begin(write=True) as txn:
|
||||
cursor = txn.cursor()
|
||||
cursor.putmulti(batch, dupdata=False, overwrite=True)
|
||||
sys.stdout.write('\n')
|
||||
with env.begin(write=True) as txn:
|
||||
key = 'total_number'.encode(coding)
|
||||
value = str(len(lines)).encode(coding)
|
||||
txn.put(key, value)
|
||||
print('done', flush=True)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--imglist', '-i', required=True, help='input imglist path')
|
||||
parser.add_argument(
|
||||
'--output', '-o', required=True, help='output lmdb path')
|
||||
parser.add_argument(
|
||||
'--batch_size',
|
||||
'-b',
|
||||
type=int,
|
||||
default=10000,
|
||||
help='processing batch size, default 10000')
|
||||
parser.add_argument(
|
||||
'--coding',
|
||||
'-c',
|
||||
default='utf8',
|
||||
help='bytes coding scheme, default utf8')
|
||||
opt = parser.parse_args()
|
||||
|
||||
converter(
|
||||
opt.imglist, opt.output, batch_size=opt.batch_size, coding=opt.coding)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,132 @@
|
|||
import os.path as osp
|
||||
import shutil
|
||||
import time
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
from mmcv.utils import ProgressBar
|
||||
|
||||
from mmdet.apis import init_detector
|
||||
from mmdet.utils import get_root_logger
|
||||
from mmocr.apis import model_inference
|
||||
from mmocr.core.evaluation.ocr_metric import eval_ocr_metric
|
||||
from mmocr.datasets import build_dataset # noqa: F401
|
||||
from mmocr.models import build_detector # noqa: F401
|
||||
|
||||
|
||||
def save_results(img_paths, pred_labels, gt_labels, res_dir):
|
||||
"""Save predicted results to txt file.
|
||||
|
||||
Args:
|
||||
img_paths (list[str])
|
||||
pred_labels (list[str])
|
||||
gt_labels (list[str])
|
||||
res_dir (str)
|
||||
"""
|
||||
assert len(img_paths) == len(pred_labels) == len(gt_labels)
|
||||
res_file = osp.join(res_dir, 'results.txt')
|
||||
correct_file = osp.join(res_dir, 'correct.txt')
|
||||
wrong_file = osp.join(res_dir, 'wrong.txt')
|
||||
with open(res_file, 'w') as fw, \
|
||||
open(correct_file, 'w') as fw_correct, \
|
||||
open(wrong_file, 'w') as fw_wrong:
|
||||
for img_path, pred_label, gt_label in zip(img_paths, pred_labels,
|
||||
gt_labels):
|
||||
fw.write(img_path + ' ' + pred_label + ' ' + gt_label + '\n')
|
||||
if pred_label == gt_label:
|
||||
fw_correct.write(img_path + ' ' + pred_label + ' ' + gt_label +
|
||||
'\n')
|
||||
else:
|
||||
fw_wrong.write(img_path + ' ' + pred_label + ' ' + gt_label +
|
||||
'\n')
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--img_root_path', type=str, help='Image root path')
|
||||
parser.add_argument('--img_list', type=str, help='Image path list file')
|
||||
parser.add_argument('--config', type=str, help='Config file')
|
||||
parser.add_argument('--checkpoint', type=str, help='Checkpoint file')
|
||||
parser.add_argument(
|
||||
'--out_dir', type=str, default='./results', help='Dir to save results')
|
||||
parser.add_argument(
|
||||
'--show', action='store_true', help='show image or save')
|
||||
args = parser.parse_args()
|
||||
|
||||
# init the logger before other steps
|
||||
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
||||
log_file = osp.join(args.out_dir, f'{timestamp}.log')
|
||||
logger = get_root_logger(log_file=log_file, log_level='INFO')
|
||||
|
||||
# build the model from a config file and a checkpoint file
|
||||
device = 'cuda:' + str(torch.cuda.current_device())
|
||||
model = init_detector(args.config, args.checkpoint, device=device)
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
if model.cfg.data.test['type'] == 'ConcatDataset':
|
||||
model.cfg.data.test.pipeline = \
|
||||
model.cfg.data.test['datasets'][0].pipeline
|
||||
|
||||
# Start Inference
|
||||
out_vis_dir = osp.join(args.out_dir, 'out_vis_dir')
|
||||
mmcv.mkdir_or_exist(out_vis_dir)
|
||||
correct_vis_dir = osp.join(args.out_dir, 'correct')
|
||||
mmcv.mkdir_or_exist(correct_vis_dir)
|
||||
wrong_vis_dir = osp.join(args.out_dir, 'wrong')
|
||||
mmcv.mkdir_or_exist(wrong_vis_dir)
|
||||
img_paths, pred_labels, gt_labels = [], [], []
|
||||
total_img_num = sum([1 for _ in open(args.img_list)])
|
||||
progressbar = ProgressBar(task_num=total_img_num)
|
||||
num_gt_label = 0
|
||||
with open(args.img_list, 'r') as fr:
|
||||
for line in fr:
|
||||
progressbar.update()
|
||||
item_list = line.strip().split()
|
||||
img_file = item_list[0]
|
||||
gt_label = ''
|
||||
if len(item_list) >= 2:
|
||||
gt_label = item_list[1]
|
||||
num_gt_label += 1
|
||||
img_path = osp.join(args.img_root_path, img_file)
|
||||
if not osp.exists(img_path):
|
||||
raise FileNotFoundError(img_path)
|
||||
# Test a single image
|
||||
result = model_inference(model, img_path)
|
||||
pred_label = result['text']
|
||||
|
||||
out_img_name = '_'.join(img_file.split('/'))
|
||||
out_file = osp.join(out_vis_dir, out_img_name)
|
||||
kwargs_dict = {
|
||||
'gt_label': gt_label,
|
||||
'show': args.show,
|
||||
'out_file': '' if args.show else out_file
|
||||
}
|
||||
model.show_result(img_path, result, **kwargs_dict)
|
||||
if gt_label != '':
|
||||
if gt_label == pred_label:
|
||||
dst_file = osp.join(correct_vis_dir, out_img_name)
|
||||
else:
|
||||
dst_file = osp.join(wrong_vis_dir, out_img_name)
|
||||
shutil.copy(out_file, dst_file)
|
||||
img_paths.append(img_path)
|
||||
gt_labels.append(gt_label)
|
||||
pred_labels.append(pred_label)
|
||||
|
||||
# Save results
|
||||
save_results(img_paths, pred_labels, gt_labels, args.out_dir)
|
||||
|
||||
if num_gt_label == len(pred_labels):
|
||||
# eval
|
||||
eval_results = eval_ocr_metric(pred_labels, gt_labels)
|
||||
logger.info('\n' + '-' * 100)
|
||||
info = 'eval on testset with img_root_path ' + \
|
||||
f'{args.img_root_path} and img_list {args.img_list}\n'
|
||||
logger.info(info)
|
||||
logger.info(eval_results)
|
||||
|
||||
print(f'\nInference done, and results saved in {args.out_dir}\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,25 @@
|
|||
#!/bin/bash
|
||||
|
||||
DATE=`date +%Y-%m-%d`
|
||||
TIME=`date +"%H-%M-%S"`
|
||||
|
||||
if [ $# -lt 5 ]
|
||||
then
|
||||
echo "Usage: bash $0 CONFIG CHECKPOINT IMG_PREFIX IMG_LIST RESULTS_DIR"
|
||||
exit
|
||||
fi
|
||||
|
||||
CONFIG_FILE=$1
|
||||
CHECKPOINT=$2
|
||||
IMG_ROOT_PATH=$3
|
||||
IMG_LIST=$4
|
||||
OUT_DIR=$5_${DATE}_${TIME}
|
||||
|
||||
mkdir ${OUT_DIR} -p &&
|
||||
|
||||
python tools/ocr_test_imgs.py \
|
||||
--img_root_path ${IMG_ROOT_PATH} \
|
||||
--img_list ${IMG_LIST} \
|
||||
--config ${CONFIG_FILE} \
|
||||
--checkpoint ${CHECKPOINT} \
|
||||
--out_dir ${OUT_DIR}
|
Loading…
Reference in New Issue