mirror of https://github.com/open-mmlab/mmocr.git
[SPTS] train (#1756)
* [Feature] Add RepeatAugSampler * initial commit * spts inference done * merge repeat_aug (bug in multi-node?) * fix inference * train done * rm readme * Revert "merge repeat_aug (bug in multi-node?)" This reverts commitpull/1761/head393506a97c
. * Revert "[Feature] Add RepeatAugSampler" This reverts commit2089b02b48
. * remove utils * readme & conversion script * update readme * fix * optimize * rename cfg & del compose * fix * fix * tmp commit * update training setting * update cfg * update readme * e2e metric * update cfg * fix * update readme * fix * update
parent
81fd74c266
commit
5670695338
|
@ -152,8 +152,8 @@ def crop_polygon(polygon: ArrayLike,
|
|||
np.array or None: Cropped polygon. If the polygon is not within the
|
||||
crop box, return None.
|
||||
"""
|
||||
poly = poly2shapely(polygon)
|
||||
crop_poly = poly2shapely(bbox2poly(crop_box))
|
||||
poly = poly_make_valid(poly2shapely(polygon))
|
||||
crop_poly = poly_make_valid(poly2shapely(bbox2poly(crop_box)))
|
||||
poly_cropped = poly.intersection(crop_poly)
|
||||
if poly_cropped.area == 0. or not isinstance(
|
||||
poly_cropped, shapely.geometry.polygon.Polygon):
|
||||
|
|
|
@ -36,10 +36,24 @@ $env:PYTHONPATH=Get-Location
|
|||
|
||||
### Dataset
|
||||
|
||||
**As of now, the implementation uses datasets provided by SPTS, but these datasets
|
||||
will be available in MMOCR's dataset preparer very soon.**
|
||||
As of now, the implementation uses datasets provided by SPTS for pre-training, and uses MMOCR's datasets for fine-tuning and testing. It's because the test split of SPTS's datasets does not contain enough information for e2e evaluation; and MMOCR's dataset preparer has not yet supported all the datasets used in SPTS. *We are working on this issue, and they will be available in MMOCR's dataset preparer very soon.*
|
||||
|
||||
Download and extract all the datasets into `data/` following [SPTS official guide](https://github.com/shannanyinxiang/SPTS#dataset).
|
||||
Please follow these steps to prepare the datasets:
|
||||
|
||||
- Download and extract all the SPTS datasets into `spts-data/` following [SPTS official guide](https://github.com/shannanyinxiang/SPTS#dataset).
|
||||
|
||||
- Use [Dataset Preparer](https://mmocr.readthedocs.io/en/dev-1.x/user_guides/data_prepare/dataset_preparer.html) to prepare `icdar2013`, `icdar2015` and `totaltext` for `textspotting` task.
|
||||
|
||||
```shell
|
||||
# Run in MMOCR's root directory
|
||||
python tools/dataset_converters/prepare_dataset.py icdar2013 icdar2015 totaltext --task textspotting
|
||||
```
|
||||
|
||||
Then create a soft link to `data/` directory in the project root directory:
|
||||
|
||||
```shell
|
||||
ln -s ../../data/ .
|
||||
```
|
||||
|
||||
### Training commands
|
||||
|
||||
|
@ -48,13 +62,13 @@ In the current directory, run the following command to train the model:
|
|||
#### Pretrain
|
||||
|
||||
```bash
|
||||
mim train mmocr config/spts/spts_resnet50_150e_pretrain-spts.py --work-dir work_dirs/
|
||||
mim train mmocr config/spts/spts_resnet50_150e_pretrain-spts.py --work-dir work_dirs/ --amp
|
||||
```
|
||||
|
||||
To train on multiple GPUs, e.g. 8 GPUs, run the following command:
|
||||
|
||||
```bash
|
||||
mim train mmocr config/spts/spts_resnet50_150e_pretrain-spts.py --work-dir work_dirs/ --launcher pytorch --gpus 8
|
||||
mim train mmocr config/spts/spts_resnet50_150e_pretrain-spts.py --work-dir work_dirs/ --launcher pytorch --gpus 8 --amp
|
||||
```
|
||||
|
||||
#### Finetune
|
||||
|
@ -62,13 +76,13 @@ mim train mmocr config/spts/spts_resnet50_150e_pretrain-spts.py --work-dir work_
|
|||
Similarly, run the following command to finetune the model on a dataset (e.g. icdar2013):
|
||||
|
||||
```bash
|
||||
mim train mmocr config/spts/spts_resnet50_350e_icdar2013-spts.py --work-dir work_dirs/ --cfg-options "load_from={CHECKPOINT_PATH}"
|
||||
mim train mmocr config/spts/spts_resnet50_8xb8-200e_icdar2013.py --work-dir work_dirs/ --cfg-options "load_from={CHECKPOINT_PATH}" --amp
|
||||
```
|
||||
|
||||
To finetune on multiple GPUs, e.g. 8 GPUs, run the following command:
|
||||
|
||||
```bash
|
||||
mim train mmocr config/spts/spts_resnet50_350e_icdar2013-spts.py --work-dir work_dirs/ --launcher pytorch --gpus 8 --cfg-options "load_from={CHECKPOINT_PATH}"
|
||||
mim train mmocr config/spts/spts_resnet50_8xb8-200e_icdar2013.py --work-dir work_dirs/ --launcher pytorch --gpus 8 --cfg-options "load_from={CHECKPOINT_PATH}" --amp
|
||||
```
|
||||
|
||||
### Testing commands
|
||||
|
@ -76,24 +90,29 @@ mim train mmocr config/spts/spts_resnet50_350e_icdar2013-spts.py --work-dir work
|
|||
In the current directory, run the following command to test the model on a dataset (e.g. icdar2013):
|
||||
|
||||
```bash
|
||||
mim test mmocr config/spts/spts_resnet50_350e_icdar2013-spts.py --work-dir work_dirs/ --checkpoint ${CHECKPOINT_PATH}
|
||||
mim test mmocr config/spts/spts_resnet50_8xb8-200e_icdar2013.py --work-dir work_dirs/ --checkpoint ${CHECKPOINT_PATH}
|
||||
```
|
||||
|
||||
## Results
|
||||
## Convert Weights from Official Repo
|
||||
|
||||
The weights from MMOCR are on the way. Users may download the weights from [SPTS](https://github.com/shannanyinxiang/SPTS#inference) and use the conversion script to convert them into MMOCR format.
|
||||
Users may download the weights from [SPTS](https://github.com/shannanyinxiang/SPTS#inference) and use the conversion script to convert them into MMOCR format.
|
||||
|
||||
```bash
|
||||
python tools/ckpt_adapter.py [SPTS_WEIGHTS_PATH] [MMOCR_WEIGHTS_PATH]
|
||||
```
|
||||
|
||||
Here are the results obtained on the converted weights. The results are lower than the original ones due to the difference in the test split of datasets, which will be addressed in next update.
|
||||
## Results
|
||||
|
||||
| Name | Model | E2E-None-Hmean |
|
||||
| :--------: | :-------------------: | :------------: |
|
||||
| ICDAR 2013 | ic13.pth (converted) | 0.8573 |
|
||||
| ctw1500 | ctw1500 (converted) | 0.6304 |
|
||||
| totaltext | totaltext (converted) | 0.6596 |
|
||||
All the models are trained on 8x A100 GPUs with AMP on (`--amp`). The overall batch size is 64.
|
||||
|
||||
| Name | Pretrained | Generic | Weak | Strong | Download |
|
||||
| ---------- | --------------------------------------------------------------------------------------- | ------- | ----- | ------ | ------------------------------------------------------------------------------------- |
|
||||
| ICDAR 2013 | [model](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_150e_pretrain-spts/spts_resnet50_150e_pretrain-spts-c9fe4c78.pth) / [log](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_150e_pretrain-spts/20230223_194550.log) | 87.10 | 91.46 | 93.41 | [model](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_200e_icdar2013/spts_resnet50_200e_icdar2013-64cb4d31.pth) / [log](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_200e_icdar2013/20230303_140316.log) |
|
||||
| ICDAR 2015 | [model](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_150e_pretrain-spts/spts_resnet50_150e_pretrain-spts-c9fe4c78.pth) / [log](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_150e_pretrain-spts/20230223_194550.log) | 69.09 | 73.45 | 79.19 | [model](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_200e_icdar2015/spts_resnet50_200e_icdar2015-d6e8621c.pth) / [log](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_200e_icdar2015/20230302_230026.log) |
|
||||
|
||||
| Name | Pretrained | None-Hmean | Full-Hmean | Download |
|
||||
| :-------: | -------------------------------------------------------------------------------------- | :--------: | :--------: | ------------------------------------------------------------------------------------- |
|
||||
| Totaltext | [model](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_150e_pretrain-spts/spts_resnet50_150e_pretrain-spts-c9fe4c78.pth) / [log](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_150e_pretrain-spts/20230223_194550.log) | 73.99 | 82.34 | [model](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_200e_totaltext/spts_resnet50_200e_totaltext-e3521af6.pth) / [log](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_200e_totaltext/20230303_103040.log) |
|
||||
|
||||
## Citation
|
||||
|
||||
|
@ -136,15 +155,15 @@ A project does not necessarily have to be finished in a single PR, but it's esse
|
|||
|
||||
<!-- As this template does. -->
|
||||
|
||||
- [ ] Milestone 2: Indicates a successful model implementation.
|
||||
- [x] Milestone 2: Indicates a successful model implementation.
|
||||
|
||||
- [ ] Training-time correctness
|
||||
- [x] Training-time correctness
|
||||
|
||||
<!-- If you are reproducing the result from a paper, checking this item means that you should have trained your model from scratch based on the original paper's specification and verified that the final result matches the report within a minor error range. -->
|
||||
|
||||
- [ ] Milestone 3: Good to be a part of our core package!
|
||||
- [x] Milestone 3: Good to be a part of our core package!
|
||||
|
||||
- [ ] Type hints and docstrings
|
||||
- [x] Type hints and docstrings
|
||||
|
||||
<!-- Ideally *all* the methods should have [type hints](https://www.pythontutorial.net/python-basics/python-type-hints/) and [docstrings](https://google.github.io/styleguide/pyguide.html#381-docstrings). [Example](https://github.com/open-mmlab/mmocr/blob/76637a290507f151215d299707c57cea5120976e/mmocr/utils/polygon_utils.py#L80-L96) -->
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
icdar2013_textspotting_data_root = 'data/icdar2013'
|
||||
icdar2013_textspotting_data_root = 'spts-data/icdar2013'
|
||||
|
||||
icdar2013_textspotting_train = dict(
|
||||
type='AdelDataset',
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
icdar2015_textspotting_data_root = 'data/icdar2015'
|
||||
icdar2015_textspotting_data_root = 'spts-data/icdar2015'
|
||||
|
||||
icdar2015_textspotting_train = dict(
|
||||
type='AdelDataset',
|
||||
|
|
|
@ -11,5 +11,4 @@ icdar2015_textspotting_test = dict(
|
|||
data_root=icdar2015_textspotting_data_root,
|
||||
ann_file='textspotting_test.json',
|
||||
test_mode=True,
|
||||
# indices=50,
|
||||
pipeline=None)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
mlt_textspotting_data_root = 'data/mlt2017'
|
||||
mlt_textspotting_data_root = 'spts-data/mlt2017'
|
||||
|
||||
mlt_textspotting_train = dict(
|
||||
type='AdelDataset',
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
syntext1_textspotting_data_root = 'data/syntext1'
|
||||
syntext1_textspotting_data_root = 'spts-data/syntext1'
|
||||
|
||||
syntext1_textspotting_train = dict(
|
||||
type='AdelDataset',
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
syntext2_textspotting_data_root = 'data/syntext2'
|
||||
syntext2_textspotting_data_root = 'spts-data/syntext2'
|
||||
|
||||
syntext2_textspotting_train = dict(
|
||||
type='AdelDataset',
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
totaltext_textspotting_data_root = 'data/totaltext'
|
||||
totaltext_textspotting_data_root = 'spts-data/totaltext'
|
||||
|
||||
totaltext_textspotting_train = dict(
|
||||
type='AdelDataset',
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
totaltext_textspotting_data_root = 'data/totaltext'
|
||||
|
||||
totaltext_textspotting_train = dict(
|
||||
type='OCRDataset',
|
||||
data_root=totaltext_textspotting_data_root,
|
||||
ann_file='textspotting_train.json',
|
||||
filter_cfg=dict(filter_empty_gt=True, min_size=32),
|
||||
pipeline=None)
|
||||
|
||||
totaltext_textspotting_test = dict(
|
||||
type='OCRDataset',
|
||||
data_root=totaltext_textspotting_data_root,
|
||||
ann_file='textspotting_test.json',
|
||||
test_mode=True,
|
||||
pipeline=None)
|
|
@ -4,7 +4,8 @@ env_cfg = dict(
|
|||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
||||
dist_cfg=dict(backend='nccl'),
|
||||
)
|
||||
randomness = dict(seed=None)
|
||||
|
||||
randomness = dict(seed=42)
|
||||
|
||||
default_hooks = dict(
|
||||
timer=dict(type='IterTimerHook'),
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
custom_imports = dict(imports=['spts'], allow_failed_imports=False)
|
||||
custom_imports = dict(
|
||||
imports=['projects.SPTS.spts'], allow_failed_imports=False)
|
||||
|
||||
file_client_args = dict(backend='disk')
|
||||
|
||||
|
@ -65,10 +66,7 @@ test_pipeline = [
|
|||
type='LoadOCRAnnotationsWithBezier',
|
||||
with_bbox=True,
|
||||
with_label=True,
|
||||
with_bezier=True,
|
||||
with_text=True),
|
||||
dict(type='Bezier2Polygon'),
|
||||
dict(type='ConvertText', dictionary=dictionary),
|
||||
dict(
|
||||
type='PackTextDetInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
|
||||
|
@ -87,7 +85,7 @@ train_pipeline = [
|
|||
with_text=True),
|
||||
dict(type='Bezier2Polygon'),
|
||||
dict(type='FixInvalidPolygon'),
|
||||
dict(type='ConvertText', dictionary=dictionary),
|
||||
dict(type='ConvertText', dictionary=dict(**dictionary, num_bins=0)),
|
||||
dict(type='RemoveIgnored'),
|
||||
dict(type='RandomCrop', min_side_ratio=0.5),
|
||||
dict(
|
||||
|
@ -119,7 +117,6 @@ train_pipeline = [
|
|||
hue=0.5)
|
||||
],
|
||||
prob=0.5),
|
||||
# dict(type='Polygon2Bezier'),
|
||||
dict(
|
||||
type='PackTextDetInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
_base_ = '_base_spts_resnet50.py'
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile', color_type='color_ignore_orientation'),
|
||||
dict(
|
||||
type='RescaleToShortSide',
|
||||
short_side_lens=[1000],
|
||||
long_side_bound=1824),
|
||||
dict(
|
||||
type='LoadOCRAnnotations',
|
||||
with_bbox=True,
|
||||
with_label=True,
|
||||
with_polygon=True,
|
||||
with_text=True),
|
||||
dict(
|
||||
type='PackTextDetInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
|
||||
]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile', color_type='color_ignore_orientation'),
|
||||
dict(
|
||||
type='LoadOCRAnnotations',
|
||||
with_bbox=True,
|
||||
with_label=True,
|
||||
with_polygon=True,
|
||||
with_text=True),
|
||||
dict(type='FixInvalidPolygon'),
|
||||
dict(type='RemoveIgnored'),
|
||||
dict(type='RandomCrop', min_side_ratio=0.5),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
transforms=[
|
||||
dict(
|
||||
type='RandomRotate',
|
||||
max_angle=30,
|
||||
pad_with_fixed_color=True,
|
||||
use_canvas=True)
|
||||
],
|
||||
prob=0.3),
|
||||
dict(type='FixInvalidPolygon'),
|
||||
dict(
|
||||
type='RandomChoiceResize',
|
||||
scales=[(640, 1600), (672, 1600), (704, 1600), (736, 1600),
|
||||
(768, 1600), (800, 1600), (832, 1600), (864, 1600),
|
||||
(896, 1600)],
|
||||
keep_ratio=True),
|
||||
dict(
|
||||
type='RandomApply',
|
||||
transforms=[
|
||||
dict(
|
||||
type='TorchVisionWrapper',
|
||||
op='ColorJitter',
|
||||
brightness=0.5,
|
||||
contrast=0.5,
|
||||
saturation=0.5,
|
||||
hue=0.5)
|
||||
],
|
||||
prob=0.5),
|
||||
dict(
|
||||
type='PackTextDetInputs',
|
||||
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
|
||||
]
|
|
@ -1,59 +0,0 @@
|
|||
_base_ = [
|
||||
'_base_spts_resnet50.py',
|
||||
'../_base_/datasets/ctw1500-spts.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
num_epochs = 350
|
||||
lr = 0.00001
|
||||
min_lr = 0.00001
|
||||
|
||||
optim_wrapper = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(type='AdamW', lr=lr, weight_decay=0.0001),
|
||||
paramwise_cfg=dict(custom_keys={
|
||||
'backbone': dict(lr_mult=0.1),
|
||||
}))
|
||||
|
||||
train_cfg = dict(
|
||||
type='EpochBasedTrainLoop', max_epochs=num_epochs, val_interval=10)
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
dict(type='LinearLR', end=5, start_factor=1 / 5, by_epoch=True),
|
||||
dict(
|
||||
type='LinearLR',
|
||||
begin=5,
|
||||
end=min(num_epochs,
|
||||
int((lr - min_lr) / (lr / num_epochs)) + 5),
|
||||
end_factor=min_lr / lr,
|
||||
by_epoch=True),
|
||||
]
|
||||
|
||||
# dataset settings
|
||||
ctw1500_textspotting_train = _base_.ctw1500_textspotting_train
|
||||
ctw1500_textspotting_train.pipeline = _base_.train_pipeline
|
||||
ctw1500_textspotting_test = _base_.ctw1500_textspotting_test
|
||||
ctw1500_textspotting_test.pipeline = _base_.test_pipeline
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=8,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='RepeatAugSampler', shuffle=True, num_repeats=2),
|
||||
dataset=ctw1500_textspotting_train)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=ctw1500_textspotting_test)
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
||||
|
||||
custom_imports = dict(imports='spts')
|
|
@ -1,57 +0,0 @@
|
|||
_base_ = [
|
||||
'_base_spts_resnet50.py',
|
||||
'../_base_/datasets/icdar2013-spts.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
num_epochs = 350
|
||||
lr = 0.00001
|
||||
min_lr = 0.00001
|
||||
|
||||
optim_wrapper = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(type='AdamW', lr=lr, weight_decay=0.0001),
|
||||
paramwise_cfg=dict(custom_keys={
|
||||
'backbone': dict(lr_mult=0.1),
|
||||
}))
|
||||
|
||||
train_cfg = dict(
|
||||
type='EpochBasedTrainLoop', max_epochs=num_epochs, val_interval=10)
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
dict(type='LinearLR', end=5, start_factor=1 / 5, by_epoch=True),
|
||||
dict(
|
||||
type='LinearLR',
|
||||
begin=5,
|
||||
end=min(num_epochs,
|
||||
int((lr - min_lr) / (lr / num_epochs)) + 5),
|
||||
end_factor=min_lr / lr,
|
||||
by_epoch=True),
|
||||
]
|
||||
|
||||
# dataset settings
|
||||
icdar2013_textspotting_train = _base_.icdar2013_textspotting_train
|
||||
icdar2013_textspotting_train.pipeline = _base_.train_pipeline
|
||||
icdar2013_textspotting_test = _base_.icdar2013_textspotting_test
|
||||
icdar2013_textspotting_test.pipeline = _base_.test_pipeline
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=8,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='RepeatAugSampler', shuffle=True, num_repeats=2),
|
||||
dataset=icdar2013_textspotting_train)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=icdar2013_textspotting_test)
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
|
@ -1,57 +0,0 @@
|
|||
_base_ = [
|
||||
'_base_spts_resnet50.py',
|
||||
'../_base_/datasets/icdar2015-spts.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
num_epochs = 350
|
||||
lr = 0.00001
|
||||
min_lr = 0.00001
|
||||
|
||||
optim_wrapper = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(type='AdamW', lr=lr, weight_decay=0.0001),
|
||||
paramwise_cfg=dict(custom_keys={
|
||||
'backbone': dict(lr_mult=0.1),
|
||||
}))
|
||||
|
||||
train_cfg = dict(
|
||||
type='EpochBasedTrainLoop', max_epochs=num_epochs, val_interval=10)
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
dict(type='LinearLR', end=5, start_factor=1 / 5, by_epoch=True),
|
||||
dict(
|
||||
type='LinearLR',
|
||||
begin=5,
|
||||
end=min(num_epochs,
|
||||
int((lr - min_lr) / (lr / num_epochs)) + 5),
|
||||
end_factor=min_lr / lr,
|
||||
by_epoch=True),
|
||||
]
|
||||
|
||||
# dataset settings
|
||||
icdar2015_textspotting_train = _base_.icdar2015_textspotting_train
|
||||
icdar2015_textspotting_train.pipeline = _base_.train_pipeline
|
||||
icdar2015_textspotting_test = _base_.icdar2015_textspotting_test
|
||||
icdar2015_textspotting_test.pipeline = _base_.test_pipeline
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=8,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='RepeatAugSampler', shuffle=True, num_repeats=2),
|
||||
dataset=icdar2015_textspotting_train)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=icdar2015_textspotting_test)
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
|
@ -2,7 +2,6 @@ _base_ = [
|
|||
'_base_spts_resnet50.py',
|
||||
'../_base_/datasets/icdar2013-spts.py',
|
||||
'../_base_/datasets/icdar2015-spts.py',
|
||||
'../_base_/datasets/ctw1500-spts.py',
|
||||
'../_base_/datasets/totaltext-spts.py',
|
||||
'../_base_/datasets/syntext1-spts.py',
|
||||
'../_base_/datasets/syntext2-spts.py',
|
||||
|
@ -16,15 +15,12 @@ min_lr = 0.00001
|
|||
|
||||
optim_wrapper = dict(
|
||||
type='OptimWrapper',
|
||||
accumulative_counts=2,
|
||||
optimizer=dict(type='AdamW', lr=lr, weight_decay=0.0001),
|
||||
paramwise_cfg=dict(custom_keys={
|
||||
'backbone': dict(lr_mult=0.1),
|
||||
}))
|
||||
train_cfg = dict(
|
||||
type='EpochBasedTrainLoop', max_epochs=num_epochs, val_interval=20)
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
||||
type='EpochBasedTrainLoop', max_epochs=num_epochs, val_interval=5)
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
dict(type='LinearLR', end=5, start_factor=1 / 5, by_epoch=True),
|
||||
|
@ -39,26 +35,24 @@ param_scheduler = [
|
|||
|
||||
# dataset settings
|
||||
train_list = [
|
||||
_base_.icdar2013_textspotting_train, _base_.icdar2015_textspotting_train,
|
||||
_base_.mlt_textspotting_train, _base_.totaltext_textspotting_train,
|
||||
_base_.syntext1_textspotting_train, _base_.syntext2_textspotting_train,
|
||||
_base_.ctw1500_textspotting_train
|
||||
_base_.icdar2013_textspotting_train,
|
||||
_base_.icdar2015_textspotting_train,
|
||||
_base_.mlt_textspotting_train,
|
||||
_base_.totaltext_textspotting_train,
|
||||
_base_.syntext1_textspotting_train,
|
||||
_base_.syntext2_textspotting_train,
|
||||
]
|
||||
|
||||
train_dataset = dict(
|
||||
type='ConcatDataset', datasets=train_list, pipeline=_base_.train_pipeline)
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=8,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='RepeatAugSampler', shuffle=True, num_repeats=2),
|
||||
dataset=train_dataset)
|
||||
|
||||
val_dataloader = None
|
||||
test_dataloader = None
|
||||
val_evaluator = None
|
||||
test_evaluator = None
|
||||
|
||||
val_cfg = None
|
||||
test_cfg = None
|
||||
train_dataloader = dict(
|
||||
batch_size=8,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='BatchAugSampler', shuffle=True, num_repeats=2),
|
||||
dataset=train_dataset)
|
|
@ -0,0 +1,87 @@
|
|||
_base_ = [
|
||||
'_base_spts_resnet50_mmocr.py',
|
||||
'../_base_/datasets/icdar2013.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
load_from = 'work_dirs/spts_resnet50_150e_pretrain-spts/epoch_150.pth'
|
||||
|
||||
num_epochs = 200
|
||||
lr = 0.00001
|
||||
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(
|
||||
type='CheckpointHook',
|
||||
save_best='generic/hmean',
|
||||
rule='greater',
|
||||
_delete_=True),
|
||||
logger=dict(type='LoggerHook', interval=1))
|
||||
|
||||
optim_wrapper = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(type='AdamW', lr=lr, weight_decay=0.0001),
|
||||
paramwise_cfg=dict(custom_keys={
|
||||
'backbone': dict(lr_mult=0.1),
|
||||
}))
|
||||
|
||||
train_cfg = dict(
|
||||
type='EpochBasedTrainLoop', max_epochs=num_epochs, val_interval=10)
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
||||
|
||||
# dataset settings
|
||||
icdar2013_textspotting_train = _base_.icdar2013_textspotting_train
|
||||
icdar2013_textspotting_train.pipeline = _base_.train_pipeline
|
||||
icdar2013_textspotting_test = _base_.icdar2013_textspotting_test
|
||||
icdar2013_textspotting_test.pipeline = _base_.test_pipeline
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=8,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='BatchAugSampler', shuffle=True, num_repeats=2),
|
||||
dataset=icdar2013_textspotting_train)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
pin_memory=True,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=icdar2013_textspotting_test)
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
||||
|
||||
val_evaluator = [
|
||||
dict(
|
||||
type='E2EPointMetric',
|
||||
prefix='generic',
|
||||
lexicon_path='data/icdar2013/lexicons/GenericVocabulary_new.txt',
|
||||
pair_path='data/icdar2013/lexicons/'
|
||||
'GenericVocabulary_pair_list.txt',
|
||||
match_dist_thr=None),
|
||||
dict(
|
||||
type='E2EPointMetric',
|
||||
prefix='weak',
|
||||
lexicon_path='data/icdar2013/lexicons/'
|
||||
'ch2_test_vocabulary_new.txt',
|
||||
pair_path='data/icdar2013/lexicons/'
|
||||
'ch2_test_vocabulary_pair_list.txt',
|
||||
match_dist_thr=0.4),
|
||||
dict(
|
||||
type='E2EPointMetric',
|
||||
prefix='strong',
|
||||
lexicon_path='data/icdar2013/lexicons/'
|
||||
'new_strong_lexicon/lexicons/',
|
||||
lexicon_mapping=('(.*).jpg', r'new_voc_\1.txt'),
|
||||
pair_path='data/icdar2013/lexicons/'
|
||||
'new_strong_lexicon/pairs/',
|
||||
pair_mapping=('(.*).jpg', r'pair_voc_\1.txt'),
|
||||
match_dist_thr=0.4),
|
||||
]
|
||||
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,87 @@
|
|||
_base_ = [
|
||||
'_base_spts_resnet50_mmocr.py',
|
||||
'../_base_/datasets/icdar2015.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
load_from = 'work_dirs/spts_resnet50_150e_pretrain-spts/epoch_150.pth'
|
||||
|
||||
num_epochs = 200
|
||||
lr = 0.00001
|
||||
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(
|
||||
type='CheckpointHook',
|
||||
save_best='generic/hmean',
|
||||
rule='greater',
|
||||
_delete_=True),
|
||||
logger=dict(type='LoggerHook', interval=10))
|
||||
|
||||
optim_wrapper = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(type='AdamW', lr=lr, weight_decay=0.0001),
|
||||
paramwise_cfg=dict(custom_keys={
|
||||
'backbone': dict(lr_mult=0.1),
|
||||
}))
|
||||
|
||||
train_cfg = dict(
|
||||
type='EpochBasedTrainLoop', max_epochs=num_epochs, val_interval=10)
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
||||
|
||||
# dataset settings
|
||||
icdar2015_textspotting_train = _base_.icdar2015_textspotting_train
|
||||
icdar2015_textspotting_train.pipeline = _base_.train_pipeline
|
||||
icdar2015_textspotting_test = _base_.icdar2015_textspotting_test
|
||||
icdar2015_textspotting_test.pipeline = _base_.test_pipeline
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=8,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='BatchAugSampler', shuffle=True, num_repeats=2),
|
||||
dataset=icdar2015_textspotting_train)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
pin_memory=True,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=icdar2015_textspotting_test)
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
||||
|
||||
val_evaluator = [
|
||||
dict(
|
||||
type='E2EPointMetric',
|
||||
prefix='generic',
|
||||
lexicon_path='data/icdar2015/lexicons/GenericVocabulary_new.txt',
|
||||
pair_path='data/icdar2015/lexicons/'
|
||||
'GenericVocabulary_pair_list.txt',
|
||||
match_dist_thr=None),
|
||||
dict(
|
||||
type='E2EPointMetric',
|
||||
prefix='weak',
|
||||
lexicon_path='data/icdar2015/lexicons/'
|
||||
'ch4_test_vocabulary_new.txt',
|
||||
pair_path='data/icdar2015/lexicons/'
|
||||
'ch4_test_vocabulary_pair_list.txt',
|
||||
match_dist_thr=0.4),
|
||||
dict(
|
||||
type='E2EPointMetric',
|
||||
prefix='strong',
|
||||
lexicon_path='data/icdar2015/lexicons/'
|
||||
'new_strong_lexicon/lexicons/',
|
||||
lexicon_mapping=('(.*).jpg', r'new_voc_\1.txt'),
|
||||
pair_path='data/icdar2015/lexicons/'
|
||||
'new_strong_lexicon/pairs/',
|
||||
pair_mapping=('(.*).jpg', r'pair_voc_\1.txt'),
|
||||
match_dist_thr=0.4),
|
||||
]
|
||||
|
||||
test_evaluator = val_evaluator
|
|
@ -1,12 +1,21 @@
|
|||
_base_ = [
|
||||
'_base_spts_resnet50.py',
|
||||
'../_base_/datasets/totaltext-spts.py',
|
||||
'_base_spts_resnet50_mmocr.py',
|
||||
'../_base_/datasets/totaltext.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
num_epochs = 350
|
||||
load_from = 'work_dirs/spts_resnet50_150e_pretrain-spts/epoch_150.pth'
|
||||
|
||||
num_epochs = 200
|
||||
lr = 0.00001
|
||||
min_lr = 0.00001
|
||||
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(
|
||||
type='CheckpointHook',
|
||||
save_best='none/hmean',
|
||||
rule='greater',
|
||||
_delete_=True),
|
||||
logger=dict(type='LoggerHook', interval=10))
|
||||
|
||||
optim_wrapper = dict(
|
||||
type='OptimWrapper',
|
||||
|
@ -19,17 +28,6 @@ train_cfg = dict(
|
|||
type='EpochBasedTrainLoop', max_epochs=num_epochs, val_interval=10)
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
dict(type='LinearLR', end=5, start_factor=1 / 5, by_epoch=True),
|
||||
dict(
|
||||
type='LinearLR',
|
||||
begin=5,
|
||||
end=min(num_epochs,
|
||||
int((lr - min_lr) / (lr / num_epochs)) + 5),
|
||||
end_factor=min_lr / lr,
|
||||
by_epoch=True),
|
||||
]
|
||||
|
||||
# dataset settings
|
||||
totaltext_textspotting_train = _base_.totaltext_textspotting_train
|
||||
|
@ -38,16 +36,18 @@ totaltext_textspotting_test = _base_.totaltext_textspotting_test
|
|||
totaltext_textspotting_test.pipeline = _base_.test_pipeline
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
batch_size=8,
|
||||
num_workers=8,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='RepeatAugSampler', shuffle=True, num_repeats=2),
|
||||
pin_memory=True,
|
||||
sampler=dict(type='BatchAugSampler', shuffle=True, num_repeats=2),
|
||||
dataset=totaltext_textspotting_train)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
pin_memory=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=totaltext_textspotting_test)
|
||||
|
||||
|
@ -55,3 +55,21 @@ test_dataloader = val_dataloader
|
|||
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
||||
|
||||
val_evaluator = [
|
||||
dict(
|
||||
type='E2EPointMetric',
|
||||
prefix='none',
|
||||
word_spotting=True,
|
||||
match_dist_thr=0.4),
|
||||
dict(
|
||||
type='E2EPointMetric',
|
||||
prefix='full',
|
||||
lexicon_path='data/totaltext/lexicons/weak_voc_new.txt',
|
||||
pair_path='data/totaltext/lexicons/'
|
||||
'weak_voc_pair_list.txt',
|
||||
word_spotting=True,
|
||||
match_dist_thr=0.4),
|
||||
]
|
||||
|
||||
test_evaluator = val_evaluator
|
|
@ -40,6 +40,7 @@ class AdelDataset(CocoDataset):
|
|||
None img. The maximum extra number of cycles to get a valid
|
||||
image. Defaults to 1000.
|
||||
"""
|
||||
METAINFO = {'classes': ('text', )}
|
||||
|
||||
def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]:
|
||||
"""Parse raw annotation to target format.
|
||||
|
|
|
@ -1,14 +1,20 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
import glob
|
||||
import os.path as osp
|
||||
import re
|
||||
from typing import Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmengine.logging import MMLogger
|
||||
from shapely.geometry import LineString, Point
|
||||
from rapidfuzz.distance import Levenshtein
|
||||
from shapely.geometry import Point
|
||||
|
||||
from mmocr.registry import METRICS
|
||||
|
||||
# TODO: CTW1500 read pair
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class E2EPointMetric(BaseMetric):
|
||||
|
@ -17,7 +23,20 @@ class E2EPointMetric(BaseMetric):
|
|||
Args:
|
||||
text_score_thrs (dict): Best text score threshold searching
|
||||
space. Defaults to dict(start=0.8, stop=1, step=0.01).
|
||||
TODO: docstr
|
||||
word_spotting (bool): Whether to work in word spotting mode. Defaults
|
||||
to False.
|
||||
lexicon_path (str, optional): Lexicon path for word spotting, which
|
||||
points to a lexicon file or a directory. Defaults to None.
|
||||
lexicon_mapping (tuple, optional): The rule to map test image name to
|
||||
its corresponding lexicon file. Only effective when lexicon path
|
||||
is a directory. Defaults to ('(.*).jpg', r'\1.txt').
|
||||
pair_path (str, optional): Pair path for word spotting, which points
|
||||
to a pair file or a directory. Defaults to None.
|
||||
pair_mapping (tuple, optional): The rule to map test image name to
|
||||
its corresponding pair file. Only effective when pair path is a
|
||||
directory. Defaults to ('(.*).jpg', r'\1.txt').
|
||||
match_dist_thr (float, optional): Matching distance threshold for
|
||||
word spotting. Defaults to None.
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
|
@ -31,20 +50,52 @@ class E2EPointMetric(BaseMetric):
|
|||
def __init__(self,
|
||||
text_score_thrs: Dict = dict(start=0.8, stop=1, step=0.01),
|
||||
word_spotting: bool = False,
|
||||
lexicon_path: Optional[str] = None,
|
||||
lexicon_mapping: Tuple[str, str] = ('(.*).jpg', r'\1.txt'),
|
||||
pair_path: Optional[str] = None,
|
||||
pair_mapping: Tuple[str, str] = ('(.*).jpg', r'\1.txt'),
|
||||
match_dist_thr: Optional[float] = None,
|
||||
collect_device: str = 'cpu',
|
||||
prefix: Optional[str] = None) -> None:
|
||||
super().__init__(collect_device=collect_device, prefix=prefix)
|
||||
self.text_score_thrs = np.arange(**text_score_thrs)
|
||||
self.word_spotting = word_spotting
|
||||
self.match_dist_thr = match_dist_thr
|
||||
if lexicon_path:
|
||||
self.lexicon_mapping = lexicon_mapping
|
||||
self.pair_mapping = pair_mapping
|
||||
self.lexicons = self._read_lexicon(lexicon_path)
|
||||
self.pairs = self._read_pair(pair_path)
|
||||
|
||||
def _read_lexicon(self, lexicon_path: str) -> List[str]:
|
||||
if lexicon_path.endswith('.txt'):
|
||||
lexicon = open(lexicon_path, 'r').read().splitlines()
|
||||
lexicon = [ele.strip() for ele in lexicon]
|
||||
else:
|
||||
lexicon = {}
|
||||
for file in glob.glob(osp.join(lexicon_path, '*.txt')):
|
||||
basename = osp.basename(file)
|
||||
lexicon[basename] = self._read_lexicon(file)
|
||||
return lexicon
|
||||
|
||||
def _read_pair(self, pair_path: str) -> Dict[str, str]:
|
||||
pairs = {}
|
||||
if pair_path.endswith('.txt'):
|
||||
pair_lines = open(pair_path, 'r').read().splitlines()
|
||||
for line in pair_lines:
|
||||
line = line.strip()
|
||||
word = line.split(' ')[0].upper()
|
||||
word_gt = line[len(word) + 1:]
|
||||
pairs[word] = word_gt
|
||||
else:
|
||||
for file in glob.glob(osp.join(pair_path, '*.txt')):
|
||||
basename = osp.basename(file)
|
||||
pairs[basename] = self._read_pair(file)
|
||||
return pairs
|
||||
|
||||
def poly_center(self, poly_pts):
|
||||
poly_pts = np.array(poly_pts).reshape(-1, 2)
|
||||
num_points = poly_pts.shape[0]
|
||||
line1 = LineString(poly_pts[int(num_points / 2):])
|
||||
line2 = LineString(poly_pts[:int(num_points / 2)])
|
||||
mid_pt1 = np.array(line1.interpolate(0.5, normalized=True).coords[0])
|
||||
mid_pt2 = np.array(line2.interpolate(0.5, normalized=True).coords[0])
|
||||
return (mid_pt1 + mid_pt2) / 2
|
||||
return poly_pts.mean(0)
|
||||
|
||||
def process(self, data_batch: Sequence[Dict],
|
||||
data_samples: Sequence[Dict]) -> None:
|
||||
|
@ -87,6 +138,8 @@ class E2EPointMetric(BaseMetric):
|
|||
~pred_ignore_flags)
|
||||
|
||||
result = dict(
|
||||
# reserved for image-level lexcions
|
||||
gt_img_name=osp.basename(data_sample.get('img_path', '')),
|
||||
text_scores=text_scores,
|
||||
pred_points=pred_points,
|
||||
gt_points=gt_points,
|
||||
|
@ -125,10 +178,33 @@ class E2EPointMetric(BaseMetric):
|
|||
gt_texts = result['gt_texts']
|
||||
pred_texts = result['pred_texts']
|
||||
gt_ignore_flags = result['gt_ignore_flags']
|
||||
gt_img_name = result['gt_img_name']
|
||||
|
||||
# Correct the words with lexicon
|
||||
pred_dist_flags = np.zeros(len(pred_texts), dtype=bool)
|
||||
if hasattr(self, 'lexicons'):
|
||||
for i, pred_text in enumerate(pred_texts):
|
||||
# If it's an image-level lexicon
|
||||
if isinstance(self.lexicons, dict):
|
||||
lexicon_name = self._map_img_name(
|
||||
gt_img_name, self.lexicon_mapping)
|
||||
pair_name = self._map_img_name(gt_img_name,
|
||||
self.pair_mapping)
|
||||
pred_texts[i], match_dist = self._match_word(
|
||||
pred_text, self.lexicons[lexicon_name],
|
||||
self.pairs[pair_name])
|
||||
else:
|
||||
pred_texts[i], match_dist = self._match_word(
|
||||
pred_text, self.lexicons, self.pairs)
|
||||
if (self.match_dist_thr
|
||||
and match_dist >= self.match_dist_thr):
|
||||
# won't even count this as a prediction
|
||||
pred_dist_flags[i] = True
|
||||
|
||||
# Filter out predictions by IoU threshold
|
||||
for i, text_score_thr in enumerate(self.text_score_thrs):
|
||||
pred_ignore_flags = text_scores < text_score_thr
|
||||
pred_ignore_flags = pred_dist_flags | (
|
||||
text_scores < text_score_thr)
|
||||
filtered_pred_texts = self._get_true_elements(
|
||||
pred_texts, ~pred_ignore_flags)
|
||||
filtered_pred_points = self._get_true_elements(
|
||||
|
@ -148,10 +224,8 @@ class E2EPointMetric(BaseMetric):
|
|||
min_idx = np.argmin(dists)
|
||||
if gt_texts[min_idx] == '###' or gt_ignore_flags[min_idx]:
|
||||
continue
|
||||
# if not gt_matched[min_idx] and self.text_match(
|
||||
# gt_texts[min_idx].upper(), pred_text.upper()):
|
||||
if (not gt_matched[min_idx] and gt_texts[min_idx].upper()
|
||||
== pred_text.upper()):
|
||||
if not gt_matched[min_idx] and (
|
||||
pred_text.upper() == gt_texts[min_idx].upper()):
|
||||
gt_matched[min_idx] = True
|
||||
num_tp[i] += 1
|
||||
num_preds[i] += 1
|
||||
|
@ -173,6 +247,69 @@ class E2EPointMetric(BaseMetric):
|
|||
best_eval_results = eval_results
|
||||
return best_eval_results
|
||||
|
||||
def _map_img_name(self, img_name: str, mapping: Tuple[str, str]) -> str:
|
||||
"""Map the image name to the another one based on mapping."""
|
||||
return re.sub(mapping[0], mapping[1], img_name)
|
||||
|
||||
def _true_indexes(self, array: np.ndarray) -> np.ndarray:
|
||||
"""Get indexes of True elements from a 1D boolean array."""
|
||||
return np.where(array)[0]
|
||||
|
||||
def _word_spotting_filter(self, gt_ignore_flags: np.ndarray,
|
||||
gt_texts: List[str]
|
||||
) -> Tuple[np.ndarray, List[str]]:
|
||||
"""Filter out gt instances that cannot be in a valid dictionary, and do
|
||||
some simple preprocessing to texts."""
|
||||
|
||||
for i in range(len(gt_texts)):
|
||||
if gt_ignore_flags[i]:
|
||||
continue
|
||||
text = gt_texts[i]
|
||||
if text[-2:] in ["'s", "'S"]:
|
||||
text = text[:-2]
|
||||
text = text.strip('-')
|
||||
for char in "'!?.:,*\"()·[]/":
|
||||
text = text.replace(char, ' ')
|
||||
text = text.strip()
|
||||
gt_ignore_flags[i] = not self._include_in_dict(text)
|
||||
if not gt_ignore_flags[i]:
|
||||
gt_texts[i] = text
|
||||
|
||||
return gt_ignore_flags, gt_texts
|
||||
|
||||
def _include_in_dict(self, text: str) -> bool:
|
||||
"""Check if the text could be in a valid dictionary."""
|
||||
if len(text) != len(text.replace(' ', '')) or len(text) < 3:
|
||||
return False
|
||||
not_allowed = '×÷·'
|
||||
valid_ranges = [(ord(u'a'), ord(u'z')), (ord(u'A'), ord(u'Z')),
|
||||
(ord(u'À'), ord(u'ƿ')), (ord(u'DŽ'), ord(u'ɿ')),
|
||||
(ord(u'Ά'), ord(u'Ͽ')), (ord(u'-'), ord(u'-'))]
|
||||
for char in text:
|
||||
code = ord(char)
|
||||
if (not_allowed.find(char) != -1):
|
||||
return False
|
||||
valid = any(code >= r[0] and code <= r[1] for r in valid_ranges)
|
||||
if not valid:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _match_word(self,
|
||||
text: str,
|
||||
lexicons: List[str],
|
||||
pairs: Optional[Dict[str, str]] = None) -> Tuple[str, int]:
|
||||
"""Match the text with the lexicons and pairs."""
|
||||
text = text.upper()
|
||||
matched_word = ''
|
||||
matched_dist = 100
|
||||
for lexicon in lexicons:
|
||||
lexicon = lexicon.upper()
|
||||
norm_dist = Levenshtein.distance(text, lexicon)
|
||||
norm_dist = Levenshtein.normalized_distance(text, lexicon)
|
||||
if norm_dist < matched_dist:
|
||||
matched_dist = norm_dist
|
||||
if pairs:
|
||||
matched_word = pairs[lexicon]
|
||||
else:
|
||||
matched_word = lexicon
|
||||
return matched_word, matched_dist
|
||||
|
|
|
@ -58,18 +58,15 @@ class SPTSDecoder(BaseDecoder):
|
|||
max_seq_len=self.max_seq_len,
|
||||
init_cfg=init_cfg)
|
||||
self.num_bins = num_bins
|
||||
self.shifted_seq_end_idx = self.num_bins + self.dictionary.seq_end_idx
|
||||
self.shifted_start_idx = self.num_bins + self.dictionary.start_idx
|
||||
|
||||
actual_num_classes = self.dictionary.num_classes + num_bins
|
||||
|
||||
self.embedding = DecoderEmbeddings(
|
||||
actual_num_classes, self.dictionary.padding_idx + num_bins,
|
||||
d_model, self.max_seq_len, dropout)
|
||||
self.embedding = DecoderEmbeddings(self.dictionary.num_classes,
|
||||
self.dictionary.padding_idx,
|
||||
d_model, self.max_seq_len, dropout)
|
||||
self.pos_embedding = PositionEmbeddingSine(d_model // 2)
|
||||
|
||||
self.vocab_embed = self._gen_vocab_embed(d_model, d_model,
|
||||
actual_num_classes, 3)
|
||||
self.dictionary.num_classes,
|
||||
3)
|
||||
encoder_layer = TransformerEncoderLayer(d_model, n_head, d_feedforward,
|
||||
dropout, 'relu',
|
||||
normalize_before)
|
||||
|
@ -166,7 +163,7 @@ class SPTSDecoder(BaseDecoder):
|
|||
max_probs = []
|
||||
seq = torch.zeros(
|
||||
batch_size, 1, dtype=torch.long).to(
|
||||
out_enc.device) + self.shifted_start_idx
|
||||
out_enc.device) + self.dictionary.start_idx
|
||||
for i in range(self.max_seq_len):
|
||||
tgt = self.embedding(seq).permute(1, 0, 2)
|
||||
hs = self.decoder(
|
||||
|
@ -182,13 +179,13 @@ class SPTSDecoder(BaseDecoder):
|
|||
|
||||
# bins chars unk eos seq_eos sos padding
|
||||
if i % 27 == 0: # coordinate or eos
|
||||
out[:, self.num_bins:self.shifted_seq_end_idx] = 0
|
||||
out[:, self.shifted_seq_end_idx + 1:] = 0
|
||||
out[:, self.num_bins:self.dictionary.seq_end_idx] = 0
|
||||
out[:, self.dictionary.seq_end_idx + 1:] = 0
|
||||
elif i % 27 == 1: # coordinate
|
||||
out[:, self.num_bins:] = 0
|
||||
else: # chars
|
||||
out[:, :self.num_bins] = 0
|
||||
out[:, self.shifted_seq_end_idx:] = 0
|
||||
out[:, self.dictionary.seq_end_idx:] = 0
|
||||
|
||||
max_prob, extra_seq = torch.max(out, dim=-1, keepdim=True)
|
||||
# prob, extra_seq = out.topk(dim=-1, k=1)
|
||||
|
@ -196,7 +193,7 @@ class SPTSDecoder(BaseDecoder):
|
|||
# TODO: optimize for multi-batch
|
||||
seq = torch.cat([seq, extra_seq], dim=-1)
|
||||
max_probs.append(max_prob)
|
||||
if extra_seq[0] == self.shifted_seq_end_idx:
|
||||
if extra_seq[0] == self.dictionary.seq_end_idx:
|
||||
break
|
||||
|
||||
max_probs = torch.cat(max_probs, dim=-1)
|
||||
|
|
|
@ -15,6 +15,8 @@ class SPTSDictionary(Dictionary):
|
|||
Args:
|
||||
dict_file (str): The path of Character dict file which a single
|
||||
character must occupies a line.
|
||||
num_bins (int): Number of bins dividing the image, which is used to
|
||||
shift the character indexes. Defaults to 1000.
|
||||
with_start (bool): The flag to control whether to include the start
|
||||
token. Defaults to False.
|
||||
with_end (bool): The flag to control whether to include the end token.
|
||||
|
@ -45,6 +47,7 @@ class SPTSDictionary(Dictionary):
|
|||
def __init__(
|
||||
self,
|
||||
dict_file: str,
|
||||
num_bins: int = 1000,
|
||||
with_start: bool = False,
|
||||
with_end: bool = False,
|
||||
with_seq_end: bool = False,
|
||||
|
@ -74,6 +77,26 @@ class SPTSDictionary(Dictionary):
|
|||
padding_token=padding_token,
|
||||
unknown_token=unknown_token)
|
||||
|
||||
self.num_bins = num_bins
|
||||
self._shift_idx()
|
||||
|
||||
@property
|
||||
def num_classes(self) -> int:
|
||||
"""int: Number of output classes. Special tokens are counted.
|
||||
"""
|
||||
return len(self._dict) + self.num_bins
|
||||
|
||||
def _shift_idx(self):
|
||||
idx_terms = [
|
||||
'start_idx', 'end_idx', 'unknown_idx', 'seq_end_idx', 'padding_idx'
|
||||
]
|
||||
for term in idx_terms:
|
||||
value = getattr(self, term)
|
||||
if value:
|
||||
setattr(self, term, value + self.num_bins)
|
||||
for char in self._dict:
|
||||
self._char2idx[char] += self.num_bins
|
||||
|
||||
def _update_dict(self):
|
||||
"""Update the dict with tokens according to parameters."""
|
||||
# BOS/EOS
|
||||
|
@ -129,10 +152,11 @@ class SPTSDictionary(Dictionary):
|
|||
assert isinstance(index, (list, tuple))
|
||||
string = ''
|
||||
for i in index:
|
||||
assert i < len(self._dict), f'Index: {i} out of range! Index ' \
|
||||
f'must be less than {len(self._dict)}'
|
||||
assert i < self.num_classes, f'Index: {i} out of range! Index ' \
|
||||
f'must be less than {self.num_classes}'
|
||||
# TODO: find its difference from ignore_chars
|
||||
# in TextRecogPostprocessor
|
||||
if self._dict[i] is not None:
|
||||
string += self._dict[i]
|
||||
shifted_i = i - self.num_bins
|
||||
if self._dict[shifted_i] is not None:
|
||||
string += self._dict[shifted_i]
|
||||
return string
|
||||
|
|
|
@ -81,7 +81,7 @@ class SPTSModuleLoss(CEModuleLoss):
|
|||
self.max_num_text = (self.max_seq_len - 1) // (2 + max_text_len)
|
||||
self.num_bins = num_bins
|
||||
|
||||
weights = torch.ones(self.dictionary.num_classes + num_bins)
|
||||
weights = torch.ones(self.dictionary.num_classes, dtype=torch.float32)
|
||||
weights[self.dictionary.seq_end_idx] = seq_eos_coef
|
||||
weights.requires_grad_ = False
|
||||
self.loss_ce = nn.CrossEntropyLoss(
|
||||
|
@ -117,18 +117,24 @@ class SPTSModuleLoss(CEModuleLoss):
|
|||
if data_sample.get('have_target', False):
|
||||
continue
|
||||
|
||||
if len(data_sample.gt_instances.polygons) > self.max_num_text:
|
||||
if len(data_sample.gt_instances) > self.max_num_text:
|
||||
keep = random.sample(
|
||||
range(len(data_sample.gt_instances['polygons'])),
|
||||
self.max_num_text)
|
||||
range(len(data_sample.gt_instances)), self.max_num_text)
|
||||
data_sample.gt_instances = data_sample.gt_instances[keep]
|
||||
|
||||
gt_instances = data_sample.gt_instances
|
||||
|
||||
if len(gt_instances.polygons) > 0:
|
||||
if len(gt_instances) > 0:
|
||||
center_pts = []
|
||||
# Slightly different from the original implementation
|
||||
# which gets the center points from bezier curves
|
||||
# for bezier_pt in gt_instances.beziers:
|
||||
# bezier_pt = bezier_pt.reshape(8, 2)
|
||||
# mid_pt1 = sample_bezier_curve(
|
||||
# bezier_pt[:4], mid_point=True)
|
||||
# mid_pt2 = sample_bezier_curve(
|
||||
# bezier_pt[4:], mid_point=True)
|
||||
# center_pt = (mid_pt1 + mid_pt2) / 2
|
||||
for polygon in gt_instances.polygons:
|
||||
center_pt = polygon.reshape(-1, 2).mean(0)
|
||||
center_pts.append(center_pt)
|
||||
|
@ -152,7 +158,7 @@ class SPTSModuleLoss(CEModuleLoss):
|
|||
dtype=torch.long) + self.dictionary.end_idx
|
||||
max_len = min(self.max_text_len - 1, len(indexes))
|
||||
indexes_tensor[:max_len] = torch.LongTensor(indexes)[:max_len]
|
||||
indexes_tensor = indexes_tensor + self.num_bins
|
||||
indexes_tensor = indexes_tensor
|
||||
gt_indexes.append(indexes_tensor)
|
||||
|
||||
if len(gt_indexes) == 0:
|
||||
|
@ -164,15 +170,12 @@ class SPTSModuleLoss(CEModuleLoss):
|
|||
|
||||
if self.dictionary.start_idx is not None:
|
||||
gt_indexes = torch.cat([
|
||||
torch.LongTensor(
|
||||
[self.dictionary.start_idx + self.num_bins]),
|
||||
gt_indexes
|
||||
torch.LongTensor([self.dictionary.start_idx]), gt_indexes
|
||||
])
|
||||
if self.dictionary.seq_end_idx is not None:
|
||||
gt_indexes = torch.cat([
|
||||
gt_indexes,
|
||||
torch.LongTensor(
|
||||
[self.dictionary.seq_end_idx + self.num_bins])
|
||||
torch.LongTensor([self.dictionary.seq_end_idx])
|
||||
])
|
||||
|
||||
batch_max_len = max(batch_max_len, len(gt_indexes))
|
||||
|
@ -190,7 +193,7 @@ class SPTSModuleLoss(CEModuleLoss):
|
|||
|
||||
padded_indexes = (
|
||||
torch.zeros(batch_max_len, dtype=torch.long) +
|
||||
self.dictionary.padding_idx + self.num_bins)
|
||||
self.dictionary.padding_idx)
|
||||
padded_indexes[:len(indexes)] = indexes
|
||||
data_sample.gt_instances.set_metainfo(
|
||||
dict(padded_indexes=padded_indexes))
|
||||
|
|
|
@ -102,12 +102,11 @@ class SPTSPostprocessor(BaseTextRecogPostprocessor):
|
|||
for char_index, char_score in zip(output_index[2:],
|
||||
output_score[2:]):
|
||||
# the first num_bins indexes are for points
|
||||
dict_idx = char_index - self.num_bins
|
||||
if dict_idx in self.ignore_indexes:
|
||||
if char_index in self.ignore_indexes:
|
||||
continue
|
||||
if dict_idx == self.dictionary.end_idx:
|
||||
if char_index == self.dictionary.end_idx:
|
||||
break
|
||||
indexes[-1].append(dict_idx)
|
||||
indexes[-1].append(char_index)
|
||||
char_scores.append(char_score)
|
||||
text_scores.append(np.mean(char_scores).item())
|
||||
return indexes, text_scores, points, pt_scores
|
||||
|
|
Loading…
Reference in New Issue