diff --git a/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py b/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py
index 790355fc..29fa5d6d 100644
--- a/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py
+++ b/configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py
@@ -74,6 +74,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=16,
workers_per_gpu=8,
+ val_dataloader=dict(samples_per_gpu=1),
+ test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',
diff --git a/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py b/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py
index f1ccec51..f3250e79 100644
--- a/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py
+++ b/configs/textdet/dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py
@@ -83,6 +83,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=8,
workers_per_gpu=4,
+ val_dataloader=dict(samples_per_gpu=1),
+ test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',
diff --git a/configs/textdet/drrg/drrg_r50_fpn_unet_1200e_ctw1500.py b/configs/textdet/drrg/drrg_r50_fpn_unet_1200e_ctw1500.py
index f59f049c..378d37e1 100644
--- a/configs/textdet/drrg/drrg_r50_fpn_unet_1200e_ctw1500.py
+++ b/configs/textdet/drrg/drrg_r50_fpn_unet_1200e_ctw1500.py
@@ -91,6 +91,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
+ val_dataloader=dict(samples_per_gpu=1),
+ test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=f'{data_root}/instances_training.json',
diff --git a/configs/textdet/fcenet/fcenet_r50_fpn_1500e_icdar2015.py b/configs/textdet/fcenet/fcenet_r50_fpn_1500e_icdar2015.py
index 2b4fd6d1..03ccd0ca 100644
--- a/configs/textdet/fcenet/fcenet_r50_fpn_1500e_icdar2015.py
+++ b/configs/textdet/fcenet/fcenet_r50_fpn_1500e_icdar2015.py
@@ -96,6 +96,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=8,
workers_per_gpu=2,
+ val_dataloader=dict(samples_per_gpu=1),
+ test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',
diff --git a/configs/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py b/configs/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py
index 5c5e048d..3cdc7a72 100644
--- a/configs/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py
+++ b/configs/textdet/fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py
@@ -95,6 +95,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=6,
workers_per_gpu=2,
+ val_dataloader=dict(samples_per_gpu=1),
+ test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',
diff --git a/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py
index d821bb7f..d8841c90 100644
--- a/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py
+++ b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py
@@ -47,6 +47,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=8,
workers_per_gpu=4,
+ val_dataloader=dict(samples_per_gpu=1),
+ test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',
diff --git a/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py
index e2d8be68..52fce0af 100644
--- a/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py
+++ b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py
@@ -46,6 +46,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=8,
workers_per_gpu=4,
+ val_dataloader=dict(samples_per_gpu=1),
+ test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',
diff --git a/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py
index 8b948e09..8fdb3f1c 100644
--- a/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py
+++ b/configs/textdet/maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py
@@ -47,6 +47,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=8,
workers_per_gpu=4,
+ val_dataloader=dict(samples_per_gpu=1),
+ test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',
diff --git a/configs/textdet/panet/panet_r18_fpem_ffm_600e_ctw1500.py b/configs/textdet/panet/panet_r18_fpem_ffm_600e_ctw1500.py
index 58d1d22b..a1bfb629 100644
--- a/configs/textdet/panet/panet_r18_fpem_ffm_600e_ctw1500.py
+++ b/configs/textdet/panet/panet_r18_fpem_ffm_600e_ctw1500.py
@@ -82,6 +82,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
+ val_dataloader=dict(samples_per_gpu=1),
+ test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',
diff --git a/configs/textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py b/configs/textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py
index de12c26c..e5681c6a 100644
--- a/configs/textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py
+++ b/configs/textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py
@@ -80,6 +80,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=8,
workers_per_gpu=2,
+ val_dataloader=dict(samples_per_gpu=1),
+ test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',
diff --git a/configs/textdet/panet/panet_r50_fpem_ffm_600e_icdar2017.py b/configs/textdet/panet/panet_r50_fpem_ffm_600e_icdar2017.py
index 933f0bae..4b41851f 100644
--- a/configs/textdet/panet/panet_r50_fpem_ffm_600e_icdar2017.py
+++ b/configs/textdet/panet/panet_r50_fpem_ffm_600e_icdar2017.py
@@ -75,6 +75,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
+ val_dataloader=dict(samples_per_gpu=1),
+ test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',
diff --git a/configs/textdet/psenet/psenet_r50_fpnf_600e_ctw1500.py b/configs/textdet/psenet/psenet_r50_fpnf_600e_ctw1500.py
index e1f5fd04..8be704ce 100644
--- a/configs/textdet/psenet/psenet_r50_fpnf_600e_ctw1500.py
+++ b/configs/textdet/psenet/psenet_r50_fpnf_600e_ctw1500.py
@@ -89,6 +89,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
+ val_dataloader=dict(samples_per_gpu=1),
+ test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',
diff --git a/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py b/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py
index 5eb7538c..d20ec29e 100644
--- a/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py
+++ b/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2015.py
@@ -89,6 +89,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=8,
workers_per_gpu=2,
+ val_dataloader=dict(samples_per_gpu=1),
+ test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',
diff --git a/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py b/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py
index 3c20c4cc..19c5aea0 100644
--- a/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py
+++ b/configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py
@@ -84,6 +84,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=8,
workers_per_gpu=4,
+ val_dataloader=dict(samples_per_gpu=1),
+ test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',
diff --git a/configs/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py b/configs/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py
index e710835d..b7dba589 100644
--- a/configs/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py
+++ b/configs/textdet/textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py
@@ -94,6 +94,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
+ val_dataloader=dict(samples_per_gpu=1),
+ test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=f'{data_root}/instances_training.json',
diff --git a/configs/textrecog/crnn/crnn_academic_dataset.py b/configs/textrecog/crnn/crnn_academic_dataset.py
index 748625e4..1875f69f 100644
--- a/configs/textrecog/crnn/crnn_academic_dataset.py
+++ b/configs/textrecog/crnn/crnn_academic_dataset.py
@@ -148,6 +148,8 @@ test6['ann_file'] = test_ann_file6
data = dict(
samples_per_gpu=64,
workers_per_gpu=4,
+ val_dataloader=dict(samples_per_gpu=1),
+ test_dataloader=dict(samples_per_gpu=1),
train=dict(type='ConcatDataset', datasets=[train1]),
val=dict(
type='ConcatDataset',
diff --git a/configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py b/configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py
index 9299583b..92398b93 100644
--- a/configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py
+++ b/configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py
@@ -152,6 +152,8 @@ test6['ann_file'] = test_ann_file6
data = dict(
samples_per_gpu=128,
workers_per_gpu=4,
+ val_dataloader=dict(samples_per_gpu=1),
+ test_dataloader=dict(samples_per_gpu=1),
train=dict(type='ConcatDataset', datasets=[train1, train2]),
val=dict(
type='ConcatDataset',
diff --git a/configs/textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py b/configs/textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py
index b003b823..121a9f56 100644
--- a/configs/textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py
+++ b/configs/textrecog/nrtr/nrtr_r31_1by8_1by4_academic.py
@@ -152,6 +152,8 @@ test6['ann_file'] = test_ann_file6
data = dict(
samples_per_gpu=128,
workers_per_gpu=4,
+ val_dataloader=dict(samples_per_gpu=1),
+ test_dataloader=dict(samples_per_gpu=1),
train=dict(type='ConcatDataset', datasets=[train1, train2]),
val=dict(
type='ConcatDataset',
diff --git a/configs/textrecog/robust_scanner/robustscanner_r31_academic.py b/configs/textrecog/robust_scanner/robustscanner_r31_academic.py
index e90dd75d..7c1b8636 100644
--- a/configs/textrecog/robust_scanner/robustscanner_r31_academic.py
+++ b/configs/textrecog/robust_scanner/robustscanner_r31_academic.py
@@ -182,6 +182,8 @@ test6['ann_file'] = test_ann_file6
data = dict(
samples_per_gpu=64,
workers_per_gpu=2,
+ val_dataloader=dict(samples_per_gpu=1),
+ test_dataloader=dict(samples_per_gpu=1),
train=dict(
type='ConcatDataset',
datasets=[
diff --git a/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py b/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py
index 4e405227..f00ab0cc 100644
--- a/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py
+++ b/configs/textrecog/sar/sar_r31_parallel_decoder_academic.py
@@ -204,6 +204,8 @@ test6['ann_file'] = test_ann_file6
data = dict(
samples_per_gpu=64,
workers_per_gpu=2,
+ val_dataloader=dict(samples_per_gpu=1),
+ test_dataloader=dict(samples_per_gpu=1),
train=dict(
type='ConcatDataset',
datasets=[
diff --git a/configs/textrecog/sar/sar_r31_parallel_decoder_chinese.py b/configs/textrecog/sar/sar_r31_parallel_decoder_chinese.py
index 90e7a0de..5e58bbd7 100644
--- a/configs/textrecog/sar/sar_r31_parallel_decoder_chinese.py
+++ b/configs/textrecog/sar/sar_r31_parallel_decoder_chinese.py
@@ -119,6 +119,8 @@ test = dict(
data = dict(
samples_per_gpu=40,
workers_per_gpu=2,
+ val_dataloader=dict(samples_per_gpu=1),
+ test_dataloader=dict(samples_per_gpu=1),
train=dict(type='ConcatDataset', datasets=[train]),
val=dict(type='ConcatDataset', datasets=[test]),
test=dict(type='ConcatDataset', datasets=[test]))
diff --git a/configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py b/configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py
index 70c08464..75cd15e7 100755
--- a/configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py
+++ b/configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py
@@ -30,24 +30,19 @@ train_pipeline = [
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
- type='MultiRotateAugOCR',
- rotate_degrees=[0, 90, 270],
- transforms=[
- dict(
- type='ResizeOCR',
- height=48,
- min_width=48,
- max_width=160,
- keep_aspect_ratio=True),
- dict(type='ToTensorOCR'),
- dict(type='NormalizeOCR', **img_norm_cfg),
- dict(
- type='Collect',
- keys=['img'],
- meta_keys=[
- 'filename', 'ori_shape', 'img_shape', 'valid_ratio',
- 'img_norm_cfg', 'ori_filename'
- ]),
+ type='ResizeOCR',
+ height=48,
+ min_width=48,
+ max_width=160,
+ keep_aspect_ratio=True),
+ dict(type='ToTensorOCR'),
+ dict(type='NormalizeOCR', **img_norm_cfg),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'filename', 'ori_shape', 'img_shape', 'valid_ratio',
+ 'img_norm_cfg', 'ori_filename'
])
]
@@ -66,7 +61,7 @@ train1 = dict(
keys=['filename', 'text'],
keys_idx=[0, 1],
separator=' ')),
- pipeline=train_pipeline,
+ pipeline=None,
test_mode=False)
train_anno_file2 = 'tests/data/ocr_toy_dataset/label.lmdb'
@@ -82,7 +77,7 @@ train2 = dict(
keys=['filename', 'text'],
keys_idx=[0, 1],
separator=' ')),
- pipeline=train_pipeline,
+ pipeline=None,
test_mode=False)
test_anno_file1 = 'tests/data/ocr_toy_dataset/label.lmdb'
@@ -92,20 +87,25 @@ test = dict(
ann_file=test_anno_file1,
loader=dict(
type='LmdbLoader',
- repeat=1,
+ repeat=10,
parser=dict(
type='LineStrParser',
keys=['filename', 'text'],
keys_idx=[0, 1],
separator=' ')),
- pipeline=test_pipeline,
+ pipeline=None,
test_mode=True)
data = dict(
- samples_per_gpu=16,
workers_per_gpu=2,
- train=dict(type='ConcatDataset', datasets=[train1, train2]),
- val=dict(type='ConcatDataset', datasets=[test]),
- test=dict(type='ConcatDataset', datasets=[test]))
+ samples_per_gpu=8,
+ train=dict(
+ type='UniformConcatDataset',
+ datasets=[train1, train2],
+ pipeline=train_pipeline),
+ val=dict(
+ type='UniformConcatDataset', datasets=[test], pipeline=test_pipeline),
+ test=dict(
+ type='UniformConcatDataset', datasets=[test], pipeline=test_pipeline))
evaluation = dict(interval=1, metric='acc')
diff --git a/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py b/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py
index 6fa00dd7..5c468625 100644
--- a/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py
+++ b/configs/textrecog/sar/sar_r31_sequential_decoder_academic.py
@@ -204,6 +204,8 @@ test6['ann_file'] = test_ann_file6
data = dict(
samples_per_gpu=64,
workers_per_gpu=2,
+ val_dataloader=dict(samples_per_gpu=1),
+ test_dataloader=dict(samples_per_gpu=1),
train=dict(
type='ConcatDataset',
datasets=[
diff --git a/docs/getting_started.md b/docs/getting_started.md
index 40578e1d..fbab8553 100644
--- a/docs/getting_started.md
+++ b/docs/getting_started.md
@@ -202,7 +202,12 @@ To support the tasks of `text detection`, `text recognition` and `key informatio
Here we show some examples of using different combination of `loader` and `parser`.
-#### Encoder-Decoder-Based Text Recognition Task
+#### Text Recognition Task
+
+##### OCRDataset
+
+*Dataset for encoder-decoder based recognizer*
+
```python
dataset_type = 'OCRDataset'
img_prefix = 'tests/data/ocr_toy_dataset/imgs'
@@ -225,7 +230,7 @@ train = dict(
You can check the content of the annotation file in `tests/data/ocr_toy_dataset/label.txt`.
The combination of `HardDiskLoader` and `LineStrParser` will return a dict for each file by calling `__getitem__`: `{'filename': '1223731.jpg', 'text': 'GRAND'}`.
-##### Optional Arguments:
+**Optional Arguments:**
- `repeat`: The number of repeated lines in the annotation files. For example, if there are `10` lines in the annotation file, setting `repeat=10` will generate a corresponding annotation file with size `100`.
@@ -254,7 +259,10 @@ train = dict(
test_mode=False)
```
-#### Segmentation-Based Text Recognition Task
+##### OCRSegDataset
+
+*Dataset for segmentation-based recognizer*
+
```python
prefix = 'tests/data/ocr_char_ann_toy_dataset/'
train = dict(
@@ -277,6 +285,11 @@ The combination of `HardDiskLoader` and `LineJsonParser` will return a dict for
```
#### Text Detection Task
+
+##### TextDetDataset
+
+*Dataset with annotation file in line-json txt format*
+
```python
dataset_type = 'TextDetDataset'
img_prefix = 'tests/data/toy_dataset/imgs'
@@ -302,7 +315,10 @@ The combination of `HardDiskLoader` and `LineJsonParser` will return a dict for
```
-### COCO-like Dataset
+##### IcdarDataset
+
+*Dataset with annotation file in coco-like json format*
+
For text detection, you can also use an annotation file in a COCO format that is defined in [mmdet](https://github.com/open-mmlab/mmdetection/blob/master/mmdet/datasets/coco.py):
```python
dataset_type = 'IcdarDataset'
@@ -325,4 +341,40 @@ You can check the content of the annotation file in `tests/data/toy_dataset/inst
```shell
python tools/data_converter/ctw1500_converter.py ${src_root_path} -o ${out_path} --split-list training test
```
+
+#### UniformConcatDataset
+
+To use the `universal pipeline` for multiple datasets, we design `UniformConcatDataset`.
+For example, apply `train_pipeline` for both `train1` and `train2`,
+
+```python
+data = dict(
+ ...
+ train=dict(
+ type='UniformConcatDataset',
+ datasets=[train1, train2],
+ pipeline=train_pipeline))
```
+
+Meanwhile, we have
+- train_dataloader
+- val_dataloader
+- test_dataloader
+
+to give specific settings. They will override the general settings in `data` dict.
+For example,
+
+```python
+data = dict(
+ workers_per_gpu=2, # global setting
+ train_dataloader=dict(samples_per_gpu=8, drop_last=True), # train-specific setting
+ val_dataloader=dict(samples_per_gpu=8, workers_per_gpu=1), # val-specific setting
+ test_dataloader=dict(samples_per_gpu=8), # test-specific setting
+ ...
+```
+`workers_per_gpu` is global setting and `train_dataloader` and `val_dataloader` will inherit the values.
+`val_dataloader` override the value by `workers_per_gpu=1`.
+
+To activate `batch inference` for `val` and `test`, please set `val_dataloader=dict(samples_per_gpu=8)` and `test_dataloader=dict(samples_per_gpu=8)` as above.
+Or just set `samples_per_gpu=8` as global setting.
+See [config](/configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py) for an example.
diff --git a/mmocr/apis/inference.py b/mmocr/apis/inference.py
index e9e5b777..1c009a8b 100644
--- a/mmocr/apis/inference.py
+++ b/mmocr/apis/inference.py
@@ -6,23 +6,40 @@ from mmdet.datasets import replace_ImageToTensor
from mmdet.datasets.pipelines import Compose
-def disable_text_recog_aug_test(cfg):
+def disable_text_recog_aug_test(cfg, set_types=None):
"""Remove aug_test from test pipeline of text recognition.
Args:
cfg (mmcv.Config): Input config.
+ set_types (list[str]): Type of dataset source. Should be
+ None or sublist of ['test', 'val']
Returns:
cfg (mmcv.Config): Output config removing
`MultiRotateAugOCR` in test pipeline.
"""
- if cfg.data.test.pipeline[1].type == 'MultiRotateAugOCR':
- cfg.data.test.pipeline = [
- cfg.data.test.pipeline[0], *cfg.data.test.pipeline[1].transforms
- ]
+ assert set_types is None or isinstance(set_types, list)
+ if set_types is None:
+ set_types = ['val', 'test']
+ for set_type in set_types:
+ if cfg.data[set_type].pipeline[1].type == 'MultiRotateAugOCR':
+ cfg.data[set_type].pipeline = [
+ cfg.data[set_type].pipeline[0],
+ *cfg.data[set_type].pipeline[1].transforms
+ ]
+ assert_if_not_support_batch_mode(cfg, set_type)
return cfg
+def assert_if_not_support_batch_mode(cfg, set_type='test'):
+ if cfg.data[set_type].pipeline[1].type == 'ResizeOCR':
+ if cfg.data[set_type].pipeline[1].max_width is None:
+ raise Exception('Batch mode is not supported '
+ 'since the image width is not fixed, '
+ 'in the case that keeping aspect ratio but '
+ 'max_width is none when do resize.')
+
+
def model_inference(model, imgs, batch_mode=False):
"""Inference image(s) with the detector.
@@ -51,13 +68,7 @@ def model_inference(model, imgs, batch_mode=False):
cfg = model.cfg
if batch_mode:
- if cfg.data.test.pipeline[1].type == 'ResizeOCR':
- if cfg.data.test.pipeline[1].max_width is None:
- raise Exception('Free resize do not support batch mode '
- 'since the image width is not fixed, '
- 'for resize keeping aspect ratio and '
- 'max_width is not give.')
- cfg = disable_text_recog_aug_test(cfg)
+ cfg = disable_text_recog_aug_test(cfg, set_types=['test'])
device = next(model.parameters()).device # model device
diff --git a/mmocr/apis/train.py b/mmocr/apis/train.py
index e3bb9cc4..ecebeac4 100644
--- a/mmocr/apis/train.py
+++ b/mmocr/apis/train.py
@@ -10,6 +10,7 @@ from mmdet.core import DistEvalHook, EvalHook
from mmdet.datasets import (build_dataloader, build_dataset,
replace_ImageToTensor)
+from mmocr.apis.inference import disable_text_recog_aug_test
from mmocr.utils import get_root_logger
@@ -24,30 +25,32 @@ def train_detector(model,
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
- if 'imgs_per_gpu' in cfg.data:
- logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. '
- 'Please use "samples_per_gpu" instead')
- if 'samples_per_gpu' in cfg.data:
- logger.warning(
- f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
- f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
- f'={cfg.data.imgs_per_gpu} is used in this experiments')
- else:
- logger.warning(
- 'Automatically set "samples_per_gpu"="imgs_per_gpu"='
- f'{cfg.data.imgs_per_gpu} in this experiments')
- cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
-
- data_loaders = [
- build_dataloader(
- ds,
- cfg.data.samples_per_gpu,
- cfg.data.workers_per_gpu,
- # cfg.gpus will be ignored if distributed
- len(cfg.gpu_ids),
+ # step 1: give default values and override (if exist) from cfg.data
+ loader_cfg = {
+ **dict(
+ seed=cfg.get('seed'),
+ drop_last=False,
dist=distributed,
- seed=cfg.seed) for ds in dataset
- ]
+ num_gpus=len(cfg.gpu_ids)),
+ **({} if torch.__version__ != 'parrots' else dict(
+ prefetch_num=2,
+ pin_memory=False,
+ )),
+ **dict((k, cfg.data[k]) for k in [
+ 'samples_per_gpu',
+ 'workers_per_gpu',
+ 'shuffle',
+ 'seed',
+ 'drop_last',
+ 'prefetch_num',
+ 'pin_memory',
+ ] if k in cfg.data)
+ }
+
+ # step 2: cfg.data.train_dataloader has highest priority
+ train_loader_cfg = dict(loader_cfg, **cfg.data.get('train_dataloader', {}))
+
+ data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
# put model on gpus
if distributed:
@@ -110,19 +113,28 @@ def train_detector(model,
# register eval hooks
if validate:
- # Support batch_size > 1 in validation
- val_samples_per_gpu = cfg.data.val.pop('samples_per_gpu', 1)
+ val_samples_per_gpu = (cfg.data.get('val_dataloader', {})).get(
+ 'samples_per_gpu', cfg.data.get('samples_per_gpu', 1))
if val_samples_per_gpu > 1:
- # Replace 'ImageToTensor' to 'DefaultFormatBundle'
- cfg.data.val.pipeline = replace_ImageToTensor(
- cfg.data.val.pipeline)
+ # Support batch_size > 1 in test for text recognition
+ # by disable MultiRotateAugOCR since it is useless for most case
+ cfg = disable_text_recog_aug_test(cfg)
+ if cfg.data.val.get('pipeline', None) is not None:
+ # Replace 'ImageToTensor' to 'DefaultFormatBundle'
+ cfg.data.val.pipeline = replace_ImageToTensor(
+ cfg.data.val.pipeline)
+
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
- val_dataloader = build_dataloader(
- val_dataset,
- samples_per_gpu=val_samples_per_gpu,
- workers_per_gpu=cfg.data.workers_per_gpu,
- dist=distributed,
- shuffle=False)
+
+ val_loader_cfg = {
+ **loader_cfg,
+ **dict(shuffle=False, drop_last=False),
+ **cfg.data.get('val_dataloader', {}),
+ **dict(samples_per_gpu=val_samples_per_gpu)
+ }
+
+ val_dataloader = build_dataloader(val_dataset, **val_loader_cfg)
+
eval_cfg = cfg.get('evaluation', {})
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
eval_hook = DistEvalHook if distributed else EvalHook
diff --git a/mmocr/datasets/__init__.py b/mmocr/datasets/__init__.py
index b4bad7b4..886d098e 100644
--- a/mmocr/datasets/__init__.py
+++ b/mmocr/datasets/__init__.py
@@ -9,6 +9,7 @@ from .ocr_dataset import OCRDataset
from .ocr_seg_dataset import OCRSegDataset
from .pipelines import CustomFormatBundle, DBNetTargets, FCENetTargets
from .text_det_dataset import TextDetDataset
+from .uniform_concat_dataset import UniformConcatDataset
from .utils import * # NOQA
@@ -16,7 +17,7 @@ __all__ = [
'DATASETS', 'IcdarDataset', 'build_dataloader', 'build_dataset',
'BaseDataset', 'OCRDataset', 'TextDetDataset', 'CustomFormatBundle',
'DBNetTargets', 'OCRSegDataset', 'KIEDataset', 'FCENetTargets',
- 'NerDataset'
+ 'NerDataset', 'UniformConcatDataset'
]
__all__ += utils.__all__
diff --git a/mmocr/datasets/uniform_concat_dataset.py b/mmocr/datasets/uniform_concat_dataset.py
new file mode 100644
index 00000000..0fe7275c
--- /dev/null
+++ b/mmocr/datasets/uniform_concat_dataset.py
@@ -0,0 +1,27 @@
+import copy
+
+from mmdet.datasets import DATASETS, ConcatDataset, build_dataset
+
+
+@DATASETS.register_module()
+class UniformConcatDataset(ConcatDataset):
+ """A wrapper of concatenated dataset.
+
+ Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
+ concat the group flag for image aspect ratio.
+
+ Args:
+ datasets (list[:obj:`Dataset`]): A list of datasets.
+ separate_eval (bool): Whether to evaluate the results
+ separately if it is used as validation dataset.
+ Defaults to True.
+ """
+
+ def __init__(self, datasets, separate_eval=True, pipeline=None, **kwargs):
+ from_cfg = all(isinstance(x, dict) for x in datasets)
+ if pipeline is not None:
+ assert from_cfg, 'datasets should be config dicts'
+ for dataset in datasets:
+ dataset['pipeline'] = copy.deepcopy(pipeline)
+ datasets = [build_dataset(c, kwargs) for c in datasets]
+ super().__init__(datasets, separate_eval)
diff --git a/tests/test_apis/test_model_inference.py b/tests/test_apis/test_model_inference.py
index 63036f8d..e7c164aa 100644
--- a/tests/test_apis/test_model_inference.py
+++ b/tests/test_apis/test_model_inference.py
@@ -114,19 +114,19 @@ def test_model_batch_inference_raises_exception_error_free_resize_recog(
with pytest.raises(
Exception,
- match='Free resize do not support batch mode '
+ match='Batch mode is not supported '
'since the image width is not fixed, '
- 'for resize keeping aspect ratio and '
- 'max_width is not give.'):
+ 'in the case that keeping aspect ratio but '
+ 'max_width is none when do resize.'):
sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_recog.jpg')
model_inference(
model, [sample_img_path, sample_img_path], batch_mode=True)
with pytest.raises(
Exception,
- match='Free resize do not support batch mode '
+ match='Batch mode is not supported '
'since the image width is not fixed, '
- 'for resize keeping aspect ratio and '
- 'max_width is not give.'):
+ 'in the case that keeping aspect ratio but '
+ 'max_width is none when do resize.'):
img = imread(sample_img_path)
model_inference(model, [img, img], batch_mode=True)
diff --git a/tools/test.py b/tools/test.py
index 63593168..fd352460 100755
--- a/tools/test.py
+++ b/tools/test.py
@@ -13,6 +13,7 @@ from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
from mmdet.apis import multi_gpu_test, single_gpu_test
from mmdet.datasets import replace_ImageToTensor
+from mmocr.apis.inference import disable_text_recog_aug_test
from mmocr.datasets import build_dataloader, build_dataset
from mmocr.models import build_detector
@@ -140,12 +141,16 @@ def main():
# in case the test dataset is concatenated
samples_per_gpu = 1
if isinstance(cfg.data.test, dict):
- cfg.data.test.test_mode = True
- samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)
+ samples_per_gpu = (cfg.data.get('test_dataloader', {})).get(
+ 'samples_per_gpu', cfg.data.get('samples_per_gpu', 1))
if samples_per_gpu > 1:
- # Replace 'ImageToTensor' to 'DefaultFormatBundle'
- cfg.data.test.pipeline = replace_ImageToTensor(
- cfg.data.test.pipeline)
+ # Support batch_size > 1 in test for text recognition
+ # by disable MultiRotateAugOCR since it is useless for most case
+ cfg = disable_text_recog_aug_test(cfg)
+ if cfg.data.test.get('pipeline', None) is not None:
+ # Replace 'ImageToTensor' to 'DefaultFormatBundle'
+ cfg.data.test.pipeline = replace_ImageToTensor(
+ cfg.data.test.pipeline)
elif isinstance(cfg.data.test, list):
for ds_cfg in cfg.data.test:
ds_cfg.test_mode = True
@@ -163,13 +168,29 @@ def main():
init_dist(args.launcher, **cfg.dist_params)
# build the dataloader
- dataset = build_dataset(cfg.data.test)
- data_loader = build_dataloader(
- dataset,
- samples_per_gpu=samples_per_gpu,
- workers_per_gpu=cfg.data.workers_per_gpu,
- dist=distributed,
- shuffle=False)
+ dataset = build_dataset(cfg.data.test, dict(test_mode=True))
+ # step 1: give default values and override (if exist) from cfg.data
+ loader_cfg = {
+ **dict(seed=cfg.get('seed'), drop_last=False, dist=distributed),
+ **({} if torch.__version__ != 'parrots' else dict(
+ prefetch_num=2,
+ pin_memory=False,
+ )),
+ **dict((k, cfg.data[k]) for k in [
+ 'workers_per_gpu',
+ 'seed',
+ 'prefetch_num',
+ 'pin_memory',
+ ] if k in cfg.data)
+ }
+ test_loader_cfg = {
+ **loader_cfg,
+ **dict(shuffle=False, drop_last=False),
+ **cfg.data.get('test_dataloader', {}),
+ **dict(samples_per_gpu=samples_per_gpu)
+ }
+
+ data_loader = build_dataloader(dataset, **test_loader_cfg)
# build the model and load checkpoint
cfg.model.train_cfg = None