[SPTS] train (#1756)

* [Feature] Add RepeatAugSampler

* initial commit

* spts inference done

* merge repeat_aug (bug in multi-node?)

* fix inference

* train done

* rm readme

* Revert "merge repeat_aug (bug in multi-node?)"

This reverts commit 393506a97c.

* Revert "[Feature] Add RepeatAugSampler"

This reverts commit 2089b02b48.

* remove utils

* readme & conversion script

* update readme

* fix

* optimize

* rename cfg & del compose

* fix

* fix

* tmp commit

* update training setting

* update cfg

* update readme

* e2e metric

* update cfg

* fix

* update readme

* fix

* update
pull/1761/head
Tong Gao 2023-03-07 14:18:01 +08:00 committed by GitHub
parent 81fd74c266
commit 5670695338
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 561 additions and 293 deletions

View File

@ -152,8 +152,8 @@ def crop_polygon(polygon: ArrayLike,
np.array or None: Cropped polygon. If the polygon is not within the
crop box, return None.
"""
poly = poly2shapely(polygon)
crop_poly = poly2shapely(bbox2poly(crop_box))
poly = poly_make_valid(poly2shapely(polygon))
crop_poly = poly_make_valid(poly2shapely(bbox2poly(crop_box)))
poly_cropped = poly.intersection(crop_poly)
if poly_cropped.area == 0. or not isinstance(
poly_cropped, shapely.geometry.polygon.Polygon):

View File

@ -36,10 +36,24 @@ $env:PYTHONPATH=Get-Location
### Dataset
**As of now, the implementation uses datasets provided by SPTS, but these datasets
will be available in MMOCR's dataset preparer very soon.**
As of now, the implementation uses datasets provided by SPTS for pre-training, and uses MMOCR's datasets for fine-tuning and testing. It's because the test split of SPTS's datasets does not contain enough information for e2e evaluation; and MMOCR's dataset preparer has not yet supported all the datasets used in SPTS. *We are working on this issue, and they will be available in MMOCR's dataset preparer very soon.*
Download and extract all the datasets into `data/` following [SPTS official guide](https://github.com/shannanyinxiang/SPTS#dataset).
Please follow these steps to prepare the datasets:
- Download and extract all the SPTS datasets into `spts-data/` following [SPTS official guide](https://github.com/shannanyinxiang/SPTS#dataset).
- Use [Dataset Preparer](https://mmocr.readthedocs.io/en/dev-1.x/user_guides/data_prepare/dataset_preparer.html) to prepare `icdar2013`, `icdar2015` and `totaltext` for `textspotting` task.
```shell
# Run in MMOCR's root directory
python tools/dataset_converters/prepare_dataset.py icdar2013 icdar2015 totaltext --task textspotting
```
Then create a soft link to `data/` directory in the project root directory:
```shell
ln -s ../../data/ .
```
### Training commands
@ -48,13 +62,13 @@ In the current directory, run the following command to train the model:
#### Pretrain
```bash
mim train mmocr config/spts/spts_resnet50_150e_pretrain-spts.py --work-dir work_dirs/
mim train mmocr config/spts/spts_resnet50_150e_pretrain-spts.py --work-dir work_dirs/ --amp
```
To train on multiple GPUs, e.g. 8 GPUs, run the following command:
```bash
mim train mmocr config/spts/spts_resnet50_150e_pretrain-spts.py --work-dir work_dirs/ --launcher pytorch --gpus 8
mim train mmocr config/spts/spts_resnet50_150e_pretrain-spts.py --work-dir work_dirs/ --launcher pytorch --gpus 8 --amp
```
#### Finetune
@ -62,13 +76,13 @@ mim train mmocr config/spts/spts_resnet50_150e_pretrain-spts.py --work-dir work_
Similarly, run the following command to finetune the model on a dataset (e.g. icdar2013):
```bash
mim train mmocr config/spts/spts_resnet50_350e_icdar2013-spts.py --work-dir work_dirs/ --cfg-options "load_from={CHECKPOINT_PATH}"
mim train mmocr config/spts/spts_resnet50_8xb8-200e_icdar2013.py --work-dir work_dirs/ --cfg-options "load_from={CHECKPOINT_PATH}" --amp
```
To finetune on multiple GPUs, e.g. 8 GPUs, run the following command:
```bash
mim train mmocr config/spts/spts_resnet50_350e_icdar2013-spts.py --work-dir work_dirs/ --launcher pytorch --gpus 8 --cfg-options "load_from={CHECKPOINT_PATH}"
mim train mmocr config/spts/spts_resnet50_8xb8-200e_icdar2013.py --work-dir work_dirs/ --launcher pytorch --gpus 8 --cfg-options "load_from={CHECKPOINT_PATH}" --amp
```
### Testing commands
@ -76,24 +90,29 @@ mim train mmocr config/spts/spts_resnet50_350e_icdar2013-spts.py --work-dir work
In the current directory, run the following command to test the model on a dataset (e.g. icdar2013):
```bash
mim test mmocr config/spts/spts_resnet50_350e_icdar2013-spts.py --work-dir work_dirs/ --checkpoint ${CHECKPOINT_PATH}
mim test mmocr config/spts/spts_resnet50_8xb8-200e_icdar2013.py --work-dir work_dirs/ --checkpoint ${CHECKPOINT_PATH}
```
## Results
## Convert Weights from Official Repo
The weights from MMOCR are on the way. Users may download the weights from [SPTS](https://github.com/shannanyinxiang/SPTS#inference) and use the conversion script to convert them into MMOCR format.
Users may download the weights from [SPTS](https://github.com/shannanyinxiang/SPTS#inference) and use the conversion script to convert them into MMOCR format.
```bash
python tools/ckpt_adapter.py [SPTS_WEIGHTS_PATH] [MMOCR_WEIGHTS_PATH]
```
Here are the results obtained on the converted weights. The results are lower than the original ones due to the difference in the test split of datasets, which will be addressed in next update.
## Results
| Name | Model | E2E-None-Hmean |
| :--------: | :-------------------: | :------------: |
| ICDAR 2013 | ic13.pth (converted) | 0.8573 |
| ctw1500 | ctw1500 (converted) | 0.6304 |
| totaltext | totaltext (converted) | 0.6596 |
All the models are trained on 8x A100 GPUs with AMP on (`--amp`). The overall batch size is 64.
| Name | Pretrained | Generic | Weak | Strong | Download |
| ---------- | --------------------------------------------------------------------------------------- | ------- | ----- | ------ | ------------------------------------------------------------------------------------- |
| ICDAR 2013 | [model](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_150e_pretrain-spts/spts_resnet50_150e_pretrain-spts-c9fe4c78.pth) / [log](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_150e_pretrain-spts/20230223_194550.log) | 87.10 | 91.46 | 93.41 | [model](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_200e_icdar2013/spts_resnet50_200e_icdar2013-64cb4d31.pth) / [log](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_200e_icdar2013/20230303_140316.log) |
| ICDAR 2015 | [model](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_150e_pretrain-spts/spts_resnet50_150e_pretrain-spts-c9fe4c78.pth) / [log](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_150e_pretrain-spts/20230223_194550.log) | 69.09 | 73.45 | 79.19 | [model](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_200e_icdar2015/spts_resnet50_200e_icdar2015-d6e8621c.pth) / [log](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_200e_icdar2015/20230302_230026.log) |
| Name | Pretrained | None-Hmean | Full-Hmean | Download |
| :-------: | -------------------------------------------------------------------------------------- | :--------: | :--------: | ------------------------------------------------------------------------------------- |
| Totaltext | [model](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_150e_pretrain-spts/spts_resnet50_150e_pretrain-spts-c9fe4c78.pth) / [log](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_150e_pretrain-spts/20230223_194550.log) | 73.99 | 82.34 | [model](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_200e_totaltext/spts_resnet50_200e_totaltext-e3521af6.pth) / [log](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_200e_totaltext/20230303_103040.log) |
## Citation
@ -136,15 +155,15 @@ A project does not necessarily have to be finished in a single PR, but it's esse
<!-- As this template does. -->
- [ ] Milestone 2: Indicates a successful model implementation.
- [x] Milestone 2: Indicates a successful model implementation.
- [ ] Training-time correctness
- [x] Training-time correctness
<!-- If you are reproducing the result from a paper, checking this item means that you should have trained your model from scratch based on the original paper's specification and verified that the final result matches the report within a minor error range. -->
- [ ] Milestone 3: Good to be a part of our core package!
- [x] Milestone 3: Good to be a part of our core package!
- [ ] Type hints and docstrings
- [x] Type hints and docstrings
<!-- Ideally *all* the methods should have [type hints](https://www.pythontutorial.net/python-basics/python-type-hints/) and [docstrings](https://google.github.io/styleguide/pyguide.html#381-docstrings). [Example](https://github.com/open-mmlab/mmocr/blob/76637a290507f151215d299707c57cea5120976e/mmocr/utils/polygon_utils.py#L80-L96) -->

View File

@ -1,4 +1,4 @@
icdar2013_textspotting_data_root = 'data/icdar2013'
icdar2013_textspotting_data_root = 'spts-data/icdar2013'
icdar2013_textspotting_train = dict(
type='AdelDataset',

View File

@ -1,4 +1,4 @@
icdar2015_textspotting_data_root = 'data/icdar2015'
icdar2015_textspotting_data_root = 'spts-data/icdar2015'
icdar2015_textspotting_train = dict(
type='AdelDataset',

View File

@ -11,5 +11,4 @@ icdar2015_textspotting_test = dict(
data_root=icdar2015_textspotting_data_root,
ann_file='textspotting_test.json',
test_mode=True,
# indices=50,
pipeline=None)

View File

@ -1,4 +1,4 @@
mlt_textspotting_data_root = 'data/mlt2017'
mlt_textspotting_data_root = 'spts-data/mlt2017'
mlt_textspotting_train = dict(
type='AdelDataset',

View File

@ -1,4 +1,4 @@
syntext1_textspotting_data_root = 'data/syntext1'
syntext1_textspotting_data_root = 'spts-data/syntext1'
syntext1_textspotting_train = dict(
type='AdelDataset',

View File

@ -1,4 +1,4 @@
syntext2_textspotting_data_root = 'data/syntext2'
syntext2_textspotting_data_root = 'spts-data/syntext2'
syntext2_textspotting_train = dict(
type='AdelDataset',

View File

@ -1,4 +1,4 @@
totaltext_textspotting_data_root = 'data/totaltext'
totaltext_textspotting_data_root = 'spts-data/totaltext'
totaltext_textspotting_train = dict(
type='AdelDataset',

View File

@ -0,0 +1,15 @@
totaltext_textspotting_data_root = 'data/totaltext'
totaltext_textspotting_train = dict(
type='OCRDataset',
data_root=totaltext_textspotting_data_root,
ann_file='textspotting_train.json',
filter_cfg=dict(filter_empty_gt=True, min_size=32),
pipeline=None)
totaltext_textspotting_test = dict(
type='OCRDataset',
data_root=totaltext_textspotting_data_root,
ann_file='textspotting_test.json',
test_mode=True,
pipeline=None)

View File

@ -4,7 +4,8 @@ env_cfg = dict(
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'),
)
randomness = dict(seed=None)
randomness = dict(seed=42)
default_hooks = dict(
timer=dict(type='IterTimerHook'),

View File

@ -1,4 +1,5 @@
custom_imports = dict(imports=['spts'], allow_failed_imports=False)
custom_imports = dict(
imports=['projects.SPTS.spts'], allow_failed_imports=False)
file_client_args = dict(backend='disk')
@ -65,10 +66,7 @@ test_pipeline = [
type='LoadOCRAnnotationsWithBezier',
with_bbox=True,
with_label=True,
with_bezier=True,
with_text=True),
dict(type='Bezier2Polygon'),
dict(type='ConvertText', dictionary=dictionary),
dict(
type='PackTextDetInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
@ -87,7 +85,7 @@ train_pipeline = [
with_text=True),
dict(type='Bezier2Polygon'),
dict(type='FixInvalidPolygon'),
dict(type='ConvertText', dictionary=dictionary),
dict(type='ConvertText', dictionary=dict(**dictionary, num_bins=0)),
dict(type='RemoveIgnored'),
dict(type='RandomCrop', min_side_ratio=0.5),
dict(
@ -119,7 +117,6 @@ train_pipeline = [
hue=0.5)
],
prob=0.5),
# dict(type='Polygon2Bezier'),
dict(
type='PackTextDetInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))

View File

@ -0,0 +1,63 @@
_base_ = '_base_spts_resnet50.py'
test_pipeline = [
dict(type='LoadImageFromFile', color_type='color_ignore_orientation'),
dict(
type='RescaleToShortSide',
short_side_lens=[1000],
long_side_bound=1824),
dict(
type='LoadOCRAnnotations',
with_bbox=True,
with_label=True,
with_polygon=True,
with_text=True),
dict(
type='PackTextDetInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
]
train_pipeline = [
dict(type='LoadImageFromFile', color_type='color_ignore_orientation'),
dict(
type='LoadOCRAnnotations',
with_bbox=True,
with_label=True,
with_polygon=True,
with_text=True),
dict(type='FixInvalidPolygon'),
dict(type='RemoveIgnored'),
dict(type='RandomCrop', min_side_ratio=0.5),
dict(
type='RandomApply',
transforms=[
dict(
type='RandomRotate',
max_angle=30,
pad_with_fixed_color=True,
use_canvas=True)
],
prob=0.3),
dict(type='FixInvalidPolygon'),
dict(
type='RandomChoiceResize',
scales=[(640, 1600), (672, 1600), (704, 1600), (736, 1600),
(768, 1600), (800, 1600), (832, 1600), (864, 1600),
(896, 1600)],
keep_ratio=True),
dict(
type='RandomApply',
transforms=[
dict(
type='TorchVisionWrapper',
op='ColorJitter',
brightness=0.5,
contrast=0.5,
saturation=0.5,
hue=0.5)
],
prob=0.5),
dict(
type='PackTextDetInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor'))
]

View File

@ -1,59 +0,0 @@
_base_ = [
'_base_spts_resnet50.py',
'../_base_/datasets/ctw1500-spts.py',
'../_base_/default_runtime.py',
]
num_epochs = 350
lr = 0.00001
min_lr = 0.00001
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=lr, weight_decay=0.0001),
paramwise_cfg=dict(custom_keys={
'backbone': dict(lr_mult=0.1),
}))
train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=num_epochs, val_interval=10)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
# learning policy
param_scheduler = [
dict(type='LinearLR', end=5, start_factor=1 / 5, by_epoch=True),
dict(
type='LinearLR',
begin=5,
end=min(num_epochs,
int((lr - min_lr) / (lr / num_epochs)) + 5),
end_factor=min_lr / lr,
by_epoch=True),
]
# dataset settings
ctw1500_textspotting_train = _base_.ctw1500_textspotting_train
ctw1500_textspotting_train.pipeline = _base_.train_pipeline
ctw1500_textspotting_test = _base_.ctw1500_textspotting_test
ctw1500_textspotting_test.pipeline = _base_.test_pipeline
train_dataloader = dict(
batch_size=4,
num_workers=8,
persistent_workers=True,
sampler=dict(type='RepeatAugSampler', shuffle=True, num_repeats=2),
dataset=ctw1500_textspotting_train)
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=ctw1500_textspotting_test)
test_dataloader = val_dataloader
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
custom_imports = dict(imports='spts')

View File

@ -1,57 +0,0 @@
_base_ = [
'_base_spts_resnet50.py',
'../_base_/datasets/icdar2013-spts.py',
'../_base_/default_runtime.py',
]
num_epochs = 350
lr = 0.00001
min_lr = 0.00001
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=lr, weight_decay=0.0001),
paramwise_cfg=dict(custom_keys={
'backbone': dict(lr_mult=0.1),
}))
train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=num_epochs, val_interval=10)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
# learning policy
param_scheduler = [
dict(type='LinearLR', end=5, start_factor=1 / 5, by_epoch=True),
dict(
type='LinearLR',
begin=5,
end=min(num_epochs,
int((lr - min_lr) / (lr / num_epochs)) + 5),
end_factor=min_lr / lr,
by_epoch=True),
]
# dataset settings
icdar2013_textspotting_train = _base_.icdar2013_textspotting_train
icdar2013_textspotting_train.pipeline = _base_.train_pipeline
icdar2013_textspotting_test = _base_.icdar2013_textspotting_test
icdar2013_textspotting_test.pipeline = _base_.test_pipeline
train_dataloader = dict(
batch_size=4,
num_workers=8,
persistent_workers=True,
sampler=dict(type='RepeatAugSampler', shuffle=True, num_repeats=2),
dataset=icdar2013_textspotting_train)
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=icdar2013_textspotting_test)
test_dataloader = val_dataloader
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

View File

@ -1,57 +0,0 @@
_base_ = [
'_base_spts_resnet50.py',
'../_base_/datasets/icdar2015-spts.py',
'../_base_/default_runtime.py',
]
num_epochs = 350
lr = 0.00001
min_lr = 0.00001
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=lr, weight_decay=0.0001),
paramwise_cfg=dict(custom_keys={
'backbone': dict(lr_mult=0.1),
}))
train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=num_epochs, val_interval=10)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
# learning policy
param_scheduler = [
dict(type='LinearLR', end=5, start_factor=1 / 5, by_epoch=True),
dict(
type='LinearLR',
begin=5,
end=min(num_epochs,
int((lr - min_lr) / (lr / num_epochs)) + 5),
end_factor=min_lr / lr,
by_epoch=True),
]
# dataset settings
icdar2015_textspotting_train = _base_.icdar2015_textspotting_train
icdar2015_textspotting_train.pipeline = _base_.train_pipeline
icdar2015_textspotting_test = _base_.icdar2015_textspotting_test
icdar2015_textspotting_test.pipeline = _base_.test_pipeline
train_dataloader = dict(
batch_size=4,
num_workers=8,
persistent_workers=True,
sampler=dict(type='RepeatAugSampler', shuffle=True, num_repeats=2),
dataset=icdar2015_textspotting_train)
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=icdar2015_textspotting_test)
test_dataloader = val_dataloader
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

View File

@ -2,7 +2,6 @@ _base_ = [
'_base_spts_resnet50.py',
'../_base_/datasets/icdar2013-spts.py',
'../_base_/datasets/icdar2015-spts.py',
'../_base_/datasets/ctw1500-spts.py',
'../_base_/datasets/totaltext-spts.py',
'../_base_/datasets/syntext1-spts.py',
'../_base_/datasets/syntext2-spts.py',
@ -16,15 +15,12 @@ min_lr = 0.00001
optim_wrapper = dict(
type='OptimWrapper',
accumulative_counts=2,
optimizer=dict(type='AdamW', lr=lr, weight_decay=0.0001),
paramwise_cfg=dict(custom_keys={
'backbone': dict(lr_mult=0.1),
}))
train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=num_epochs, val_interval=20)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
type='EpochBasedTrainLoop', max_epochs=num_epochs, val_interval=5)
# learning policy
param_scheduler = [
dict(type='LinearLR', end=5, start_factor=1 / 5, by_epoch=True),
@ -39,26 +35,24 @@ param_scheduler = [
# dataset settings
train_list = [
_base_.icdar2013_textspotting_train, _base_.icdar2015_textspotting_train,
_base_.mlt_textspotting_train, _base_.totaltext_textspotting_train,
_base_.syntext1_textspotting_train, _base_.syntext2_textspotting_train,
_base_.ctw1500_textspotting_train
_base_.icdar2013_textspotting_train,
_base_.icdar2015_textspotting_train,
_base_.mlt_textspotting_train,
_base_.totaltext_textspotting_train,
_base_.syntext1_textspotting_train,
_base_.syntext2_textspotting_train,
]
train_dataset = dict(
type='ConcatDataset', datasets=train_list, pipeline=_base_.train_pipeline)
train_dataloader = dict(
batch_size=4,
num_workers=8,
persistent_workers=True,
sampler=dict(type='RepeatAugSampler', shuffle=True, num_repeats=2),
dataset=train_dataset)
val_dataloader = None
test_dataloader = None
val_evaluator = None
test_evaluator = None
val_cfg = None
test_cfg = None
train_dataloader = dict(
batch_size=8,
num_workers=8,
pin_memory=True,
persistent_workers=True,
sampler=dict(type='BatchAugSampler', shuffle=True, num_repeats=2),
dataset=train_dataset)

View File

@ -0,0 +1,87 @@
_base_ = [
'_base_spts_resnet50_mmocr.py',
'../_base_/datasets/icdar2013.py',
'../_base_/default_runtime.py',
]
load_from = 'work_dirs/spts_resnet50_150e_pretrain-spts/epoch_150.pth'
num_epochs = 200
lr = 0.00001
default_hooks = dict(
checkpoint=dict(
type='CheckpointHook',
save_best='generic/hmean',
rule='greater',
_delete_=True),
logger=dict(type='LoggerHook', interval=1))
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=lr, weight_decay=0.0001),
paramwise_cfg=dict(custom_keys={
'backbone': dict(lr_mult=0.1),
}))
train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=num_epochs, val_interval=10)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
# dataset settings
icdar2013_textspotting_train = _base_.icdar2013_textspotting_train
icdar2013_textspotting_train.pipeline = _base_.train_pipeline
icdar2013_textspotting_test = _base_.icdar2013_textspotting_test
icdar2013_textspotting_test.pipeline = _base_.test_pipeline
train_dataloader = dict(
batch_size=8,
num_workers=8,
pin_memory=True,
persistent_workers=True,
sampler=dict(type='BatchAugSampler', shuffle=True, num_repeats=2),
dataset=icdar2013_textspotting_train)
val_dataloader = dict(
batch_size=1,
num_workers=4,
pin_memory=True,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=icdar2013_textspotting_test)
test_dataloader = val_dataloader
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
val_evaluator = [
dict(
type='E2EPointMetric',
prefix='generic',
lexicon_path='data/icdar2013/lexicons/GenericVocabulary_new.txt',
pair_path='data/icdar2013/lexicons/'
'GenericVocabulary_pair_list.txt',
match_dist_thr=None),
dict(
type='E2EPointMetric',
prefix='weak',
lexicon_path='data/icdar2013/lexicons/'
'ch2_test_vocabulary_new.txt',
pair_path='data/icdar2013/lexicons/'
'ch2_test_vocabulary_pair_list.txt',
match_dist_thr=0.4),
dict(
type='E2EPointMetric',
prefix='strong',
lexicon_path='data/icdar2013/lexicons/'
'new_strong_lexicon/lexicons/',
lexicon_mapping=('(.*).jpg', r'new_voc_\1.txt'),
pair_path='data/icdar2013/lexicons/'
'new_strong_lexicon/pairs/',
pair_mapping=('(.*).jpg', r'pair_voc_\1.txt'),
match_dist_thr=0.4),
]
test_evaluator = val_evaluator

View File

@ -0,0 +1,87 @@
_base_ = [
'_base_spts_resnet50_mmocr.py',
'../_base_/datasets/icdar2015.py',
'../_base_/default_runtime.py',
]
load_from = 'work_dirs/spts_resnet50_150e_pretrain-spts/epoch_150.pth'
num_epochs = 200
lr = 0.00001
default_hooks = dict(
checkpoint=dict(
type='CheckpointHook',
save_best='generic/hmean',
rule='greater',
_delete_=True),
logger=dict(type='LoggerHook', interval=10))
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=lr, weight_decay=0.0001),
paramwise_cfg=dict(custom_keys={
'backbone': dict(lr_mult=0.1),
}))
train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=num_epochs, val_interval=10)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
# dataset settings
icdar2015_textspotting_train = _base_.icdar2015_textspotting_train
icdar2015_textspotting_train.pipeline = _base_.train_pipeline
icdar2015_textspotting_test = _base_.icdar2015_textspotting_test
icdar2015_textspotting_test.pipeline = _base_.test_pipeline
train_dataloader = dict(
batch_size=8,
num_workers=8,
pin_memory=True,
persistent_workers=True,
sampler=dict(type='BatchAugSampler', shuffle=True, num_repeats=2),
dataset=icdar2015_textspotting_train)
val_dataloader = dict(
batch_size=1,
num_workers=4,
pin_memory=True,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=icdar2015_textspotting_test)
test_dataloader = val_dataloader
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
val_evaluator = [
dict(
type='E2EPointMetric',
prefix='generic',
lexicon_path='data/icdar2015/lexicons/GenericVocabulary_new.txt',
pair_path='data/icdar2015/lexicons/'
'GenericVocabulary_pair_list.txt',
match_dist_thr=None),
dict(
type='E2EPointMetric',
prefix='weak',
lexicon_path='data/icdar2015/lexicons/'
'ch4_test_vocabulary_new.txt',
pair_path='data/icdar2015/lexicons/'
'ch4_test_vocabulary_pair_list.txt',
match_dist_thr=0.4),
dict(
type='E2EPointMetric',
prefix='strong',
lexicon_path='data/icdar2015/lexicons/'
'new_strong_lexicon/lexicons/',
lexicon_mapping=('(.*).jpg', r'new_voc_\1.txt'),
pair_path='data/icdar2015/lexicons/'
'new_strong_lexicon/pairs/',
pair_mapping=('(.*).jpg', r'pair_voc_\1.txt'),
match_dist_thr=0.4),
]
test_evaluator = val_evaluator

View File

@ -1,12 +1,21 @@
_base_ = [
'_base_spts_resnet50.py',
'../_base_/datasets/totaltext-spts.py',
'_base_spts_resnet50_mmocr.py',
'../_base_/datasets/totaltext.py',
'../_base_/default_runtime.py',
]
num_epochs = 350
load_from = 'work_dirs/spts_resnet50_150e_pretrain-spts/epoch_150.pth'
num_epochs = 200
lr = 0.00001
min_lr = 0.00001
default_hooks = dict(
checkpoint=dict(
type='CheckpointHook',
save_best='none/hmean',
rule='greater',
_delete_=True),
logger=dict(type='LoggerHook', interval=10))
optim_wrapper = dict(
type='OptimWrapper',
@ -19,17 +28,6 @@ train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=num_epochs, val_interval=10)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
# learning policy
param_scheduler = [
dict(type='LinearLR', end=5, start_factor=1 / 5, by_epoch=True),
dict(
type='LinearLR',
begin=5,
end=min(num_epochs,
int((lr - min_lr) / (lr / num_epochs)) + 5),
end_factor=min_lr / lr,
by_epoch=True),
]
# dataset settings
totaltext_textspotting_train = _base_.totaltext_textspotting_train
@ -38,16 +36,18 @@ totaltext_textspotting_test = _base_.totaltext_textspotting_test
totaltext_textspotting_test.pipeline = _base_.test_pipeline
train_dataloader = dict(
batch_size=4,
batch_size=8,
num_workers=8,
persistent_workers=True,
sampler=dict(type='RepeatAugSampler', shuffle=True, num_repeats=2),
pin_memory=True,
sampler=dict(type='BatchAugSampler', shuffle=True, num_repeats=2),
dataset=totaltext_textspotting_train)
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
pin_memory=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=totaltext_textspotting_test)
@ -55,3 +55,21 @@ test_dataloader = val_dataloader
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
val_evaluator = [
dict(
type='E2EPointMetric',
prefix='none',
word_spotting=True,
match_dist_thr=0.4),
dict(
type='E2EPointMetric',
prefix='full',
lexicon_path='data/totaltext/lexicons/weak_voc_new.txt',
pair_path='data/totaltext/lexicons/'
'weak_voc_pair_list.txt',
word_spotting=True,
match_dist_thr=0.4),
]
test_evaluator = val_evaluator

View File

@ -40,6 +40,7 @@ class AdelDataset(CocoDataset):
None img. The maximum extra number of cycles to get a valid
image. Defaults to 1000.
"""
METAINFO = {'classes': ('text', )}
def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]:
"""Parse raw annotation to target format.

View File

@ -1,14 +1,20 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Sequence
import glob
import os.path as osp
import re
from typing import Dict, List, Optional, Sequence, Tuple
import numpy as np
import torch
from mmengine.evaluator import BaseMetric
from mmengine.logging import MMLogger
from shapely.geometry import LineString, Point
from rapidfuzz.distance import Levenshtein
from shapely.geometry import Point
from mmocr.registry import METRICS
# TODO: CTW1500 read pair
@METRICS.register_module()
class E2EPointMetric(BaseMetric):
@ -17,7 +23,20 @@ class E2EPointMetric(BaseMetric):
Args:
text_score_thrs (dict): Best text score threshold searching
space. Defaults to dict(start=0.8, stop=1, step=0.01).
TODO: docstr
word_spotting (bool): Whether to work in word spotting mode. Defaults
to False.
lexicon_path (str, optional): Lexicon path for word spotting, which
points to a lexicon file or a directory. Defaults to None.
lexicon_mapping (tuple, optional): The rule to map test image name to
its corresponding lexicon file. Only effective when lexicon path
is a directory. Defaults to ('(.*).jpg', r'\1.txt').
pair_path (str, optional): Pair path for word spotting, which points
to a pair file or a directory. Defaults to None.
pair_mapping (tuple, optional): The rule to map test image name to
its corresponding pair file. Only effective when pair path is a
directory. Defaults to ('(.*).jpg', r'\1.txt').
match_dist_thr (float, optional): Matching distance threshold for
word spotting. Defaults to None.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
@ -31,20 +50,52 @@ class E2EPointMetric(BaseMetric):
def __init__(self,
text_score_thrs: Dict = dict(start=0.8, stop=1, step=0.01),
word_spotting: bool = False,
lexicon_path: Optional[str] = None,
lexicon_mapping: Tuple[str, str] = ('(.*).jpg', r'\1.txt'),
pair_path: Optional[str] = None,
pair_mapping: Tuple[str, str] = ('(.*).jpg', r'\1.txt'),
match_dist_thr: Optional[float] = None,
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
super().__init__(collect_device=collect_device, prefix=prefix)
self.text_score_thrs = np.arange(**text_score_thrs)
self.word_spotting = word_spotting
self.match_dist_thr = match_dist_thr
if lexicon_path:
self.lexicon_mapping = lexicon_mapping
self.pair_mapping = pair_mapping
self.lexicons = self._read_lexicon(lexicon_path)
self.pairs = self._read_pair(pair_path)
def _read_lexicon(self, lexicon_path: str) -> List[str]:
if lexicon_path.endswith('.txt'):
lexicon = open(lexicon_path, 'r').read().splitlines()
lexicon = [ele.strip() for ele in lexicon]
else:
lexicon = {}
for file in glob.glob(osp.join(lexicon_path, '*.txt')):
basename = osp.basename(file)
lexicon[basename] = self._read_lexicon(file)
return lexicon
def _read_pair(self, pair_path: str) -> Dict[str, str]:
pairs = {}
if pair_path.endswith('.txt'):
pair_lines = open(pair_path, 'r').read().splitlines()
for line in pair_lines:
line = line.strip()
word = line.split(' ')[0].upper()
word_gt = line[len(word) + 1:]
pairs[word] = word_gt
else:
for file in glob.glob(osp.join(pair_path, '*.txt')):
basename = osp.basename(file)
pairs[basename] = self._read_pair(file)
return pairs
def poly_center(self, poly_pts):
poly_pts = np.array(poly_pts).reshape(-1, 2)
num_points = poly_pts.shape[0]
line1 = LineString(poly_pts[int(num_points / 2):])
line2 = LineString(poly_pts[:int(num_points / 2)])
mid_pt1 = np.array(line1.interpolate(0.5, normalized=True).coords[0])
mid_pt2 = np.array(line2.interpolate(0.5, normalized=True).coords[0])
return (mid_pt1 + mid_pt2) / 2
return poly_pts.mean(0)
def process(self, data_batch: Sequence[Dict],
data_samples: Sequence[Dict]) -> None:
@ -87,6 +138,8 @@ class E2EPointMetric(BaseMetric):
~pred_ignore_flags)
result = dict(
# reserved for image-level lexcions
gt_img_name=osp.basename(data_sample.get('img_path', '')),
text_scores=text_scores,
pred_points=pred_points,
gt_points=gt_points,
@ -125,10 +178,33 @@ class E2EPointMetric(BaseMetric):
gt_texts = result['gt_texts']
pred_texts = result['pred_texts']
gt_ignore_flags = result['gt_ignore_flags']
gt_img_name = result['gt_img_name']
# Correct the words with lexicon
pred_dist_flags = np.zeros(len(pred_texts), dtype=bool)
if hasattr(self, 'lexicons'):
for i, pred_text in enumerate(pred_texts):
# If it's an image-level lexicon
if isinstance(self.lexicons, dict):
lexicon_name = self._map_img_name(
gt_img_name, self.lexicon_mapping)
pair_name = self._map_img_name(gt_img_name,
self.pair_mapping)
pred_texts[i], match_dist = self._match_word(
pred_text, self.lexicons[lexicon_name],
self.pairs[pair_name])
else:
pred_texts[i], match_dist = self._match_word(
pred_text, self.lexicons, self.pairs)
if (self.match_dist_thr
and match_dist >= self.match_dist_thr):
# won't even count this as a prediction
pred_dist_flags[i] = True
# Filter out predictions by IoU threshold
for i, text_score_thr in enumerate(self.text_score_thrs):
pred_ignore_flags = text_scores < text_score_thr
pred_ignore_flags = pred_dist_flags | (
text_scores < text_score_thr)
filtered_pred_texts = self._get_true_elements(
pred_texts, ~pred_ignore_flags)
filtered_pred_points = self._get_true_elements(
@ -148,10 +224,8 @@ class E2EPointMetric(BaseMetric):
min_idx = np.argmin(dists)
if gt_texts[min_idx] == '###' or gt_ignore_flags[min_idx]:
continue
# if not gt_matched[min_idx] and self.text_match(
# gt_texts[min_idx].upper(), pred_text.upper()):
if (not gt_matched[min_idx] and gt_texts[min_idx].upper()
== pred_text.upper()):
if not gt_matched[min_idx] and (
pred_text.upper() == gt_texts[min_idx].upper()):
gt_matched[min_idx] = True
num_tp[i] += 1
num_preds[i] += 1
@ -173,6 +247,69 @@ class E2EPointMetric(BaseMetric):
best_eval_results = eval_results
return best_eval_results
def _map_img_name(self, img_name: str, mapping: Tuple[str, str]) -> str:
"""Map the image name to the another one based on mapping."""
return re.sub(mapping[0], mapping[1], img_name)
def _true_indexes(self, array: np.ndarray) -> np.ndarray:
"""Get indexes of True elements from a 1D boolean array."""
return np.where(array)[0]
def _word_spotting_filter(self, gt_ignore_flags: np.ndarray,
gt_texts: List[str]
) -> Tuple[np.ndarray, List[str]]:
"""Filter out gt instances that cannot be in a valid dictionary, and do
some simple preprocessing to texts."""
for i in range(len(gt_texts)):
if gt_ignore_flags[i]:
continue
text = gt_texts[i]
if text[-2:] in ["'s", "'S"]:
text = text[:-2]
text = text.strip('-')
for char in "'!?.:,*\"()·[]/":
text = text.replace(char, ' ')
text = text.strip()
gt_ignore_flags[i] = not self._include_in_dict(text)
if not gt_ignore_flags[i]:
gt_texts[i] = text
return gt_ignore_flags, gt_texts
def _include_in_dict(self, text: str) -> bool:
"""Check if the text could be in a valid dictionary."""
if len(text) != len(text.replace(' ', '')) or len(text) < 3:
return False
not_allowed = '×÷·'
valid_ranges = [(ord(u'a'), ord(u'z')), (ord(u'A'), ord(u'Z')),
(ord(u'À'), ord(u'ƿ')), (ord(u'DŽ'), ord(u'ɿ')),
(ord(u'Ά'), ord(u'Ͽ')), (ord(u'-'), ord(u'-'))]
for char in text:
code = ord(char)
if (not_allowed.find(char) != -1):
return False
valid = any(code >= r[0] and code <= r[1] for r in valid_ranges)
if not valid:
return False
return True
def _match_word(self,
text: str,
lexicons: List[str],
pairs: Optional[Dict[str, str]] = None) -> Tuple[str, int]:
"""Match the text with the lexicons and pairs."""
text = text.upper()
matched_word = ''
matched_dist = 100
for lexicon in lexicons:
lexicon = lexicon.upper()
norm_dist = Levenshtein.distance(text, lexicon)
norm_dist = Levenshtein.normalized_distance(text, lexicon)
if norm_dist < matched_dist:
matched_dist = norm_dist
if pairs:
matched_word = pairs[lexicon]
else:
matched_word = lexicon
return matched_word, matched_dist

View File

@ -58,18 +58,15 @@ class SPTSDecoder(BaseDecoder):
max_seq_len=self.max_seq_len,
init_cfg=init_cfg)
self.num_bins = num_bins
self.shifted_seq_end_idx = self.num_bins + self.dictionary.seq_end_idx
self.shifted_start_idx = self.num_bins + self.dictionary.start_idx
actual_num_classes = self.dictionary.num_classes + num_bins
self.embedding = DecoderEmbeddings(
actual_num_classes, self.dictionary.padding_idx + num_bins,
d_model, self.max_seq_len, dropout)
self.embedding = DecoderEmbeddings(self.dictionary.num_classes,
self.dictionary.padding_idx,
d_model, self.max_seq_len, dropout)
self.pos_embedding = PositionEmbeddingSine(d_model // 2)
self.vocab_embed = self._gen_vocab_embed(d_model, d_model,
actual_num_classes, 3)
self.dictionary.num_classes,
3)
encoder_layer = TransformerEncoderLayer(d_model, n_head, d_feedforward,
dropout, 'relu',
normalize_before)
@ -166,7 +163,7 @@ class SPTSDecoder(BaseDecoder):
max_probs = []
seq = torch.zeros(
batch_size, 1, dtype=torch.long).to(
out_enc.device) + self.shifted_start_idx
out_enc.device) + self.dictionary.start_idx
for i in range(self.max_seq_len):
tgt = self.embedding(seq).permute(1, 0, 2)
hs = self.decoder(
@ -182,13 +179,13 @@ class SPTSDecoder(BaseDecoder):
# bins chars unk eos seq_eos sos padding
if i % 27 == 0: # coordinate or eos
out[:, self.num_bins:self.shifted_seq_end_idx] = 0
out[:, self.shifted_seq_end_idx + 1:] = 0
out[:, self.num_bins:self.dictionary.seq_end_idx] = 0
out[:, self.dictionary.seq_end_idx + 1:] = 0
elif i % 27 == 1: # coordinate
out[:, self.num_bins:] = 0
else: # chars
out[:, :self.num_bins] = 0
out[:, self.shifted_seq_end_idx:] = 0
out[:, self.dictionary.seq_end_idx:] = 0
max_prob, extra_seq = torch.max(out, dim=-1, keepdim=True)
# prob, extra_seq = out.topk(dim=-1, k=1)
@ -196,7 +193,7 @@ class SPTSDecoder(BaseDecoder):
# TODO: optimize for multi-batch
seq = torch.cat([seq, extra_seq], dim=-1)
max_probs.append(max_prob)
if extra_seq[0] == self.shifted_seq_end_idx:
if extra_seq[0] == self.dictionary.seq_end_idx:
break
max_probs = torch.cat(max_probs, dim=-1)

View File

@ -15,6 +15,8 @@ class SPTSDictionary(Dictionary):
Args:
dict_file (str): The path of Character dict file which a single
character must occupies a line.
num_bins (int): Number of bins dividing the image, which is used to
shift the character indexes. Defaults to 1000.
with_start (bool): The flag to control whether to include the start
token. Defaults to False.
with_end (bool): The flag to control whether to include the end token.
@ -45,6 +47,7 @@ class SPTSDictionary(Dictionary):
def __init__(
self,
dict_file: str,
num_bins: int = 1000,
with_start: bool = False,
with_end: bool = False,
with_seq_end: bool = False,
@ -74,6 +77,26 @@ class SPTSDictionary(Dictionary):
padding_token=padding_token,
unknown_token=unknown_token)
self.num_bins = num_bins
self._shift_idx()
@property
def num_classes(self) -> int:
"""int: Number of output classes. Special tokens are counted.
"""
return len(self._dict) + self.num_bins
def _shift_idx(self):
idx_terms = [
'start_idx', 'end_idx', 'unknown_idx', 'seq_end_idx', 'padding_idx'
]
for term in idx_terms:
value = getattr(self, term)
if value:
setattr(self, term, value + self.num_bins)
for char in self._dict:
self._char2idx[char] += self.num_bins
def _update_dict(self):
"""Update the dict with tokens according to parameters."""
# BOS/EOS
@ -129,10 +152,11 @@ class SPTSDictionary(Dictionary):
assert isinstance(index, (list, tuple))
string = ''
for i in index:
assert i < len(self._dict), f'Index: {i} out of range! Index ' \
f'must be less than {len(self._dict)}'
assert i < self.num_classes, f'Index: {i} out of range! Index ' \
f'must be less than {self.num_classes}'
# TODO: find its difference from ignore_chars
# in TextRecogPostprocessor
if self._dict[i] is not None:
string += self._dict[i]
shifted_i = i - self.num_bins
if self._dict[shifted_i] is not None:
string += self._dict[shifted_i]
return string

View File

@ -81,7 +81,7 @@ class SPTSModuleLoss(CEModuleLoss):
self.max_num_text = (self.max_seq_len - 1) // (2 + max_text_len)
self.num_bins = num_bins
weights = torch.ones(self.dictionary.num_classes + num_bins)
weights = torch.ones(self.dictionary.num_classes, dtype=torch.float32)
weights[self.dictionary.seq_end_idx] = seq_eos_coef
weights.requires_grad_ = False
self.loss_ce = nn.CrossEntropyLoss(
@ -117,18 +117,24 @@ class SPTSModuleLoss(CEModuleLoss):
if data_sample.get('have_target', False):
continue
if len(data_sample.gt_instances.polygons) > self.max_num_text:
if len(data_sample.gt_instances) > self.max_num_text:
keep = random.sample(
range(len(data_sample.gt_instances['polygons'])),
self.max_num_text)
range(len(data_sample.gt_instances)), self.max_num_text)
data_sample.gt_instances = data_sample.gt_instances[keep]
gt_instances = data_sample.gt_instances
if len(gt_instances.polygons) > 0:
if len(gt_instances) > 0:
center_pts = []
# Slightly different from the original implementation
# which gets the center points from bezier curves
# for bezier_pt in gt_instances.beziers:
# bezier_pt = bezier_pt.reshape(8, 2)
# mid_pt1 = sample_bezier_curve(
# bezier_pt[:4], mid_point=True)
# mid_pt2 = sample_bezier_curve(
# bezier_pt[4:], mid_point=True)
# center_pt = (mid_pt1 + mid_pt2) / 2
for polygon in gt_instances.polygons:
center_pt = polygon.reshape(-1, 2).mean(0)
center_pts.append(center_pt)
@ -152,7 +158,7 @@ class SPTSModuleLoss(CEModuleLoss):
dtype=torch.long) + self.dictionary.end_idx
max_len = min(self.max_text_len - 1, len(indexes))
indexes_tensor[:max_len] = torch.LongTensor(indexes)[:max_len]
indexes_tensor = indexes_tensor + self.num_bins
indexes_tensor = indexes_tensor
gt_indexes.append(indexes_tensor)
if len(gt_indexes) == 0:
@ -164,15 +170,12 @@ class SPTSModuleLoss(CEModuleLoss):
if self.dictionary.start_idx is not None:
gt_indexes = torch.cat([
torch.LongTensor(
[self.dictionary.start_idx + self.num_bins]),
gt_indexes
torch.LongTensor([self.dictionary.start_idx]), gt_indexes
])
if self.dictionary.seq_end_idx is not None:
gt_indexes = torch.cat([
gt_indexes,
torch.LongTensor(
[self.dictionary.seq_end_idx + self.num_bins])
torch.LongTensor([self.dictionary.seq_end_idx])
])
batch_max_len = max(batch_max_len, len(gt_indexes))
@ -190,7 +193,7 @@ class SPTSModuleLoss(CEModuleLoss):
padded_indexes = (
torch.zeros(batch_max_len, dtype=torch.long) +
self.dictionary.padding_idx + self.num_bins)
self.dictionary.padding_idx)
padded_indexes[:len(indexes)] = indexes
data_sample.gt_instances.set_metainfo(
dict(padded_indexes=padded_indexes))

View File

@ -102,12 +102,11 @@ class SPTSPostprocessor(BaseTextRecogPostprocessor):
for char_index, char_score in zip(output_index[2:],
output_score[2:]):
# the first num_bins indexes are for points
dict_idx = char_index - self.num_bins
if dict_idx in self.ignore_indexes:
if char_index in self.ignore_indexes:
continue
if dict_idx == self.dictionary.end_idx:
if char_index == self.dictionary.end_idx:
break
indexes[-1].append(dict_idx)
indexes[-1].append(char_index)
char_scores.append(char_score)
text_scores.append(np.mean(char_scores).item())
return indexes, text_scores, points, pt_scores