diff --git a/mmocr/utils/polygon_utils.py b/mmocr/utils/polygon_utils.py index 4a103ca4..7c6b857f 100644 --- a/mmocr/utils/polygon_utils.py +++ b/mmocr/utils/polygon_utils.py @@ -152,8 +152,8 @@ def crop_polygon(polygon: ArrayLike, np.array or None: Cropped polygon. If the polygon is not within the crop box, return None. """ - poly = poly2shapely(polygon) - crop_poly = poly2shapely(bbox2poly(crop_box)) + poly = poly_make_valid(poly2shapely(polygon)) + crop_poly = poly_make_valid(poly2shapely(bbox2poly(crop_box))) poly_cropped = poly.intersection(crop_poly) if poly_cropped.area == 0. or not isinstance( poly_cropped, shapely.geometry.polygon.Polygon): diff --git a/projects/SPTS/README.md b/projects/SPTS/README.md index 19847605..8401f36b 100644 --- a/projects/SPTS/README.md +++ b/projects/SPTS/README.md @@ -36,10 +36,24 @@ $env:PYTHONPATH=Get-Location ### Dataset -**As of now, the implementation uses datasets provided by SPTS, but these datasets -will be available in MMOCR's dataset preparer very soon.** +As of now, the implementation uses datasets provided by SPTS for pre-training, and uses MMOCR's datasets for fine-tuning and testing. It's because the test split of SPTS's datasets does not contain enough information for e2e evaluation; and MMOCR's dataset preparer has not yet supported all the datasets used in SPTS. *We are working on this issue, and they will be available in MMOCR's dataset preparer very soon.* -Download and extract all the datasets into `data/` following [SPTS official guide](https://github.com/shannanyinxiang/SPTS#dataset). +Please follow these steps to prepare the datasets: + +- Download and extract all the SPTS datasets into `spts-data/` following [SPTS official guide](https://github.com/shannanyinxiang/SPTS#dataset). + +- Use [Dataset Preparer](https://mmocr.readthedocs.io/en/dev-1.x/user_guides/data_prepare/dataset_preparer.html) to prepare `icdar2013`, `icdar2015` and `totaltext` for `textspotting` task. + + ```shell + # Run in MMOCR's root directory + python tools/dataset_converters/prepare_dataset.py icdar2013 icdar2015 totaltext --task textspotting + ``` + + Then create a soft link to `data/` directory in the project root directory: + + ```shell + ln -s ../../data/ . + ``` ### Training commands @@ -48,13 +62,13 @@ In the current directory, run the following command to train the model: #### Pretrain ```bash -mim train mmocr config/spts/spts_resnet50_150e_pretrain-spts.py --work-dir work_dirs/ +mim train mmocr config/spts/spts_resnet50_150e_pretrain-spts.py --work-dir work_dirs/ --amp ``` To train on multiple GPUs, e.g. 8 GPUs, run the following command: ```bash -mim train mmocr config/spts/spts_resnet50_150e_pretrain-spts.py --work-dir work_dirs/ --launcher pytorch --gpus 8 +mim train mmocr config/spts/spts_resnet50_150e_pretrain-spts.py --work-dir work_dirs/ --launcher pytorch --gpus 8 --amp ``` #### Finetune @@ -62,13 +76,13 @@ mim train mmocr config/spts/spts_resnet50_150e_pretrain-spts.py --work-dir work_ Similarly, run the following command to finetune the model on a dataset (e.g. icdar2013): ```bash -mim train mmocr config/spts/spts_resnet50_350e_icdar2013-spts.py --work-dir work_dirs/ --cfg-options "load_from={CHECKPOINT_PATH}" +mim train mmocr config/spts/spts_resnet50_8xb8-200e_icdar2013.py --work-dir work_dirs/ --cfg-options "load_from={CHECKPOINT_PATH}" --amp ``` To finetune on multiple GPUs, e.g. 8 GPUs, run the following command: ```bash -mim train mmocr config/spts/spts_resnet50_350e_icdar2013-spts.py --work-dir work_dirs/ --launcher pytorch --gpus 8 --cfg-options "load_from={CHECKPOINT_PATH}" +mim train mmocr config/spts/spts_resnet50_8xb8-200e_icdar2013.py --work-dir work_dirs/ --launcher pytorch --gpus 8 --cfg-options "load_from={CHECKPOINT_PATH}" --amp ``` ### Testing commands @@ -76,24 +90,29 @@ mim train mmocr config/spts/spts_resnet50_350e_icdar2013-spts.py --work-dir work In the current directory, run the following command to test the model on a dataset (e.g. icdar2013): ```bash -mim test mmocr config/spts/spts_resnet50_350e_icdar2013-spts.py --work-dir work_dirs/ --checkpoint ${CHECKPOINT_PATH} +mim test mmocr config/spts/spts_resnet50_8xb8-200e_icdar2013.py --work-dir work_dirs/ --checkpoint ${CHECKPOINT_PATH} ``` -## Results +## Convert Weights from Official Repo -The weights from MMOCR are on the way. Users may download the weights from [SPTS](https://github.com/shannanyinxiang/SPTS#inference) and use the conversion script to convert them into MMOCR format. +Users may download the weights from [SPTS](https://github.com/shannanyinxiang/SPTS#inference) and use the conversion script to convert them into MMOCR format. ```bash python tools/ckpt_adapter.py [SPTS_WEIGHTS_PATH] [MMOCR_WEIGHTS_PATH] ``` -Here are the results obtained on the converted weights. The results are lower than the original ones due to the difference in the test split of datasets, which will be addressed in next update. +## Results -| Name | Model | E2E-None-Hmean | -| :--------: | :-------------------: | :------------: | -| ICDAR 2013 | ic13.pth (converted) | 0.8573 | -| ctw1500 | ctw1500 (converted) | 0.6304 | -| totaltext | totaltext (converted) | 0.6596 | +All the models are trained on 8x A100 GPUs with AMP on (`--amp`). The overall batch size is 64. + +| Name | Pretrained | Generic | Weak | Strong | Download | +| ---------- | --------------------------------------------------------------------------------------- | ------- | ----- | ------ | ------------------------------------------------------------------------------------- | +| ICDAR 2013 | [model](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_150e_pretrain-spts/spts_resnet50_150e_pretrain-spts-c9fe4c78.pth) / [log](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_150e_pretrain-spts/20230223_194550.log) | 87.10 | 91.46 | 93.41 | [model](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_200e_icdar2013/spts_resnet50_200e_icdar2013-64cb4d31.pth) / [log](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_200e_icdar2013/20230303_140316.log) | +| ICDAR 2015 | [model](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_150e_pretrain-spts/spts_resnet50_150e_pretrain-spts-c9fe4c78.pth) / [log](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_150e_pretrain-spts/20230223_194550.log) | 69.09 | 73.45 | 79.19 | [model](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_200e_icdar2015/spts_resnet50_200e_icdar2015-d6e8621c.pth) / [log](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_200e_icdar2015/20230302_230026.log) | + +| Name | Pretrained | None-Hmean | Full-Hmean | Download | +| :-------: | -------------------------------------------------------------------------------------- | :--------: | :--------: | ------------------------------------------------------------------------------------- | +| Totaltext | [model](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_150e_pretrain-spts/spts_resnet50_150e_pretrain-spts-c9fe4c78.pth) / [log](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_150e_pretrain-spts/20230223_194550.log) | 73.99 | 82.34 | [model](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_200e_totaltext/spts_resnet50_200e_totaltext-e3521af6.pth) / [log](https://download.openmmlab.com/mmocr/textspotting/spts/spts_resnet50_200e_totaltext/20230303_103040.log) | ## Citation @@ -136,15 +155,15 @@ A project does not necessarily have to be finished in a single PR, but it's esse -- [ ] Milestone 2: Indicates a successful model implementation. +- [x] Milestone 2: Indicates a successful model implementation. - - [ ] Training-time correctness + - [x] Training-time correctness -- [ ] Milestone 3: Good to be a part of our core package! +- [x] Milestone 3: Good to be a part of our core package! - - [ ] Type hints and docstrings + - [x] Type hints and docstrings diff --git a/projects/SPTS/config/_base_/datasets/icdar2013-spts.py b/projects/SPTS/config/_base_/datasets/icdar2013-spts.py index 87304803..61f57d44 100644 --- a/projects/SPTS/config/_base_/datasets/icdar2013-spts.py +++ b/projects/SPTS/config/_base_/datasets/icdar2013-spts.py @@ -1,4 +1,4 @@ -icdar2013_textspotting_data_root = 'data/icdar2013' +icdar2013_textspotting_data_root = 'spts-data/icdar2013' icdar2013_textspotting_train = dict( type='AdelDataset', diff --git a/projects/SPTS/config/_base_/datasets/icdar2015-spts.py b/projects/SPTS/config/_base_/datasets/icdar2015-spts.py index 6ea93cb5..df0139ad 100644 --- a/projects/SPTS/config/_base_/datasets/icdar2015-spts.py +++ b/projects/SPTS/config/_base_/datasets/icdar2015-spts.py @@ -1,4 +1,4 @@ -icdar2015_textspotting_data_root = 'data/icdar2015' +icdar2015_textspotting_data_root = 'spts-data/icdar2015' icdar2015_textspotting_train = dict( type='AdelDataset', diff --git a/projects/SPTS/config/_base_/datasets/icdar2015.py b/projects/SPTS/config/_base_/datasets/icdar2015.py index 240f1347..f71a7214 100644 --- a/projects/SPTS/config/_base_/datasets/icdar2015.py +++ b/projects/SPTS/config/_base_/datasets/icdar2015.py @@ -11,5 +11,4 @@ icdar2015_textspotting_test = dict( data_root=icdar2015_textspotting_data_root, ann_file='textspotting_test.json', test_mode=True, - # indices=50, pipeline=None) diff --git a/projects/SPTS/config/_base_/datasets/mlt-spts.py b/projects/SPTS/config/_base_/datasets/mlt-spts.py index b99fe147..22d45039 100644 --- a/projects/SPTS/config/_base_/datasets/mlt-spts.py +++ b/projects/SPTS/config/_base_/datasets/mlt-spts.py @@ -1,4 +1,4 @@ -mlt_textspotting_data_root = 'data/mlt2017' +mlt_textspotting_data_root = 'spts-data/mlt2017' mlt_textspotting_train = dict( type='AdelDataset', diff --git a/projects/SPTS/config/_base_/datasets/syntext1-spts.py b/projects/SPTS/config/_base_/datasets/syntext1-spts.py index 19b6be64..d30df532 100644 --- a/projects/SPTS/config/_base_/datasets/syntext1-spts.py +++ b/projects/SPTS/config/_base_/datasets/syntext1-spts.py @@ -1,4 +1,4 @@ -syntext1_textspotting_data_root = 'data/syntext1' +syntext1_textspotting_data_root = 'spts-data/syntext1' syntext1_textspotting_train = dict( type='AdelDataset', diff --git a/projects/SPTS/config/_base_/datasets/syntext2-spts.py b/projects/SPTS/config/_base_/datasets/syntext2-spts.py index 53709ead..6fb06e30 100644 --- a/projects/SPTS/config/_base_/datasets/syntext2-spts.py +++ b/projects/SPTS/config/_base_/datasets/syntext2-spts.py @@ -1,4 +1,4 @@ -syntext2_textspotting_data_root = 'data/syntext2' +syntext2_textspotting_data_root = 'spts-data/syntext2' syntext2_textspotting_train = dict( type='AdelDataset', diff --git a/projects/SPTS/config/_base_/datasets/totaltext-spts.py b/projects/SPTS/config/_base_/datasets/totaltext-spts.py index ae6a2316..37bea881 100644 --- a/projects/SPTS/config/_base_/datasets/totaltext-spts.py +++ b/projects/SPTS/config/_base_/datasets/totaltext-spts.py @@ -1,4 +1,4 @@ -totaltext_textspotting_data_root = 'data/totaltext' +totaltext_textspotting_data_root = 'spts-data/totaltext' totaltext_textspotting_train = dict( type='AdelDataset', diff --git a/projects/SPTS/config/_base_/datasets/totaltext.py b/projects/SPTS/config/_base_/datasets/totaltext.py new file mode 100644 index 00000000..ddc8f32f --- /dev/null +++ b/projects/SPTS/config/_base_/datasets/totaltext.py @@ -0,0 +1,15 @@ +totaltext_textspotting_data_root = 'data/totaltext' + +totaltext_textspotting_train = dict( + type='OCRDataset', + data_root=totaltext_textspotting_data_root, + ann_file='textspotting_train.json', + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=None) + +totaltext_textspotting_test = dict( + type='OCRDataset', + data_root=totaltext_textspotting_data_root, + ann_file='textspotting_test.json', + test_mode=True, + pipeline=None) diff --git a/projects/SPTS/config/_base_/default_runtime.py b/projects/SPTS/config/_base_/default_runtime.py index edd2219a..22657075 100644 --- a/projects/SPTS/config/_base_/default_runtime.py +++ b/projects/SPTS/config/_base_/default_runtime.py @@ -4,7 +4,8 @@ env_cfg = dict( mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), dist_cfg=dict(backend='nccl'), ) -randomness = dict(seed=None) + +randomness = dict(seed=42) default_hooks = dict( timer=dict(type='IterTimerHook'), diff --git a/projects/SPTS/config/spts/_base_spts_resnet50.py b/projects/SPTS/config/spts/_base_spts_resnet50.py index 02e95300..895df0b3 100644 --- a/projects/SPTS/config/spts/_base_spts_resnet50.py +++ b/projects/SPTS/config/spts/_base_spts_resnet50.py @@ -1,4 +1,5 @@ -custom_imports = dict(imports=['spts'], allow_failed_imports=False) +custom_imports = dict( + imports=['projects.SPTS.spts'], allow_failed_imports=False) file_client_args = dict(backend='disk') @@ -65,10 +66,7 @@ test_pipeline = [ type='LoadOCRAnnotationsWithBezier', with_bbox=True, with_label=True, - with_bezier=True, with_text=True), - dict(type='Bezier2Polygon'), - dict(type='ConvertText', dictionary=dictionary), dict( type='PackTextDetInputs', meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor')) @@ -87,7 +85,7 @@ train_pipeline = [ with_text=True), dict(type='Bezier2Polygon'), dict(type='FixInvalidPolygon'), - dict(type='ConvertText', dictionary=dictionary), + dict(type='ConvertText', dictionary=dict(**dictionary, num_bins=0)), dict(type='RemoveIgnored'), dict(type='RandomCrop', min_side_ratio=0.5), dict( @@ -119,7 +117,6 @@ train_pipeline = [ hue=0.5) ], prob=0.5), - # dict(type='Polygon2Bezier'), dict( type='PackTextDetInputs', meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor')) diff --git a/projects/SPTS/config/spts/_base_spts_resnet50_mmocr.py b/projects/SPTS/config/spts/_base_spts_resnet50_mmocr.py new file mode 100644 index 00000000..f4242e36 --- /dev/null +++ b/projects/SPTS/config/spts/_base_spts_resnet50_mmocr.py @@ -0,0 +1,63 @@ +_base_ = '_base_spts_resnet50.py' + +test_pipeline = [ + dict(type='LoadImageFromFile', color_type='color_ignore_orientation'), + dict( + type='RescaleToShortSide', + short_side_lens=[1000], + long_side_bound=1824), + dict( + type='LoadOCRAnnotations', + with_bbox=True, + with_label=True, + with_polygon=True, + with_text=True), + dict( + type='PackTextDetInputs', + meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor')) +] + +train_pipeline = [ + dict(type='LoadImageFromFile', color_type='color_ignore_orientation'), + dict( + type='LoadOCRAnnotations', + with_bbox=True, + with_label=True, + with_polygon=True, + with_text=True), + dict(type='FixInvalidPolygon'), + dict(type='RemoveIgnored'), + dict(type='RandomCrop', min_side_ratio=0.5), + dict( + type='RandomApply', + transforms=[ + dict( + type='RandomRotate', + max_angle=30, + pad_with_fixed_color=True, + use_canvas=True) + ], + prob=0.3), + dict(type='FixInvalidPolygon'), + dict( + type='RandomChoiceResize', + scales=[(640, 1600), (672, 1600), (704, 1600), (736, 1600), + (768, 1600), (800, 1600), (832, 1600), (864, 1600), + (896, 1600)], + keep_ratio=True), + dict( + type='RandomApply', + transforms=[ + dict( + type='TorchVisionWrapper', + op='ColorJitter', + brightness=0.5, + contrast=0.5, + saturation=0.5, + hue=0.5) + ], + prob=0.5), + dict( + type='PackTextDetInputs', + meta_keys=('img_path', 'ori_shape', 'img_shape', 'scale_factor')) +] diff --git a/projects/SPTS/config/spts/spts_resnet50_350e_ctw1500-spts.py b/projects/SPTS/config/spts/spts_resnet50_350e_ctw1500-spts.py deleted file mode 100644 index 72a91a5c..00000000 --- a/projects/SPTS/config/spts/spts_resnet50_350e_ctw1500-spts.py +++ /dev/null @@ -1,59 +0,0 @@ -_base_ = [ - '_base_spts_resnet50.py', - '../_base_/datasets/ctw1500-spts.py', - '../_base_/default_runtime.py', -] - -num_epochs = 350 -lr = 0.00001 -min_lr = 0.00001 - -optim_wrapper = dict( - type='OptimWrapper', - optimizer=dict(type='AdamW', lr=lr, weight_decay=0.0001), - paramwise_cfg=dict(custom_keys={ - 'backbone': dict(lr_mult=0.1), - })) - -train_cfg = dict( - type='EpochBasedTrainLoop', max_epochs=num_epochs, val_interval=10) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') -# learning policy -param_scheduler = [ - dict(type='LinearLR', end=5, start_factor=1 / 5, by_epoch=True), - dict( - type='LinearLR', - begin=5, - end=min(num_epochs, - int((lr - min_lr) / (lr / num_epochs)) + 5), - end_factor=min_lr / lr, - by_epoch=True), -] - -# dataset settings -ctw1500_textspotting_train = _base_.ctw1500_textspotting_train -ctw1500_textspotting_train.pipeline = _base_.train_pipeline -ctw1500_textspotting_test = _base_.ctw1500_textspotting_test -ctw1500_textspotting_test.pipeline = _base_.test_pipeline - -train_dataloader = dict( - batch_size=4, - num_workers=8, - persistent_workers=True, - sampler=dict(type='RepeatAugSampler', shuffle=True, num_repeats=2), - dataset=ctw1500_textspotting_train) - -val_dataloader = dict( - batch_size=1, - num_workers=4, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=False), - dataset=ctw1500_textspotting_test) - -test_dataloader = val_dataloader - -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') - -custom_imports = dict(imports='spts') diff --git a/projects/SPTS/config/spts/spts_resnet50_350e_icdar2013-spts.py b/projects/SPTS/config/spts/spts_resnet50_350e_icdar2013-spts.py deleted file mode 100644 index 48f83cc5..00000000 --- a/projects/SPTS/config/spts/spts_resnet50_350e_icdar2013-spts.py +++ /dev/null @@ -1,57 +0,0 @@ -_base_ = [ - '_base_spts_resnet50.py', - '../_base_/datasets/icdar2013-spts.py', - '../_base_/default_runtime.py', -] - -num_epochs = 350 -lr = 0.00001 -min_lr = 0.00001 - -optim_wrapper = dict( - type='OptimWrapper', - optimizer=dict(type='AdamW', lr=lr, weight_decay=0.0001), - paramwise_cfg=dict(custom_keys={ - 'backbone': dict(lr_mult=0.1), - })) - -train_cfg = dict( - type='EpochBasedTrainLoop', max_epochs=num_epochs, val_interval=10) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') -# learning policy -param_scheduler = [ - dict(type='LinearLR', end=5, start_factor=1 / 5, by_epoch=True), - dict( - type='LinearLR', - begin=5, - end=min(num_epochs, - int((lr - min_lr) / (lr / num_epochs)) + 5), - end_factor=min_lr / lr, - by_epoch=True), -] - -# dataset settings -icdar2013_textspotting_train = _base_.icdar2013_textspotting_train -icdar2013_textspotting_train.pipeline = _base_.train_pipeline -icdar2013_textspotting_test = _base_.icdar2013_textspotting_test -icdar2013_textspotting_test.pipeline = _base_.test_pipeline - -train_dataloader = dict( - batch_size=4, - num_workers=8, - persistent_workers=True, - sampler=dict(type='RepeatAugSampler', shuffle=True, num_repeats=2), - dataset=icdar2013_textspotting_train) - -val_dataloader = dict( - batch_size=1, - num_workers=4, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=False), - dataset=icdar2013_textspotting_test) - -test_dataloader = val_dataloader - -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') diff --git a/projects/SPTS/config/spts/spts_resnet50_350e_icdar2015-spts.py b/projects/SPTS/config/spts/spts_resnet50_350e_icdar2015-spts.py deleted file mode 100644 index 5ce3f3c0..00000000 --- a/projects/SPTS/config/spts/spts_resnet50_350e_icdar2015-spts.py +++ /dev/null @@ -1,57 +0,0 @@ -_base_ = [ - '_base_spts_resnet50.py', - '../_base_/datasets/icdar2015-spts.py', - '../_base_/default_runtime.py', -] - -num_epochs = 350 -lr = 0.00001 -min_lr = 0.00001 - -optim_wrapper = dict( - type='OptimWrapper', - optimizer=dict(type='AdamW', lr=lr, weight_decay=0.0001), - paramwise_cfg=dict(custom_keys={ - 'backbone': dict(lr_mult=0.1), - })) - -train_cfg = dict( - type='EpochBasedTrainLoop', max_epochs=num_epochs, val_interval=10) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') -# learning policy -param_scheduler = [ - dict(type='LinearLR', end=5, start_factor=1 / 5, by_epoch=True), - dict( - type='LinearLR', - begin=5, - end=min(num_epochs, - int((lr - min_lr) / (lr / num_epochs)) + 5), - end_factor=min_lr / lr, - by_epoch=True), -] - -# dataset settings -icdar2015_textspotting_train = _base_.icdar2015_textspotting_train -icdar2015_textspotting_train.pipeline = _base_.train_pipeline -icdar2015_textspotting_test = _base_.icdar2015_textspotting_test -icdar2015_textspotting_test.pipeline = _base_.test_pipeline - -train_dataloader = dict( - batch_size=4, - num_workers=8, - persistent_workers=True, - sampler=dict(type='RepeatAugSampler', shuffle=True, num_repeats=2), - dataset=icdar2015_textspotting_train) - -val_dataloader = dict( - batch_size=1, - num_workers=4, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=False), - dataset=icdar2015_textspotting_test) - -test_dataloader = val_dataloader - -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') diff --git a/projects/SPTS/config/spts/spts_resnet50_150e_pretrain-spts.py b/projects/SPTS/config/spts/spts_resnet50_8xb8-150e_pretrain-spts.py similarity index 68% rename from projects/SPTS/config/spts/spts_resnet50_150e_pretrain-spts.py rename to projects/SPTS/config/spts/spts_resnet50_8xb8-150e_pretrain-spts.py index da9cb52a..42938f43 100644 --- a/projects/SPTS/config/spts/spts_resnet50_150e_pretrain-spts.py +++ b/projects/SPTS/config/spts/spts_resnet50_8xb8-150e_pretrain-spts.py @@ -2,7 +2,6 @@ _base_ = [ '_base_spts_resnet50.py', '../_base_/datasets/icdar2013-spts.py', '../_base_/datasets/icdar2015-spts.py', - '../_base_/datasets/ctw1500-spts.py', '../_base_/datasets/totaltext-spts.py', '../_base_/datasets/syntext1-spts.py', '../_base_/datasets/syntext2-spts.py', @@ -16,15 +15,12 @@ min_lr = 0.00001 optim_wrapper = dict( type='OptimWrapper', - accumulative_counts=2, optimizer=dict(type='AdamW', lr=lr, weight_decay=0.0001), paramwise_cfg=dict(custom_keys={ 'backbone': dict(lr_mult=0.1), })) train_cfg = dict( - type='EpochBasedTrainLoop', max_epochs=num_epochs, val_interval=20) -val_cfg = dict(type='ValLoop') -test_cfg = dict(type='TestLoop') + type='EpochBasedTrainLoop', max_epochs=num_epochs, val_interval=5) # learning policy param_scheduler = [ dict(type='LinearLR', end=5, start_factor=1 / 5, by_epoch=True), @@ -39,26 +35,24 @@ param_scheduler = [ # dataset settings train_list = [ - _base_.icdar2013_textspotting_train, _base_.icdar2015_textspotting_train, - _base_.mlt_textspotting_train, _base_.totaltext_textspotting_train, - _base_.syntext1_textspotting_train, _base_.syntext2_textspotting_train, - _base_.ctw1500_textspotting_train + _base_.icdar2013_textspotting_train, + _base_.icdar2015_textspotting_train, + _base_.mlt_textspotting_train, + _base_.totaltext_textspotting_train, + _base_.syntext1_textspotting_train, + _base_.syntext2_textspotting_train, ] train_dataset = dict( type='ConcatDataset', datasets=train_list, pipeline=_base_.train_pipeline) -train_dataloader = dict( - batch_size=4, - num_workers=8, - persistent_workers=True, - sampler=dict(type='RepeatAugSampler', shuffle=True, num_repeats=2), - dataset=train_dataset) - -val_dataloader = None -test_dataloader = None val_evaluator = None test_evaluator = None -val_cfg = None -test_cfg = None +train_dataloader = dict( + batch_size=8, + num_workers=8, + pin_memory=True, + persistent_workers=True, + sampler=dict(type='BatchAugSampler', shuffle=True, num_repeats=2), + dataset=train_dataset) diff --git a/projects/SPTS/config/spts/spts_resnet50_8xb8-200e_icdar2013.py b/projects/SPTS/config/spts/spts_resnet50_8xb8-200e_icdar2013.py new file mode 100644 index 00000000..223c1c24 --- /dev/null +++ b/projects/SPTS/config/spts/spts_resnet50_8xb8-200e_icdar2013.py @@ -0,0 +1,87 @@ +_base_ = [ + '_base_spts_resnet50_mmocr.py', + '../_base_/datasets/icdar2013.py', + '../_base_/default_runtime.py', +] + +load_from = 'work_dirs/spts_resnet50_150e_pretrain-spts/epoch_150.pth' + +num_epochs = 200 +lr = 0.00001 + +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', + save_best='generic/hmean', + rule='greater', + _delete_=True), + logger=dict(type='LoggerHook', interval=1)) + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=lr, weight_decay=0.0001), + paramwise_cfg=dict(custom_keys={ + 'backbone': dict(lr_mult=0.1), + })) + +train_cfg = dict( + type='EpochBasedTrainLoop', max_epochs=num_epochs, val_interval=10) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +# dataset settings +icdar2013_textspotting_train = _base_.icdar2013_textspotting_train +icdar2013_textspotting_train.pipeline = _base_.train_pipeline +icdar2013_textspotting_test = _base_.icdar2013_textspotting_test +icdar2013_textspotting_test.pipeline = _base_.test_pipeline + +train_dataloader = dict( + batch_size=8, + num_workers=8, + pin_memory=True, + persistent_workers=True, + sampler=dict(type='BatchAugSampler', shuffle=True, num_repeats=2), + dataset=icdar2013_textspotting_train) + +val_dataloader = dict( + batch_size=1, + num_workers=4, + pin_memory=True, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=icdar2013_textspotting_test) + +test_dataloader = val_dataloader + +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +val_evaluator = [ + dict( + type='E2EPointMetric', + prefix='generic', + lexicon_path='data/icdar2013/lexicons/GenericVocabulary_new.txt', + pair_path='data/icdar2013/lexicons/' + 'GenericVocabulary_pair_list.txt', + match_dist_thr=None), + dict( + type='E2EPointMetric', + prefix='weak', + lexicon_path='data/icdar2013/lexicons/' + 'ch2_test_vocabulary_new.txt', + pair_path='data/icdar2013/lexicons/' + 'ch2_test_vocabulary_pair_list.txt', + match_dist_thr=0.4), + dict( + type='E2EPointMetric', + prefix='strong', + lexicon_path='data/icdar2013/lexicons/' + 'new_strong_lexicon/lexicons/', + lexicon_mapping=('(.*).jpg', r'new_voc_\1.txt'), + pair_path='data/icdar2013/lexicons/' + 'new_strong_lexicon/pairs/', + pair_mapping=('(.*).jpg', r'pair_voc_\1.txt'), + match_dist_thr=0.4), +] + +test_evaluator = val_evaluator diff --git a/projects/SPTS/config/spts/spts_resnet50_8xb8-200e_icdar2015.py b/projects/SPTS/config/spts/spts_resnet50_8xb8-200e_icdar2015.py new file mode 100644 index 00000000..9f811a80 --- /dev/null +++ b/projects/SPTS/config/spts/spts_resnet50_8xb8-200e_icdar2015.py @@ -0,0 +1,87 @@ +_base_ = [ + '_base_spts_resnet50_mmocr.py', + '../_base_/datasets/icdar2015.py', + '../_base_/default_runtime.py', +] + +load_from = 'work_dirs/spts_resnet50_150e_pretrain-spts/epoch_150.pth' + +num_epochs = 200 +lr = 0.00001 + +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', + save_best='generic/hmean', + rule='greater', + _delete_=True), + logger=dict(type='LoggerHook', interval=10)) + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=lr, weight_decay=0.0001), + paramwise_cfg=dict(custom_keys={ + 'backbone': dict(lr_mult=0.1), + })) + +train_cfg = dict( + type='EpochBasedTrainLoop', max_epochs=num_epochs, val_interval=10) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +# dataset settings +icdar2015_textspotting_train = _base_.icdar2015_textspotting_train +icdar2015_textspotting_train.pipeline = _base_.train_pipeline +icdar2015_textspotting_test = _base_.icdar2015_textspotting_test +icdar2015_textspotting_test.pipeline = _base_.test_pipeline + +train_dataloader = dict( + batch_size=8, + num_workers=8, + pin_memory=True, + persistent_workers=True, + sampler=dict(type='BatchAugSampler', shuffle=True, num_repeats=2), + dataset=icdar2015_textspotting_train) + +val_dataloader = dict( + batch_size=1, + num_workers=4, + pin_memory=True, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=icdar2015_textspotting_test) + +test_dataloader = val_dataloader + +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +val_evaluator = [ + dict( + type='E2EPointMetric', + prefix='generic', + lexicon_path='data/icdar2015/lexicons/GenericVocabulary_new.txt', + pair_path='data/icdar2015/lexicons/' + 'GenericVocabulary_pair_list.txt', + match_dist_thr=None), + dict( + type='E2EPointMetric', + prefix='weak', + lexicon_path='data/icdar2015/lexicons/' + 'ch4_test_vocabulary_new.txt', + pair_path='data/icdar2015/lexicons/' + 'ch4_test_vocabulary_pair_list.txt', + match_dist_thr=0.4), + dict( + type='E2EPointMetric', + prefix='strong', + lexicon_path='data/icdar2015/lexicons/' + 'new_strong_lexicon/lexicons/', + lexicon_mapping=('(.*).jpg', r'new_voc_\1.txt'), + pair_path='data/icdar2015/lexicons/' + 'new_strong_lexicon/pairs/', + pair_mapping=('(.*).jpg', r'pair_voc_\1.txt'), + match_dist_thr=0.4), +] + +test_evaluator = val_evaluator diff --git a/projects/SPTS/config/spts/spts_resnet50_350e_totaltext-spts.py b/projects/SPTS/config/spts/spts_resnet50_8xb8-200e_totaltext.py similarity index 54% rename from projects/SPTS/config/spts/spts_resnet50_350e_totaltext-spts.py rename to projects/SPTS/config/spts/spts_resnet50_8xb8-200e_totaltext.py index fa50a302..ed71f000 100644 --- a/projects/SPTS/config/spts/spts_resnet50_350e_totaltext-spts.py +++ b/projects/SPTS/config/spts/spts_resnet50_8xb8-200e_totaltext.py @@ -1,12 +1,21 @@ _base_ = [ - '_base_spts_resnet50.py', - '../_base_/datasets/totaltext-spts.py', + '_base_spts_resnet50_mmocr.py', + '../_base_/datasets/totaltext.py', '../_base_/default_runtime.py', ] -num_epochs = 350 +load_from = 'work_dirs/spts_resnet50_150e_pretrain-spts/epoch_150.pth' + +num_epochs = 200 lr = 0.00001 -min_lr = 0.00001 + +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', + save_best='none/hmean', + rule='greater', + _delete_=True), + logger=dict(type='LoggerHook', interval=10)) optim_wrapper = dict( type='OptimWrapper', @@ -19,17 +28,6 @@ train_cfg = dict( type='EpochBasedTrainLoop', max_epochs=num_epochs, val_interval=10) val_cfg = dict(type='ValLoop') test_cfg = dict(type='TestLoop') -# learning policy -param_scheduler = [ - dict(type='LinearLR', end=5, start_factor=1 / 5, by_epoch=True), - dict( - type='LinearLR', - begin=5, - end=min(num_epochs, - int((lr - min_lr) / (lr / num_epochs)) + 5), - end_factor=min_lr / lr, - by_epoch=True), -] # dataset settings totaltext_textspotting_train = _base_.totaltext_textspotting_train @@ -38,16 +36,18 @@ totaltext_textspotting_test = _base_.totaltext_textspotting_test totaltext_textspotting_test.pipeline = _base_.test_pipeline train_dataloader = dict( - batch_size=4, + batch_size=8, num_workers=8, persistent_workers=True, - sampler=dict(type='RepeatAugSampler', shuffle=True, num_repeats=2), + pin_memory=True, + sampler=dict(type='BatchAugSampler', shuffle=True, num_repeats=2), dataset=totaltext_textspotting_train) val_dataloader = dict( batch_size=1, num_workers=4, persistent_workers=True, + pin_memory=True, sampler=dict(type='DefaultSampler', shuffle=False), dataset=totaltext_textspotting_test) @@ -55,3 +55,21 @@ test_dataloader = val_dataloader val_cfg = dict(type='ValLoop') test_cfg = dict(type='TestLoop') + +val_evaluator = [ + dict( + type='E2EPointMetric', + prefix='none', + word_spotting=True, + match_dist_thr=0.4), + dict( + type='E2EPointMetric', + prefix='full', + lexicon_path='data/totaltext/lexicons/weak_voc_new.txt', + pair_path='data/totaltext/lexicons/' + 'weak_voc_pair_list.txt', + word_spotting=True, + match_dist_thr=0.4), +] + +test_evaluator = val_evaluator diff --git a/projects/SPTS/spts/datasets/adel_dataset.py b/projects/SPTS/spts/datasets/adel_dataset.py index 05674afb..d1e3edda 100644 --- a/projects/SPTS/spts/datasets/adel_dataset.py +++ b/projects/SPTS/spts/datasets/adel_dataset.py @@ -40,6 +40,7 @@ class AdelDataset(CocoDataset): None img. The maximum extra number of cycles to get a valid image. Defaults to 1000. """ + METAINFO = {'classes': ('text', )} def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]: """Parse raw annotation to target format. diff --git a/projects/SPTS/spts/metric/e2e_point_metric.py b/projects/SPTS/spts/metric/e2e_point_metric.py index 7a44bbf8..d219b4aa 100644 --- a/projects/SPTS/spts/metric/e2e_point_metric.py +++ b/projects/SPTS/spts/metric/e2e_point_metric.py @@ -1,14 +1,20 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Optional, Sequence +import glob +import os.path as osp +import re +from typing import Dict, List, Optional, Sequence, Tuple import numpy as np import torch from mmengine.evaluator import BaseMetric from mmengine.logging import MMLogger -from shapely.geometry import LineString, Point +from rapidfuzz.distance import Levenshtein +from shapely.geometry import Point from mmocr.registry import METRICS +# TODO: CTW1500 read pair + @METRICS.register_module() class E2EPointMetric(BaseMetric): @@ -17,7 +23,20 @@ class E2EPointMetric(BaseMetric): Args: text_score_thrs (dict): Best text score threshold searching space. Defaults to dict(start=0.8, stop=1, step=0.01). - TODO: docstr + word_spotting (bool): Whether to work in word spotting mode. Defaults + to False. + lexicon_path (str, optional): Lexicon path for word spotting, which + points to a lexicon file or a directory. Defaults to None. + lexicon_mapping (tuple, optional): The rule to map test image name to + its corresponding lexicon file. Only effective when lexicon path + is a directory. Defaults to ('(.*).jpg', r'\1.txt'). + pair_path (str, optional): Pair path for word spotting, which points + to a pair file or a directory. Defaults to None. + pair_mapping (tuple, optional): The rule to map test image name to + its corresponding pair file. Only effective when pair path is a + directory. Defaults to ('(.*).jpg', r'\1.txt'). + match_dist_thr (float, optional): Matching distance threshold for + word spotting. Defaults to None. collect_device (str): Device name used for collecting results from different ranks during distributed training. Must be 'cpu' or 'gpu'. Defaults to 'cpu'. @@ -31,20 +50,52 @@ class E2EPointMetric(BaseMetric): def __init__(self, text_score_thrs: Dict = dict(start=0.8, stop=1, step=0.01), word_spotting: bool = False, + lexicon_path: Optional[str] = None, + lexicon_mapping: Tuple[str, str] = ('(.*).jpg', r'\1.txt'), + pair_path: Optional[str] = None, + pair_mapping: Tuple[str, str] = ('(.*).jpg', r'\1.txt'), + match_dist_thr: Optional[float] = None, collect_device: str = 'cpu', prefix: Optional[str] = None) -> None: super().__init__(collect_device=collect_device, prefix=prefix) self.text_score_thrs = np.arange(**text_score_thrs) self.word_spotting = word_spotting + self.match_dist_thr = match_dist_thr + if lexicon_path: + self.lexicon_mapping = lexicon_mapping + self.pair_mapping = pair_mapping + self.lexicons = self._read_lexicon(lexicon_path) + self.pairs = self._read_pair(pair_path) + + def _read_lexicon(self, lexicon_path: str) -> List[str]: + if lexicon_path.endswith('.txt'): + lexicon = open(lexicon_path, 'r').read().splitlines() + lexicon = [ele.strip() for ele in lexicon] + else: + lexicon = {} + for file in glob.glob(osp.join(lexicon_path, '*.txt')): + basename = osp.basename(file) + lexicon[basename] = self._read_lexicon(file) + return lexicon + + def _read_pair(self, pair_path: str) -> Dict[str, str]: + pairs = {} + if pair_path.endswith('.txt'): + pair_lines = open(pair_path, 'r').read().splitlines() + for line in pair_lines: + line = line.strip() + word = line.split(' ')[0].upper() + word_gt = line[len(word) + 1:] + pairs[word] = word_gt + else: + for file in glob.glob(osp.join(pair_path, '*.txt')): + basename = osp.basename(file) + pairs[basename] = self._read_pair(file) + return pairs def poly_center(self, poly_pts): poly_pts = np.array(poly_pts).reshape(-1, 2) - num_points = poly_pts.shape[0] - line1 = LineString(poly_pts[int(num_points / 2):]) - line2 = LineString(poly_pts[:int(num_points / 2)]) - mid_pt1 = np.array(line1.interpolate(0.5, normalized=True).coords[0]) - mid_pt2 = np.array(line2.interpolate(0.5, normalized=True).coords[0]) - return (mid_pt1 + mid_pt2) / 2 + return poly_pts.mean(0) def process(self, data_batch: Sequence[Dict], data_samples: Sequence[Dict]) -> None: @@ -87,6 +138,8 @@ class E2EPointMetric(BaseMetric): ~pred_ignore_flags) result = dict( + # reserved for image-level lexcions + gt_img_name=osp.basename(data_sample.get('img_path', '')), text_scores=text_scores, pred_points=pred_points, gt_points=gt_points, @@ -125,10 +178,33 @@ class E2EPointMetric(BaseMetric): gt_texts = result['gt_texts'] pred_texts = result['pred_texts'] gt_ignore_flags = result['gt_ignore_flags'] + gt_img_name = result['gt_img_name'] + + # Correct the words with lexicon + pred_dist_flags = np.zeros(len(pred_texts), dtype=bool) + if hasattr(self, 'lexicons'): + for i, pred_text in enumerate(pred_texts): + # If it's an image-level lexicon + if isinstance(self.lexicons, dict): + lexicon_name = self._map_img_name( + gt_img_name, self.lexicon_mapping) + pair_name = self._map_img_name(gt_img_name, + self.pair_mapping) + pred_texts[i], match_dist = self._match_word( + pred_text, self.lexicons[lexicon_name], + self.pairs[pair_name]) + else: + pred_texts[i], match_dist = self._match_word( + pred_text, self.lexicons, self.pairs) + if (self.match_dist_thr + and match_dist >= self.match_dist_thr): + # won't even count this as a prediction + pred_dist_flags[i] = True # Filter out predictions by IoU threshold for i, text_score_thr in enumerate(self.text_score_thrs): - pred_ignore_flags = text_scores < text_score_thr + pred_ignore_flags = pred_dist_flags | ( + text_scores < text_score_thr) filtered_pred_texts = self._get_true_elements( pred_texts, ~pred_ignore_flags) filtered_pred_points = self._get_true_elements( @@ -148,10 +224,8 @@ class E2EPointMetric(BaseMetric): min_idx = np.argmin(dists) if gt_texts[min_idx] == '###' or gt_ignore_flags[min_idx]: continue - # if not gt_matched[min_idx] and self.text_match( - # gt_texts[min_idx].upper(), pred_text.upper()): - if (not gt_matched[min_idx] and gt_texts[min_idx].upper() - == pred_text.upper()): + if not gt_matched[min_idx] and ( + pred_text.upper() == gt_texts[min_idx].upper()): gt_matched[min_idx] = True num_tp[i] += 1 num_preds[i] += 1 @@ -173,6 +247,69 @@ class E2EPointMetric(BaseMetric): best_eval_results = eval_results return best_eval_results + def _map_img_name(self, img_name: str, mapping: Tuple[str, str]) -> str: + """Map the image name to the another one based on mapping.""" + return re.sub(mapping[0], mapping[1], img_name) + def _true_indexes(self, array: np.ndarray) -> np.ndarray: """Get indexes of True elements from a 1D boolean array.""" return np.where(array)[0] + + def _word_spotting_filter(self, gt_ignore_flags: np.ndarray, + gt_texts: List[str] + ) -> Tuple[np.ndarray, List[str]]: + """Filter out gt instances that cannot be in a valid dictionary, and do + some simple preprocessing to texts.""" + + for i in range(len(gt_texts)): + if gt_ignore_flags[i]: + continue + text = gt_texts[i] + if text[-2:] in ["'s", "'S"]: + text = text[:-2] + text = text.strip('-') + for char in "'!?.:,*\"()·[]/": + text = text.replace(char, ' ') + text = text.strip() + gt_ignore_flags[i] = not self._include_in_dict(text) + if not gt_ignore_flags[i]: + gt_texts[i] = text + + return gt_ignore_flags, gt_texts + + def _include_in_dict(self, text: str) -> bool: + """Check if the text could be in a valid dictionary.""" + if len(text) != len(text.replace(' ', '')) or len(text) < 3: + return False + not_allowed = '×÷·' + valid_ranges = [(ord(u'a'), ord(u'z')), (ord(u'A'), ord(u'Z')), + (ord(u'À'), ord(u'ƿ')), (ord(u'DŽ'), ord(u'ɿ')), + (ord(u'Ά'), ord(u'Ͽ')), (ord(u'-'), ord(u'-'))] + for char in text: + code = ord(char) + if (not_allowed.find(char) != -1): + return False + valid = any(code >= r[0] and code <= r[1] for r in valid_ranges) + if not valid: + return False + return True + + def _match_word(self, + text: str, + lexicons: List[str], + pairs: Optional[Dict[str, str]] = None) -> Tuple[str, int]: + """Match the text with the lexicons and pairs.""" + text = text.upper() + matched_word = '' + matched_dist = 100 + for lexicon in lexicons: + lexicon = lexicon.upper() + norm_dist = Levenshtein.distance(text, lexicon) + norm_dist = Levenshtein.normalized_distance(text, lexicon) + if norm_dist < matched_dist: + matched_dist = norm_dist + if pairs: + matched_word = pairs[lexicon] + else: + matched_word = lexicon + return matched_word, matched_dist diff --git a/projects/SPTS/spts/model/spts_decoder.py b/projects/SPTS/spts/model/spts_decoder.py index 90b25c9b..374c65c6 100755 --- a/projects/SPTS/spts/model/spts_decoder.py +++ b/projects/SPTS/spts/model/spts_decoder.py @@ -58,18 +58,15 @@ class SPTSDecoder(BaseDecoder): max_seq_len=self.max_seq_len, init_cfg=init_cfg) self.num_bins = num_bins - self.shifted_seq_end_idx = self.num_bins + self.dictionary.seq_end_idx - self.shifted_start_idx = self.num_bins + self.dictionary.start_idx - actual_num_classes = self.dictionary.num_classes + num_bins - - self.embedding = DecoderEmbeddings( - actual_num_classes, self.dictionary.padding_idx + num_bins, - d_model, self.max_seq_len, dropout) + self.embedding = DecoderEmbeddings(self.dictionary.num_classes, + self.dictionary.padding_idx, + d_model, self.max_seq_len, dropout) self.pos_embedding = PositionEmbeddingSine(d_model // 2) self.vocab_embed = self._gen_vocab_embed(d_model, d_model, - actual_num_classes, 3) + self.dictionary.num_classes, + 3) encoder_layer = TransformerEncoderLayer(d_model, n_head, d_feedforward, dropout, 'relu', normalize_before) @@ -166,7 +163,7 @@ class SPTSDecoder(BaseDecoder): max_probs = [] seq = torch.zeros( batch_size, 1, dtype=torch.long).to( - out_enc.device) + self.shifted_start_idx + out_enc.device) + self.dictionary.start_idx for i in range(self.max_seq_len): tgt = self.embedding(seq).permute(1, 0, 2) hs = self.decoder( @@ -182,13 +179,13 @@ class SPTSDecoder(BaseDecoder): # bins chars unk eos seq_eos sos padding if i % 27 == 0: # coordinate or eos - out[:, self.num_bins:self.shifted_seq_end_idx] = 0 - out[:, self.shifted_seq_end_idx + 1:] = 0 + out[:, self.num_bins:self.dictionary.seq_end_idx] = 0 + out[:, self.dictionary.seq_end_idx + 1:] = 0 elif i % 27 == 1: # coordinate out[:, self.num_bins:] = 0 else: # chars out[:, :self.num_bins] = 0 - out[:, self.shifted_seq_end_idx:] = 0 + out[:, self.dictionary.seq_end_idx:] = 0 max_prob, extra_seq = torch.max(out, dim=-1, keepdim=True) # prob, extra_seq = out.topk(dim=-1, k=1) @@ -196,7 +193,7 @@ class SPTSDecoder(BaseDecoder): # TODO: optimize for multi-batch seq = torch.cat([seq, extra_seq], dim=-1) max_probs.append(max_prob) - if extra_seq[0] == self.shifted_seq_end_idx: + if extra_seq[0] == self.dictionary.seq_end_idx: break max_probs = torch.cat(max_probs, dim=-1) diff --git a/projects/SPTS/spts/model/spts_dictionary.py b/projects/SPTS/spts/model/spts_dictionary.py index fb88f166..f204a71a 100644 --- a/projects/SPTS/spts/model/spts_dictionary.py +++ b/projects/SPTS/spts/model/spts_dictionary.py @@ -15,6 +15,8 @@ class SPTSDictionary(Dictionary): Args: dict_file (str): The path of Character dict file which a single character must occupies a line. + num_bins (int): Number of bins dividing the image, which is used to + shift the character indexes. Defaults to 1000. with_start (bool): The flag to control whether to include the start token. Defaults to False. with_end (bool): The flag to control whether to include the end token. @@ -45,6 +47,7 @@ class SPTSDictionary(Dictionary): def __init__( self, dict_file: str, + num_bins: int = 1000, with_start: bool = False, with_end: bool = False, with_seq_end: bool = False, @@ -74,6 +77,26 @@ class SPTSDictionary(Dictionary): padding_token=padding_token, unknown_token=unknown_token) + self.num_bins = num_bins + self._shift_idx() + + @property + def num_classes(self) -> int: + """int: Number of output classes. Special tokens are counted. + """ + return len(self._dict) + self.num_bins + + def _shift_idx(self): + idx_terms = [ + 'start_idx', 'end_idx', 'unknown_idx', 'seq_end_idx', 'padding_idx' + ] + for term in idx_terms: + value = getattr(self, term) + if value: + setattr(self, term, value + self.num_bins) + for char in self._dict: + self._char2idx[char] += self.num_bins + def _update_dict(self): """Update the dict with tokens according to parameters.""" # BOS/EOS @@ -129,10 +152,11 @@ class SPTSDictionary(Dictionary): assert isinstance(index, (list, tuple)) string = '' for i in index: - assert i < len(self._dict), f'Index: {i} out of range! Index ' \ - f'must be less than {len(self._dict)}' + assert i < self.num_classes, f'Index: {i} out of range! Index ' \ + f'must be less than {self.num_classes}' # TODO: find its difference from ignore_chars # in TextRecogPostprocessor - if self._dict[i] is not None: - string += self._dict[i] + shifted_i = i - self.num_bins + if self._dict[shifted_i] is not None: + string += self._dict[shifted_i] return string diff --git a/projects/SPTS/spts/model/spts_module_loss.py b/projects/SPTS/spts/model/spts_module_loss.py index 370cd043..2847e14c 100644 --- a/projects/SPTS/spts/model/spts_module_loss.py +++ b/projects/SPTS/spts/model/spts_module_loss.py @@ -81,7 +81,7 @@ class SPTSModuleLoss(CEModuleLoss): self.max_num_text = (self.max_seq_len - 1) // (2 + max_text_len) self.num_bins = num_bins - weights = torch.ones(self.dictionary.num_classes + num_bins) + weights = torch.ones(self.dictionary.num_classes, dtype=torch.float32) weights[self.dictionary.seq_end_idx] = seq_eos_coef weights.requires_grad_ = False self.loss_ce = nn.CrossEntropyLoss( @@ -117,18 +117,24 @@ class SPTSModuleLoss(CEModuleLoss): if data_sample.get('have_target', False): continue - if len(data_sample.gt_instances.polygons) > self.max_num_text: + if len(data_sample.gt_instances) > self.max_num_text: keep = random.sample( - range(len(data_sample.gt_instances['polygons'])), - self.max_num_text) + range(len(data_sample.gt_instances)), self.max_num_text) data_sample.gt_instances = data_sample.gt_instances[keep] gt_instances = data_sample.gt_instances - if len(gt_instances.polygons) > 0: + if len(gt_instances) > 0: center_pts = [] # Slightly different from the original implementation # which gets the center points from bezier curves + # for bezier_pt in gt_instances.beziers: + # bezier_pt = bezier_pt.reshape(8, 2) + # mid_pt1 = sample_bezier_curve( + # bezier_pt[:4], mid_point=True) + # mid_pt2 = sample_bezier_curve( + # bezier_pt[4:], mid_point=True) + # center_pt = (mid_pt1 + mid_pt2) / 2 for polygon in gt_instances.polygons: center_pt = polygon.reshape(-1, 2).mean(0) center_pts.append(center_pt) @@ -152,7 +158,7 @@ class SPTSModuleLoss(CEModuleLoss): dtype=torch.long) + self.dictionary.end_idx max_len = min(self.max_text_len - 1, len(indexes)) indexes_tensor[:max_len] = torch.LongTensor(indexes)[:max_len] - indexes_tensor = indexes_tensor + self.num_bins + indexes_tensor = indexes_tensor gt_indexes.append(indexes_tensor) if len(gt_indexes) == 0: @@ -164,15 +170,12 @@ class SPTSModuleLoss(CEModuleLoss): if self.dictionary.start_idx is not None: gt_indexes = torch.cat([ - torch.LongTensor( - [self.dictionary.start_idx + self.num_bins]), - gt_indexes + torch.LongTensor([self.dictionary.start_idx]), gt_indexes ]) if self.dictionary.seq_end_idx is not None: gt_indexes = torch.cat([ gt_indexes, - torch.LongTensor( - [self.dictionary.seq_end_idx + self.num_bins]) + torch.LongTensor([self.dictionary.seq_end_idx]) ]) batch_max_len = max(batch_max_len, len(gt_indexes)) @@ -190,7 +193,7 @@ class SPTSModuleLoss(CEModuleLoss): padded_indexes = ( torch.zeros(batch_max_len, dtype=torch.long) + - self.dictionary.padding_idx + self.num_bins) + self.dictionary.padding_idx) padded_indexes[:len(indexes)] = indexes data_sample.gt_instances.set_metainfo( dict(padded_indexes=padded_indexes)) diff --git a/projects/SPTS/spts/model/spts_postprocessor.py b/projects/SPTS/spts/model/spts_postprocessor.py index 6994aafa..249c9694 100644 --- a/projects/SPTS/spts/model/spts_postprocessor.py +++ b/projects/SPTS/spts/model/spts_postprocessor.py @@ -102,12 +102,11 @@ class SPTSPostprocessor(BaseTextRecogPostprocessor): for char_index, char_score in zip(output_index[2:], output_score[2:]): # the first num_bins indexes are for points - dict_idx = char_index - self.num_bins - if dict_idx in self.ignore_indexes: + if char_index in self.ignore_indexes: continue - if dict_idx == self.dictionary.end_idx: + if char_index == self.dictionary.end_idx: break - indexes[-1].append(dict_idx) + indexes[-1].append(char_index) char_scores.append(char_score) text_scores.append(np.mean(char_scores).item()) return indexes, text_scores, points, pt_scores