diff --git a/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py b/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py index 790355fc..29fa5d6d 100644 --- a/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py +++ b/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py @@ -74,6 +74,8 @@ test_pipeline = [ data = dict( samples_per_gpu=16, workers_per_gpu=8, + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), train=dict( type=dataset_type, ann_file=data_root + '/instances_training.json', diff --git a/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py b/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py index f1ccec51..f3250e79 100644 --- a/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py +++ b/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py @@ -83,6 +83,8 @@ test_pipeline = [ data = dict( samples_per_gpu=8, workers_per_gpu=4, + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), train=dict( type=dataset_type, ann_file=data_root + '/instances_training.json', diff --git a/configs/textdet/drrg/drrg_r50_fpn_unet_1200e_ctw1500.py b/configs/textdet/drrg/drrg_r50_fpn_unet_1200e_ctw1500.py index f59f049c..378d37e1 100644 --- a/configs/textdet/drrg/drrg_r50_fpn_unet_1200e_ctw1500.py +++ b/configs/textdet/drrg/drrg_r50_fpn_unet_1200e_ctw1500.py @@ -91,6 +91,8 @@ test_pipeline = [ data = dict( samples_per_gpu=4, workers_per_gpu=4, + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), train=dict( type=dataset_type, ann_file=f'{data_root}/instances_training.json', diff --git a/configs/textdet/fcenet/fcenet_r50_fpn_1500e_icdar2015.py b/configs/textdet/fcenet/fcenet_r50_fpn_1500e_icdar2015.py index 2b4fd6d1..03ccd0ca 100644 --- a/configs/textdet/fcenet/fcenet_r50_fpn_1500e_icdar2015.py +++ b/configs/textdet/fcenet/fcenet_r50_fpn_1500e_icdar2015.py @@ -96,6 +96,8 @@ test_pipeline = [ data = dict( samples_per_gpu=8, workers_per_gpu=2, + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), train=dict( type=dataset_type, ann_file=data_root + '/instances_training.json', diff --git a/configs/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py b/configs/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py index 5c5e048d..3cdc7a72 100644 --- a/configs/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py +++ b/configs/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py @@ -95,6 +95,8 @@ test_pipeline = [ data = dict( samples_per_gpu=6, workers_per_gpu=2, + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), train=dict( type=dataset_type, ann_file=data_root + '/instances_training.json', diff --git a/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py index d821bb7f..d8841c90 100644 --- a/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py +++ b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py @@ -47,6 +47,8 @@ test_pipeline = [ data = dict( samples_per_gpu=8, workers_per_gpu=4, + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), train=dict( type=dataset_type, ann_file=data_root + '/instances_training.json', diff --git a/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py index e2d8be68..52fce0af 100644 --- a/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py +++ b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py @@ -46,6 +46,8 @@ test_pipeline = [ data = dict( samples_per_gpu=8, workers_per_gpu=4, + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), train=dict( type=dataset_type, ann_file=data_root + '/instances_training.json', diff --git a/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py index 8b948e09..8fdb3f1c 100644 --- a/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py +++ b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py @@ -47,6 +47,8 @@ test_pipeline = [ data = dict( samples_per_gpu=8, workers_per_gpu=4, + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), train=dict( type=dataset_type, ann_file=data_root + '/instances_training.json', diff --git a/configs/textdet/panet/panet_r18_fpem_ffm_600e_ctw1500.py b/configs/textdet/panet/panet_r18_fpem_ffm_600e_ctw1500.py index 58d1d22b..a1bfb629 100644 --- a/configs/textdet/panet/panet_r18_fpem_ffm_600e_ctw1500.py +++ b/configs/textdet/panet/panet_r18_fpem_ffm_600e_ctw1500.py @@ -82,6 +82,8 @@ test_pipeline = [ data = dict( samples_per_gpu=2, workers_per_gpu=2, + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), train=dict( type=dataset_type, ann_file=data_root + '/instances_training.json', diff --git a/configs/textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py b/configs/textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py index de12c26c..e5681c6a 100644 --- a/configs/textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py +++ b/configs/textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py @@ -80,6 +80,8 @@ test_pipeline = [ data = dict( samples_per_gpu=8, workers_per_gpu=2, + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), train=dict( type=dataset_type, ann_file=data_root + '/instances_training.json', diff --git a/configs/textdet/panet/panet_r50_fpem_ffm_600e_icdar2017.py b/configs/textdet/panet/panet_r50_fpem_ffm_600e_icdar2017.py index 933f0bae..4b41851f 100644 --- a/configs/textdet/panet/panet_r50_fpem_ffm_600e_icdar2017.py +++ b/configs/textdet/panet/panet_r50_fpem_ffm_600e_icdar2017.py @@ -75,6 +75,8 @@ test_pipeline = [ data = dict( samples_per_gpu=4, workers_per_gpu=4, + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), train=dict( type=dataset_type, ann_file=data_root + '/instances_training.json', diff --git a/configs/textdet/psenet/psenet_r50_fpnf_600e_ctw1500.py b/configs/textdet/psenet/psenet_r50_fpnf_600e_ctw1500.py index e1f5fd04..8be704ce 100644 --- a/configs/textdet/psenet/psenet_r50_fpnf_600e_ctw1500.py +++ b/configs/textdet/psenet/psenet_r50_fpnf_600e_ctw1500.py @@ -89,6 +89,8 @@ test_pipeline = [ data = dict( samples_per_gpu=2, workers_per_gpu=2, + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), train=dict( type=dataset_type, ann_file=data_root + '/instances_training.json', diff --git a/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py b/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py index 5eb7538c..d20ec29e 100644 --- a/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py +++ b/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py @@ -89,6 +89,8 @@ test_pipeline = [ data = dict( samples_per_gpu=8, workers_per_gpu=2, + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), train=dict( type=dataset_type, ann_file=data_root + '/instances_training.json', diff --git a/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py b/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py index 3c20c4cc..19c5aea0 100644 --- a/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py +++ b/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py @@ -84,6 +84,8 @@ test_pipeline = [ data = dict( samples_per_gpu=8, workers_per_gpu=4, + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), train=dict( type=dataset_type, ann_file=data_root + '/instances_training.json', diff --git a/configs/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py b/configs/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py index e710835d..b7dba589 100644 --- a/configs/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py +++ b/configs/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py @@ -94,6 +94,8 @@ test_pipeline = [ data = dict( samples_per_gpu=4, workers_per_gpu=4, + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), train=dict( type=dataset_type, ann_file=f'{data_root}/instances_training.json', diff --git a/configs/textrecog/crnn/crnn_academic_dataset.py b/configs/textrecog/crnn/crnn_academic_dataset.py index 748625e4..1875f69f 100644 --- a/configs/textrecog/crnn/crnn_academic_dataset.py +++ b/configs/textrecog/crnn/crnn_academic_dataset.py @@ -148,6 +148,8 @@ test6['ann_file'] = test_ann_file6 data = dict( samples_per_gpu=64, workers_per_gpu=4, + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), train=dict(type='ConcatDataset', datasets=[train1]), val=dict( type='ConcatDataset', diff --git a/configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py b/configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py index 9299583b..92398b93 100644 --- a/configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py +++ b/configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py @@ -152,6 +152,8 @@ test6['ann_file'] = test_ann_file6 data = dict( samples_per_gpu=128, workers_per_gpu=4, + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), train=dict(type='ConcatDataset', datasets=[train1, train2]), val=dict( type='ConcatDataset', diff --git a/configs/textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py b/configs/textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py index b003b823..121a9f56 100644 --- a/configs/textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py +++ b/configs/textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py @@ -152,6 +152,8 @@ test6['ann_file'] = test_ann_file6 data = dict( samples_per_gpu=128, workers_per_gpu=4, + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), train=dict(type='ConcatDataset', datasets=[train1, train2]), val=dict( type='ConcatDataset', diff --git a/configs/textrecog/robust_scanner/robustscanner_r31_academic.py b/configs/textrecog/robust_scanner/robustscanner_r31_academic.py index e90dd75d..7c1b8636 100644 --- a/configs/textrecog/robust_scanner/robustscanner_r31_academic.py +++ b/configs/textrecog/robust_scanner/robustscanner_r31_academic.py @@ -182,6 +182,8 @@ test6['ann_file'] = test_ann_file6 data = dict( samples_per_gpu=64, workers_per_gpu=2, + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), train=dict( type='ConcatDataset', datasets=[ diff --git a/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py b/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py index 4e405227..f00ab0cc 100644 --- a/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py +++ b/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py @@ -204,6 +204,8 @@ test6['ann_file'] = test_ann_file6 data = dict( samples_per_gpu=64, workers_per_gpu=2, + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), train=dict( type='ConcatDataset', datasets=[ diff --git a/configs/textrecog/sar/sar_r31_parallel_decoder_chinese.py b/configs/textrecog/sar/sar_r31_parallel_decoder_chinese.py index 90e7a0de..5e58bbd7 100644 --- a/configs/textrecog/sar/sar_r31_parallel_decoder_chinese.py +++ b/configs/textrecog/sar/sar_r31_parallel_decoder_chinese.py @@ -119,6 +119,8 @@ test = dict( data = dict( samples_per_gpu=40, workers_per_gpu=2, + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), train=dict(type='ConcatDataset', datasets=[train]), val=dict(type='ConcatDataset', datasets=[test]), test=dict(type='ConcatDataset', datasets=[test])) diff --git a/configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py b/configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py index 70c08464..75cd15e7 100755 --- a/configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py +++ b/configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py @@ -30,24 +30,19 @@ train_pipeline = [ test_pipeline = [ dict(type='LoadImageFromFile'), dict( - type='MultiRotateAugOCR', - rotate_degrees=[0, 90, 270], - transforms=[ - dict( - type='ResizeOCR', - height=48, - min_width=48, - max_width=160, - keep_aspect_ratio=True), - dict(type='ToTensorOCR'), - dict(type='NormalizeOCR', **img_norm_cfg), - dict( - type='Collect', - keys=['img'], - meta_keys=[ - 'filename', 'ori_shape', 'img_shape', 'valid_ratio', - 'img_norm_cfg', 'ori_filename' - ]), + type='ResizeOCR', + height=48, + min_width=48, + max_width=160, + keep_aspect_ratio=True), + dict(type='ToTensorOCR'), + dict(type='NormalizeOCR', **img_norm_cfg), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'filename', 'ori_shape', 'img_shape', 'valid_ratio', + 'img_norm_cfg', 'ori_filename' ]) ] @@ -66,7 +61,7 @@ train1 = dict( keys=['filename', 'text'], keys_idx=[0, 1], separator=' ')), - pipeline=train_pipeline, + pipeline=None, test_mode=False) train_anno_file2 = 'tests/data/ocr_toy_dataset/label.lmdb' @@ -82,7 +77,7 @@ train2 = dict( keys=['filename', 'text'], keys_idx=[0, 1], separator=' ')), - pipeline=train_pipeline, + pipeline=None, test_mode=False) test_anno_file1 = 'tests/data/ocr_toy_dataset/label.lmdb' @@ -92,20 +87,25 @@ test = dict( ann_file=test_anno_file1, loader=dict( type='LmdbLoader', - repeat=1, + repeat=10, parser=dict( type='LineStrParser', keys=['filename', 'text'], keys_idx=[0, 1], separator=' ')), - pipeline=test_pipeline, + pipeline=None, test_mode=True) data = dict( - samples_per_gpu=16, workers_per_gpu=2, - train=dict(type='ConcatDataset', datasets=[train1, train2]), - val=dict(type='ConcatDataset', datasets=[test]), - test=dict(type='ConcatDataset', datasets=[test])) + samples_per_gpu=8, + train=dict( + type='UniformConcatDataset', + datasets=[train1, train2], + pipeline=train_pipeline), + val=dict( + type='UniformConcatDataset', datasets=[test], pipeline=test_pipeline), + test=dict( + type='UniformConcatDataset', datasets=[test], pipeline=test_pipeline)) evaluation = dict(interval=1, metric='acc') diff --git a/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py b/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py index 6fa00dd7..5c468625 100644 --- a/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py +++ b/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py @@ -204,6 +204,8 @@ test6['ann_file'] = test_ann_file6 data = dict( samples_per_gpu=64, workers_per_gpu=2, + val_dataloader=dict(samples_per_gpu=1), + test_dataloader=dict(samples_per_gpu=1), train=dict( type='ConcatDataset', datasets=[ diff --git a/docs/getting_started.md b/docs/getting_started.md index 40578e1d..fbab8553 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -202,7 +202,12 @@ To support the tasks of `text detection`, `text recognition` and `key informatio Here we show some examples of using different combination of `loader` and `parser`. -#### Encoder-Decoder-Based Text Recognition Task +#### Text Recognition Task + +##### OCRDataset + +*Dataset for encoder-decoder based recognizer* + ```python dataset_type = 'OCRDataset' img_prefix = 'tests/data/ocr_toy_dataset/imgs' @@ -225,7 +230,7 @@ train = dict( You can check the content of the annotation file in `tests/data/ocr_toy_dataset/label.txt`. The combination of `HardDiskLoader` and `LineStrParser` will return a dict for each file by calling `__getitem__`: `{'filename': '1223731.jpg', 'text': 'GRAND'}`. -##### Optional Arguments: +**Optional Arguments:** - `repeat`: The number of repeated lines in the annotation files. For example, if there are `10` lines in the annotation file, setting `repeat=10` will generate a corresponding annotation file with size `100`. @@ -254,7 +259,10 @@ train = dict( test_mode=False) ``` -#### Segmentation-Based Text Recognition Task +##### OCRSegDataset + +*Dataset for segmentation-based recognizer* + ```python prefix = 'tests/data/ocr_char_ann_toy_dataset/' train = dict( @@ -277,6 +285,11 @@ The combination of `HardDiskLoader` and `LineJsonParser` will return a dict for ``` #### Text Detection Task + +##### TextDetDataset + +*Dataset with annotation file in line-json txt format* + ```python dataset_type = 'TextDetDataset' img_prefix = 'tests/data/toy_dataset/imgs' @@ -302,7 +315,10 @@ The combination of `HardDiskLoader` and `LineJsonParser` will return a dict for ``` -### COCO-like Dataset +##### IcdarDataset + +*Dataset with annotation file in coco-like json format* + For text detection, you can also use an annotation file in a COCO format that is defined in [mmdet](https://github.com/open-mmlab/mmdetection/blob/master/mmdet/datasets/coco.py): ```python dataset_type = 'IcdarDataset' @@ -325,4 +341,40 @@ You can check the content of the annotation file in `tests/data/toy_dataset/inst ```shell python tools/data_converter/ctw1500_converter.py ${src_root_path} -o ${out_path} --split-list training test ``` + +#### UniformConcatDataset + +To use the `universal pipeline` for multiple datasets, we design `UniformConcatDataset`. +For example, apply `train_pipeline` for both `train1` and `train2`, + +```python +data = dict( + ... + train=dict( + type='UniformConcatDataset', + datasets=[train1, train2], + pipeline=train_pipeline)) ``` + +Meanwhile, we have +- train_dataloader +- val_dataloader +- test_dataloader + +to give specific settings. They will override the general settings in `data` dict. +For example, + +```python +data = dict( + workers_per_gpu=2, # global setting + train_dataloader=dict(samples_per_gpu=8, drop_last=True), # train-specific setting + val_dataloader=dict(samples_per_gpu=8, workers_per_gpu=1), # val-specific setting + test_dataloader=dict(samples_per_gpu=8), # test-specific setting + ... +``` +`workers_per_gpu` is global setting and `train_dataloader` and `val_dataloader` will inherit the values. +`val_dataloader` override the value by `workers_per_gpu=1`. + +To activate `batch inference` for `val` and `test`, please set `val_dataloader=dict(samples_per_gpu=8)` and `test_dataloader=dict(samples_per_gpu=8)` as above. +Or just set `samples_per_gpu=8` as global setting. +See [config](/configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py) for an example. diff --git a/mmocr/apis/inference.py b/mmocr/apis/inference.py index e9e5b777..1c009a8b 100644 --- a/mmocr/apis/inference.py +++ b/mmocr/apis/inference.py @@ -6,23 +6,40 @@ from mmdet.datasets import replace_ImageToTensor from mmdet.datasets.pipelines import Compose -def disable_text_recog_aug_test(cfg): +def disable_text_recog_aug_test(cfg, set_types=None): """Remove aug_test from test pipeline of text recognition. Args: cfg (mmcv.Config): Input config. + set_types (list[str]): Type of dataset source. Should be + None or sublist of ['test', 'val'] Returns: cfg (mmcv.Config): Output config removing `MultiRotateAugOCR` in test pipeline. """ - if cfg.data.test.pipeline[1].type == 'MultiRotateAugOCR': - cfg.data.test.pipeline = [ - cfg.data.test.pipeline[0], *cfg.data.test.pipeline[1].transforms - ] + assert set_types is None or isinstance(set_types, list) + if set_types is None: + set_types = ['val', 'test'] + for set_type in set_types: + if cfg.data[set_type].pipeline[1].type == 'MultiRotateAugOCR': + cfg.data[set_type].pipeline = [ + cfg.data[set_type].pipeline[0], + *cfg.data[set_type].pipeline[1].transforms + ] + assert_if_not_support_batch_mode(cfg, set_type) return cfg +def assert_if_not_support_batch_mode(cfg, set_type='test'): + if cfg.data[set_type].pipeline[1].type == 'ResizeOCR': + if cfg.data[set_type].pipeline[1].max_width is None: + raise Exception('Batch mode is not supported ' + 'since the image width is not fixed, ' + 'in the case that keeping aspect ratio but ' + 'max_width is none when do resize.') + + def model_inference(model, imgs, batch_mode=False): """Inference image(s) with the detector. @@ -51,13 +68,7 @@ def model_inference(model, imgs, batch_mode=False): cfg = model.cfg if batch_mode: - if cfg.data.test.pipeline[1].type == 'ResizeOCR': - if cfg.data.test.pipeline[1].max_width is None: - raise Exception('Free resize do not support batch mode ' - 'since the image width is not fixed, ' - 'for resize keeping aspect ratio and ' - 'max_width is not give.') - cfg = disable_text_recog_aug_test(cfg) + cfg = disable_text_recog_aug_test(cfg, set_types=['test']) device = next(model.parameters()).device # model device diff --git a/mmocr/apis/train.py b/mmocr/apis/train.py index e3bb9cc4..ecebeac4 100644 --- a/mmocr/apis/train.py +++ b/mmocr/apis/train.py @@ -10,6 +10,7 @@ from mmdet.core import DistEvalHook, EvalHook from mmdet.datasets import (build_dataloader, build_dataset, replace_ImageToTensor) +from mmocr.apis.inference import disable_text_recog_aug_test from mmocr.utils import get_root_logger @@ -24,30 +25,32 @@ def train_detector(model, # prepare data loaders dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] - if 'imgs_per_gpu' in cfg.data: - logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. ' - 'Please use "samples_per_gpu" instead') - if 'samples_per_gpu' in cfg.data: - logger.warning( - f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and ' - f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"' - f'={cfg.data.imgs_per_gpu} is used in this experiments') - else: - logger.warning( - 'Automatically set "samples_per_gpu"="imgs_per_gpu"=' - f'{cfg.data.imgs_per_gpu} in this experiments') - cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu - - data_loaders = [ - build_dataloader( - ds, - cfg.data.samples_per_gpu, - cfg.data.workers_per_gpu, - # cfg.gpus will be ignored if distributed - len(cfg.gpu_ids), + # step 1: give default values and override (if exist) from cfg.data + loader_cfg = { + **dict( + seed=cfg.get('seed'), + drop_last=False, dist=distributed, - seed=cfg.seed) for ds in dataset - ] + num_gpus=len(cfg.gpu_ids)), + **({} if torch.__version__ != 'parrots' else dict( + prefetch_num=2, + pin_memory=False, + )), + **dict((k, cfg.data[k]) for k in [ + 'samples_per_gpu', + 'workers_per_gpu', + 'shuffle', + 'seed', + 'drop_last', + 'prefetch_num', + 'pin_memory', + ] if k in cfg.data) + } + + # step 2: cfg.data.train_dataloader has highest priority + train_loader_cfg = dict(loader_cfg, **cfg.data.get('train_dataloader', {})) + + data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset] # put model on gpus if distributed: @@ -110,19 +113,28 @@ def train_detector(model, # register eval hooks if validate: - # Support batch_size > 1 in validation - val_samples_per_gpu = cfg.data.val.pop('samples_per_gpu', 1) + val_samples_per_gpu = (cfg.data.get('val_dataloader', {})).get( + 'samples_per_gpu', cfg.data.get('samples_per_gpu', 1)) if val_samples_per_gpu > 1: - # Replace 'ImageToTensor' to 'DefaultFormatBundle' - cfg.data.val.pipeline = replace_ImageToTensor( - cfg.data.val.pipeline) + # Support batch_size > 1 in test for text recognition + # by disable MultiRotateAugOCR since it is useless for most case + cfg = disable_text_recog_aug_test(cfg) + if cfg.data.val.get('pipeline', None) is not None: + # Replace 'ImageToTensor' to 'DefaultFormatBundle' + cfg.data.val.pipeline = replace_ImageToTensor( + cfg.data.val.pipeline) + val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) - val_dataloader = build_dataloader( - val_dataset, - samples_per_gpu=val_samples_per_gpu, - workers_per_gpu=cfg.data.workers_per_gpu, - dist=distributed, - shuffle=False) + + val_loader_cfg = { + **loader_cfg, + **dict(shuffle=False, drop_last=False), + **cfg.data.get('val_dataloader', {}), + **dict(samples_per_gpu=val_samples_per_gpu) + } + + val_dataloader = build_dataloader(val_dataset, **val_loader_cfg) + eval_cfg = cfg.get('evaluation', {}) eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' eval_hook = DistEvalHook if distributed else EvalHook diff --git a/mmocr/datasets/__init__.py b/mmocr/datasets/__init__.py index b4bad7b4..886d098e 100644 --- a/mmocr/datasets/__init__.py +++ b/mmocr/datasets/__init__.py @@ -9,6 +9,7 @@ from .ocr_dataset import OCRDataset from .ocr_seg_dataset import OCRSegDataset from .pipelines import CustomFormatBundle, DBNetTargets, FCENetTargets from .text_det_dataset import TextDetDataset +from .uniform_concat_dataset import UniformConcatDataset from .utils import * # NOQA @@ -16,7 +17,7 @@ __all__ = [ 'DATASETS', 'IcdarDataset', 'build_dataloader', 'build_dataset', 'BaseDataset', 'OCRDataset', 'TextDetDataset', 'CustomFormatBundle', 'DBNetTargets', 'OCRSegDataset', 'KIEDataset', 'FCENetTargets', - 'NerDataset' + 'NerDataset', 'UniformConcatDataset' ] __all__ += utils.__all__ diff --git a/mmocr/datasets/uniform_concat_dataset.py b/mmocr/datasets/uniform_concat_dataset.py new file mode 100644 index 00000000..0fe7275c --- /dev/null +++ b/mmocr/datasets/uniform_concat_dataset.py @@ -0,0 +1,27 @@ +import copy + +from mmdet.datasets import DATASETS, ConcatDataset, build_dataset + + +@DATASETS.register_module() +class UniformConcatDataset(ConcatDataset): + """A wrapper of concatenated dataset. + + Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but + concat the group flag for image aspect ratio. + + Args: + datasets (list[:obj:`Dataset`]): A list of datasets. + separate_eval (bool): Whether to evaluate the results + separately if it is used as validation dataset. + Defaults to True. + """ + + def __init__(self, datasets, separate_eval=True, pipeline=None, **kwargs): + from_cfg = all(isinstance(x, dict) for x in datasets) + if pipeline is not None: + assert from_cfg, 'datasets should be config dicts' + for dataset in datasets: + dataset['pipeline'] = copy.deepcopy(pipeline) + datasets = [build_dataset(c, kwargs) for c in datasets] + super().__init__(datasets, separate_eval) diff --git a/tests/test_apis/test_model_inference.py b/tests/test_apis/test_model_inference.py index 63036f8d..e7c164aa 100644 --- a/tests/test_apis/test_model_inference.py +++ b/tests/test_apis/test_model_inference.py @@ -114,19 +114,19 @@ def test_model_batch_inference_raises_exception_error_free_resize_recog( with pytest.raises( Exception, - match='Free resize do not support batch mode ' + match='Batch mode is not supported ' 'since the image width is not fixed, ' - 'for resize keeping aspect ratio and ' - 'max_width is not give.'): + 'in the case that keeping aspect ratio but ' + 'max_width is none when do resize.'): sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_recog.jpg') model_inference( model, [sample_img_path, sample_img_path], batch_mode=True) with pytest.raises( Exception, - match='Free resize do not support batch mode ' + match='Batch mode is not supported ' 'since the image width is not fixed, ' - 'for resize keeping aspect ratio and ' - 'max_width is not give.'): + 'in the case that keeping aspect ratio but ' + 'max_width is none when do resize.'): img = imread(sample_img_path) model_inference(model, [img, img], batch_mode=True) diff --git a/tools/test.py b/tools/test.py index 63593168..fd352460 100755 --- a/tools/test.py +++ b/tools/test.py @@ -13,6 +13,7 @@ from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, from mmdet.apis import multi_gpu_test, single_gpu_test from mmdet.datasets import replace_ImageToTensor +from mmocr.apis.inference import disable_text_recog_aug_test from mmocr.datasets import build_dataloader, build_dataset from mmocr.models import build_detector @@ -140,12 +141,16 @@ def main(): # in case the test dataset is concatenated samples_per_gpu = 1 if isinstance(cfg.data.test, dict): - cfg.data.test.test_mode = True - samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1) + samples_per_gpu = (cfg.data.get('test_dataloader', {})).get( + 'samples_per_gpu', cfg.data.get('samples_per_gpu', 1)) if samples_per_gpu > 1: - # Replace 'ImageToTensor' to 'DefaultFormatBundle' - cfg.data.test.pipeline = replace_ImageToTensor( - cfg.data.test.pipeline) + # Support batch_size > 1 in test for text recognition + # by disable MultiRotateAugOCR since it is useless for most case + cfg = disable_text_recog_aug_test(cfg) + if cfg.data.test.get('pipeline', None) is not None: + # Replace 'ImageToTensor' to 'DefaultFormatBundle' + cfg.data.test.pipeline = replace_ImageToTensor( + cfg.data.test.pipeline) elif isinstance(cfg.data.test, list): for ds_cfg in cfg.data.test: ds_cfg.test_mode = True @@ -163,13 +168,29 @@ def main(): init_dist(args.launcher, **cfg.dist_params) # build the dataloader - dataset = build_dataset(cfg.data.test) - data_loader = build_dataloader( - dataset, - samples_per_gpu=samples_per_gpu, - workers_per_gpu=cfg.data.workers_per_gpu, - dist=distributed, - shuffle=False) + dataset = build_dataset(cfg.data.test, dict(test_mode=True)) + # step 1: give default values and override (if exist) from cfg.data + loader_cfg = { + **dict(seed=cfg.get('seed'), drop_last=False, dist=distributed), + **({} if torch.__version__ != 'parrots' else dict( + prefetch_num=2, + pin_memory=False, + )), + **dict((k, cfg.data[k]) for k in [ + 'workers_per_gpu', + 'seed', + 'prefetch_num', + 'pin_memory', + ] if k in cfg.data) + } + test_loader_cfg = { + **loader_cfg, + **dict(shuffle=False, drop_last=False), + **cfg.data.get('test_dataloader', {}), + **dict(samples_per_gpu=samples_per_gpu) + } + + data_loader = build_dataloader(dataset, **test_loader_cfg) # build the model and load checkpoint cfg.model.train_cfg = None