mirror of https://github.com/open-mmlab/mmocr.git
support batch inference during testing (#310)
* support batch inference during testing * fix unittest * update docs using url * set cfg for train, val and test * update docs * update docs and test.py * samples_per_gpu as global setting * changes revertpull/316/head
parent
e6cb750922
commit
82f64a5b62
|
@ -74,6 +74,8 @@ test_pipeline = [
|
|||
data = dict(
|
||||
samples_per_gpu=16,
|
||||
workers_per_gpu=8,
|
||||
val_dataloader=dict(samples_per_gpu=1),
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + '/instances_training.json',
|
||||
|
|
|
@ -83,6 +83,8 @@ test_pipeline = [
|
|||
data = dict(
|
||||
samples_per_gpu=8,
|
||||
workers_per_gpu=4,
|
||||
val_dataloader=dict(samples_per_gpu=1),
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + '/instances_training.json',
|
||||
|
|
|
@ -91,6 +91,8 @@ test_pipeline = [
|
|||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=4,
|
||||
val_dataloader=dict(samples_per_gpu=1),
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
ann_file=f'{data_root}/instances_training.json',
|
||||
|
|
|
@ -96,6 +96,8 @@ test_pipeline = [
|
|||
data = dict(
|
||||
samples_per_gpu=8,
|
||||
workers_per_gpu=2,
|
||||
val_dataloader=dict(samples_per_gpu=1),
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + '/instances_training.json',
|
||||
|
|
|
@ -95,6 +95,8 @@ test_pipeline = [
|
|||
data = dict(
|
||||
samples_per_gpu=6,
|
||||
workers_per_gpu=2,
|
||||
val_dataloader=dict(samples_per_gpu=1),
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + '/instances_training.json',
|
||||
|
|
|
@ -47,6 +47,8 @@ test_pipeline = [
|
|||
data = dict(
|
||||
samples_per_gpu=8,
|
||||
workers_per_gpu=4,
|
||||
val_dataloader=dict(samples_per_gpu=1),
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + '/instances_training.json',
|
||||
|
|
|
@ -46,6 +46,8 @@ test_pipeline = [
|
|||
data = dict(
|
||||
samples_per_gpu=8,
|
||||
workers_per_gpu=4,
|
||||
val_dataloader=dict(samples_per_gpu=1),
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + '/instances_training.json',
|
||||
|
|
|
@ -47,6 +47,8 @@ test_pipeline = [
|
|||
data = dict(
|
||||
samples_per_gpu=8,
|
||||
workers_per_gpu=4,
|
||||
val_dataloader=dict(samples_per_gpu=1),
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + '/instances_training.json',
|
||||
|
|
|
@ -82,6 +82,8 @@ test_pipeline = [
|
|||
data = dict(
|
||||
samples_per_gpu=2,
|
||||
workers_per_gpu=2,
|
||||
val_dataloader=dict(samples_per_gpu=1),
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + '/instances_training.json',
|
||||
|
|
|
@ -80,6 +80,8 @@ test_pipeline = [
|
|||
data = dict(
|
||||
samples_per_gpu=8,
|
||||
workers_per_gpu=2,
|
||||
val_dataloader=dict(samples_per_gpu=1),
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + '/instances_training.json',
|
||||
|
|
|
@ -75,6 +75,8 @@ test_pipeline = [
|
|||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=4,
|
||||
val_dataloader=dict(samples_per_gpu=1),
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + '/instances_training.json',
|
||||
|
|
|
@ -89,6 +89,8 @@ test_pipeline = [
|
|||
data = dict(
|
||||
samples_per_gpu=2,
|
||||
workers_per_gpu=2,
|
||||
val_dataloader=dict(samples_per_gpu=1),
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + '/instances_training.json',
|
||||
|
|
|
@ -89,6 +89,8 @@ test_pipeline = [
|
|||
data = dict(
|
||||
samples_per_gpu=8,
|
||||
workers_per_gpu=2,
|
||||
val_dataloader=dict(samples_per_gpu=1),
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + '/instances_training.json',
|
||||
|
|
|
@ -84,6 +84,8 @@ test_pipeline = [
|
|||
data = dict(
|
||||
samples_per_gpu=8,
|
||||
workers_per_gpu=4,
|
||||
val_dataloader=dict(samples_per_gpu=1),
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
ann_file=data_root + '/instances_training.json',
|
||||
|
|
|
@ -94,6 +94,8 @@ test_pipeline = [
|
|||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=4,
|
||||
val_dataloader=dict(samples_per_gpu=1),
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
ann_file=f'{data_root}/instances_training.json',
|
||||
|
|
|
@ -148,6 +148,8 @@ test6['ann_file'] = test_ann_file6
|
|||
data = dict(
|
||||
samples_per_gpu=64,
|
||||
workers_per_gpu=4,
|
||||
val_dataloader=dict(samples_per_gpu=1),
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
train=dict(type='ConcatDataset', datasets=[train1]),
|
||||
val=dict(
|
||||
type='ConcatDataset',
|
||||
|
|
|
@ -152,6 +152,8 @@ test6['ann_file'] = test_ann_file6
|
|||
data = dict(
|
||||
samples_per_gpu=128,
|
||||
workers_per_gpu=4,
|
||||
val_dataloader=dict(samples_per_gpu=1),
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
train=dict(type='ConcatDataset', datasets=[train1, train2]),
|
||||
val=dict(
|
||||
type='ConcatDataset',
|
||||
|
|
|
@ -152,6 +152,8 @@ test6['ann_file'] = test_ann_file6
|
|||
data = dict(
|
||||
samples_per_gpu=128,
|
||||
workers_per_gpu=4,
|
||||
val_dataloader=dict(samples_per_gpu=1),
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
train=dict(type='ConcatDataset', datasets=[train1, train2]),
|
||||
val=dict(
|
||||
type='ConcatDataset',
|
||||
|
|
|
@ -182,6 +182,8 @@ test6['ann_file'] = test_ann_file6
|
|||
data = dict(
|
||||
samples_per_gpu=64,
|
||||
workers_per_gpu=2,
|
||||
val_dataloader=dict(samples_per_gpu=1),
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
train=dict(
|
||||
type='ConcatDataset',
|
||||
datasets=[
|
||||
|
|
|
@ -204,6 +204,8 @@ test6['ann_file'] = test_ann_file6
|
|||
data = dict(
|
||||
samples_per_gpu=64,
|
||||
workers_per_gpu=2,
|
||||
val_dataloader=dict(samples_per_gpu=1),
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
train=dict(
|
||||
type='ConcatDataset',
|
||||
datasets=[
|
||||
|
|
|
@ -119,6 +119,8 @@ test = dict(
|
|||
data = dict(
|
||||
samples_per_gpu=40,
|
||||
workers_per_gpu=2,
|
||||
val_dataloader=dict(samples_per_gpu=1),
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
train=dict(type='ConcatDataset', datasets=[train]),
|
||||
val=dict(type='ConcatDataset', datasets=[test]),
|
||||
test=dict(type='ConcatDataset', datasets=[test]))
|
||||
|
|
|
@ -30,24 +30,19 @@ train_pipeline = [
|
|||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiRotateAugOCR',
|
||||
rotate_degrees=[0, 90, 270],
|
||||
transforms=[
|
||||
dict(
|
||||
type='ResizeOCR',
|
||||
height=48,
|
||||
min_width=48,
|
||||
max_width=160,
|
||||
keep_aspect_ratio=True),
|
||||
dict(type='ToTensorOCR'),
|
||||
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=[
|
||||
'filename', 'ori_shape', 'img_shape', 'valid_ratio',
|
||||
'img_norm_cfg', 'ori_filename'
|
||||
]),
|
||||
type='ResizeOCR',
|
||||
height=48,
|
||||
min_width=48,
|
||||
max_width=160,
|
||||
keep_aspect_ratio=True),
|
||||
dict(type='ToTensorOCR'),
|
||||
dict(type='NormalizeOCR', **img_norm_cfg),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=[
|
||||
'filename', 'ori_shape', 'img_shape', 'valid_ratio',
|
||||
'img_norm_cfg', 'ori_filename'
|
||||
])
|
||||
]
|
||||
|
||||
|
@ -66,7 +61,7 @@ train1 = dict(
|
|||
keys=['filename', 'text'],
|
||||
keys_idx=[0, 1],
|
||||
separator=' ')),
|
||||
pipeline=train_pipeline,
|
||||
pipeline=None,
|
||||
test_mode=False)
|
||||
|
||||
train_anno_file2 = 'tests/data/ocr_toy_dataset/label.lmdb'
|
||||
|
@ -82,7 +77,7 @@ train2 = dict(
|
|||
keys=['filename', 'text'],
|
||||
keys_idx=[0, 1],
|
||||
separator=' ')),
|
||||
pipeline=train_pipeline,
|
||||
pipeline=None,
|
||||
test_mode=False)
|
||||
|
||||
test_anno_file1 = 'tests/data/ocr_toy_dataset/label.lmdb'
|
||||
|
@ -92,20 +87,25 @@ test = dict(
|
|||
ann_file=test_anno_file1,
|
||||
loader=dict(
|
||||
type='LmdbLoader',
|
||||
repeat=1,
|
||||
repeat=10,
|
||||
parser=dict(
|
||||
type='LineStrParser',
|
||||
keys=['filename', 'text'],
|
||||
keys_idx=[0, 1],
|
||||
separator=' ')),
|
||||
pipeline=test_pipeline,
|
||||
pipeline=None,
|
||||
test_mode=True)
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=16,
|
||||
workers_per_gpu=2,
|
||||
train=dict(type='ConcatDataset', datasets=[train1, train2]),
|
||||
val=dict(type='ConcatDataset', datasets=[test]),
|
||||
test=dict(type='ConcatDataset', datasets=[test]))
|
||||
samples_per_gpu=8,
|
||||
train=dict(
|
||||
type='UniformConcatDataset',
|
||||
datasets=[train1, train2],
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type='UniformConcatDataset', datasets=[test], pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type='UniformConcatDataset', datasets=[test], pipeline=test_pipeline))
|
||||
|
||||
evaluation = dict(interval=1, metric='acc')
|
||||
|
|
|
@ -204,6 +204,8 @@ test6['ann_file'] = test_ann_file6
|
|||
data = dict(
|
||||
samples_per_gpu=64,
|
||||
workers_per_gpu=2,
|
||||
val_dataloader=dict(samples_per_gpu=1),
|
||||
test_dataloader=dict(samples_per_gpu=1),
|
||||
train=dict(
|
||||
type='ConcatDataset',
|
||||
datasets=[
|
||||
|
|
|
@ -202,7 +202,12 @@ To support the tasks of `text detection`, `text recognition` and `key informatio
|
|||
|
||||
Here we show some examples of using different combination of `loader` and `parser`.
|
||||
|
||||
#### Encoder-Decoder-Based Text Recognition Task
|
||||
#### Text Recognition Task
|
||||
|
||||
##### OCRDataset
|
||||
|
||||
<small>*Dataset for encoder-decoder based recognizer*</small>
|
||||
|
||||
```python
|
||||
dataset_type = 'OCRDataset'
|
||||
img_prefix = 'tests/data/ocr_toy_dataset/imgs'
|
||||
|
@ -225,7 +230,7 @@ train = dict(
|
|||
You can check the content of the annotation file in `tests/data/ocr_toy_dataset/label.txt`.
|
||||
The combination of `HardDiskLoader` and `LineStrParser` will return a dict for each file by calling `__getitem__`: `{'filename': '1223731.jpg', 'text': 'GRAND'}`.
|
||||
|
||||
##### Optional Arguments:
|
||||
**Optional Arguments:**
|
||||
|
||||
- `repeat`: The number of repeated lines in the annotation files. For example, if there are `10` lines in the annotation file, setting `repeat=10` will generate a corresponding annotation file with size `100`.
|
||||
|
||||
|
@ -254,7 +259,10 @@ train = dict(
|
|||
test_mode=False)
|
||||
```
|
||||
|
||||
#### Segmentation-Based Text Recognition Task
|
||||
##### OCRSegDataset
|
||||
|
||||
<small>*Dataset for segmentation-based recognizer*</small>
|
||||
|
||||
```python
|
||||
prefix = 'tests/data/ocr_char_ann_toy_dataset/'
|
||||
train = dict(
|
||||
|
@ -277,6 +285,11 @@ The combination of `HardDiskLoader` and `LineJsonParser` will return a dict for
|
|||
```
|
||||
|
||||
#### Text Detection Task
|
||||
|
||||
##### TextDetDataset
|
||||
|
||||
<small>*Dataset with annotation file in line-json txt format*</small>
|
||||
|
||||
```python
|
||||
dataset_type = 'TextDetDataset'
|
||||
img_prefix = 'tests/data/toy_dataset/imgs'
|
||||
|
@ -302,7 +315,10 @@ The combination of `HardDiskLoader` and `LineJsonParser` will return a dict for
|
|||
```
|
||||
|
||||
|
||||
### COCO-like Dataset
|
||||
##### IcdarDataset
|
||||
|
||||
<small>*Dataset with annotation file in coco-like json format*</small>
|
||||
|
||||
For text detection, you can also use an annotation file in a COCO format that is defined in [mmdet](https://github.com/open-mmlab/mmdetection/blob/master/mmdet/datasets/coco.py):
|
||||
```python
|
||||
dataset_type = 'IcdarDataset'
|
||||
|
@ -325,4 +341,40 @@ You can check the content of the annotation file in `tests/data/toy_dataset/inst
|
|||
```shell
|
||||
python tools/data_converter/ctw1500_converter.py ${src_root_path} -o ${out_path} --split-list training test
|
||||
```
|
||||
|
||||
#### UniformConcatDataset
|
||||
|
||||
To use the `universal pipeline` for multiple datasets, we design `UniformConcatDataset`.
|
||||
For example, apply `train_pipeline` for both `train1` and `train2`,
|
||||
|
||||
```python
|
||||
data = dict(
|
||||
...
|
||||
train=dict(
|
||||
type='UniformConcatDataset',
|
||||
datasets=[train1, train2],
|
||||
pipeline=train_pipeline))
|
||||
```
|
||||
|
||||
Meanwhile, we have
|
||||
- train_dataloader
|
||||
- val_dataloader
|
||||
- test_dataloader
|
||||
|
||||
to give specific settings. They will override the general settings in `data` dict.
|
||||
For example,
|
||||
|
||||
```python
|
||||
data = dict(
|
||||
workers_per_gpu=2, # global setting
|
||||
train_dataloader=dict(samples_per_gpu=8, drop_last=True), # train-specific setting
|
||||
val_dataloader=dict(samples_per_gpu=8, workers_per_gpu=1), # val-specific setting
|
||||
test_dataloader=dict(samples_per_gpu=8), # test-specific setting
|
||||
...
|
||||
```
|
||||
`workers_per_gpu` is global setting and `train_dataloader` and `val_dataloader` will inherit the values.
|
||||
`val_dataloader` override the value by `workers_per_gpu=1`.
|
||||
|
||||
To activate `batch inference` for `val` and `test`, please set `val_dataloader=dict(samples_per_gpu=8)` and `test_dataloader=dict(samples_per_gpu=8)` as above.
|
||||
Or just set `samples_per_gpu=8` as global setting.
|
||||
See [config](/configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py) for an example.
|
||||
|
|
|
@ -6,23 +6,40 @@ from mmdet.datasets import replace_ImageToTensor
|
|||
from mmdet.datasets.pipelines import Compose
|
||||
|
||||
|
||||
def disable_text_recog_aug_test(cfg):
|
||||
def disable_text_recog_aug_test(cfg, set_types=None):
|
||||
"""Remove aug_test from test pipeline of text recognition.
|
||||
Args:
|
||||
cfg (mmcv.Config): Input config.
|
||||
set_types (list[str]): Type of dataset source. Should be
|
||||
None or sublist of ['test', 'val']
|
||||
|
||||
Returns:
|
||||
cfg (mmcv.Config): Output config removing
|
||||
`MultiRotateAugOCR` in test pipeline.
|
||||
"""
|
||||
if cfg.data.test.pipeline[1].type == 'MultiRotateAugOCR':
|
||||
cfg.data.test.pipeline = [
|
||||
cfg.data.test.pipeline[0], *cfg.data.test.pipeline[1].transforms
|
||||
]
|
||||
assert set_types is None or isinstance(set_types, list)
|
||||
if set_types is None:
|
||||
set_types = ['val', 'test']
|
||||
for set_type in set_types:
|
||||
if cfg.data[set_type].pipeline[1].type == 'MultiRotateAugOCR':
|
||||
cfg.data[set_type].pipeline = [
|
||||
cfg.data[set_type].pipeline[0],
|
||||
*cfg.data[set_type].pipeline[1].transforms
|
||||
]
|
||||
assert_if_not_support_batch_mode(cfg, set_type)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
def assert_if_not_support_batch_mode(cfg, set_type='test'):
|
||||
if cfg.data[set_type].pipeline[1].type == 'ResizeOCR':
|
||||
if cfg.data[set_type].pipeline[1].max_width is None:
|
||||
raise Exception('Batch mode is not supported '
|
||||
'since the image width is not fixed, '
|
||||
'in the case that keeping aspect ratio but '
|
||||
'max_width is none when do resize.')
|
||||
|
||||
|
||||
def model_inference(model, imgs, batch_mode=False):
|
||||
"""Inference image(s) with the detector.
|
||||
|
||||
|
@ -51,13 +68,7 @@ def model_inference(model, imgs, batch_mode=False):
|
|||
cfg = model.cfg
|
||||
|
||||
if batch_mode:
|
||||
if cfg.data.test.pipeline[1].type == 'ResizeOCR':
|
||||
if cfg.data.test.pipeline[1].max_width is None:
|
||||
raise Exception('Free resize do not support batch mode '
|
||||
'since the image width is not fixed, '
|
||||
'for resize keeping aspect ratio and '
|
||||
'max_width is not give.')
|
||||
cfg = disable_text_recog_aug_test(cfg)
|
||||
cfg = disable_text_recog_aug_test(cfg, set_types=['test'])
|
||||
|
||||
device = next(model.parameters()).device # model device
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ from mmdet.core import DistEvalHook, EvalHook
|
|||
from mmdet.datasets import (build_dataloader, build_dataset,
|
||||
replace_ImageToTensor)
|
||||
|
||||
from mmocr.apis.inference import disable_text_recog_aug_test
|
||||
from mmocr.utils import get_root_logger
|
||||
|
||||
|
||||
|
@ -24,30 +25,32 @@ def train_detector(model,
|
|||
|
||||
# prepare data loaders
|
||||
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
|
||||
if 'imgs_per_gpu' in cfg.data:
|
||||
logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. '
|
||||
'Please use "samples_per_gpu" instead')
|
||||
if 'samples_per_gpu' in cfg.data:
|
||||
logger.warning(
|
||||
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
|
||||
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
|
||||
f'={cfg.data.imgs_per_gpu} is used in this experiments')
|
||||
else:
|
||||
logger.warning(
|
||||
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
|
||||
f'{cfg.data.imgs_per_gpu} in this experiments')
|
||||
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
|
||||
|
||||
data_loaders = [
|
||||
build_dataloader(
|
||||
ds,
|
||||
cfg.data.samples_per_gpu,
|
||||
cfg.data.workers_per_gpu,
|
||||
# cfg.gpus will be ignored if distributed
|
||||
len(cfg.gpu_ids),
|
||||
# step 1: give default values and override (if exist) from cfg.data
|
||||
loader_cfg = {
|
||||
**dict(
|
||||
seed=cfg.get('seed'),
|
||||
drop_last=False,
|
||||
dist=distributed,
|
||||
seed=cfg.seed) for ds in dataset
|
||||
]
|
||||
num_gpus=len(cfg.gpu_ids)),
|
||||
**({} if torch.__version__ != 'parrots' else dict(
|
||||
prefetch_num=2,
|
||||
pin_memory=False,
|
||||
)),
|
||||
**dict((k, cfg.data[k]) for k in [
|
||||
'samples_per_gpu',
|
||||
'workers_per_gpu',
|
||||
'shuffle',
|
||||
'seed',
|
||||
'drop_last',
|
||||
'prefetch_num',
|
||||
'pin_memory',
|
||||
] if k in cfg.data)
|
||||
}
|
||||
|
||||
# step 2: cfg.data.train_dataloader has highest priority
|
||||
train_loader_cfg = dict(loader_cfg, **cfg.data.get('train_dataloader', {}))
|
||||
|
||||
data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
|
||||
|
||||
# put model on gpus
|
||||
if distributed:
|
||||
|
@ -110,19 +113,28 @@ def train_detector(model,
|
|||
|
||||
# register eval hooks
|
||||
if validate:
|
||||
# Support batch_size > 1 in validation
|
||||
val_samples_per_gpu = cfg.data.val.pop('samples_per_gpu', 1)
|
||||
val_samples_per_gpu = (cfg.data.get('val_dataloader', {})).get(
|
||||
'samples_per_gpu', cfg.data.get('samples_per_gpu', 1))
|
||||
if val_samples_per_gpu > 1:
|
||||
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
|
||||
cfg.data.val.pipeline = replace_ImageToTensor(
|
||||
cfg.data.val.pipeline)
|
||||
# Support batch_size > 1 in test for text recognition
|
||||
# by disable MultiRotateAugOCR since it is useless for most case
|
||||
cfg = disable_text_recog_aug_test(cfg)
|
||||
if cfg.data.val.get('pipeline', None) is not None:
|
||||
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
|
||||
cfg.data.val.pipeline = replace_ImageToTensor(
|
||||
cfg.data.val.pipeline)
|
||||
|
||||
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
|
||||
val_dataloader = build_dataloader(
|
||||
val_dataset,
|
||||
samples_per_gpu=val_samples_per_gpu,
|
||||
workers_per_gpu=cfg.data.workers_per_gpu,
|
||||
dist=distributed,
|
||||
shuffle=False)
|
||||
|
||||
val_loader_cfg = {
|
||||
**loader_cfg,
|
||||
**dict(shuffle=False, drop_last=False),
|
||||
**cfg.data.get('val_dataloader', {}),
|
||||
**dict(samples_per_gpu=val_samples_per_gpu)
|
||||
}
|
||||
|
||||
val_dataloader = build_dataloader(val_dataset, **val_loader_cfg)
|
||||
|
||||
eval_cfg = cfg.get('evaluation', {})
|
||||
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
|
||||
eval_hook = DistEvalHook if distributed else EvalHook
|
||||
|
|
|
@ -9,6 +9,7 @@ from .ocr_dataset import OCRDataset
|
|||
from .ocr_seg_dataset import OCRSegDataset
|
||||
from .pipelines import CustomFormatBundle, DBNetTargets, FCENetTargets
|
||||
from .text_det_dataset import TextDetDataset
|
||||
from .uniform_concat_dataset import UniformConcatDataset
|
||||
|
||||
from .utils import * # NOQA
|
||||
|
||||
|
@ -16,7 +17,7 @@ __all__ = [
|
|||
'DATASETS', 'IcdarDataset', 'build_dataloader', 'build_dataset',
|
||||
'BaseDataset', 'OCRDataset', 'TextDetDataset', 'CustomFormatBundle',
|
||||
'DBNetTargets', 'OCRSegDataset', 'KIEDataset', 'FCENetTargets',
|
||||
'NerDataset'
|
||||
'NerDataset', 'UniformConcatDataset'
|
||||
]
|
||||
|
||||
__all__ += utils.__all__
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
import copy
|
||||
|
||||
from mmdet.datasets import DATASETS, ConcatDataset, build_dataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class UniformConcatDataset(ConcatDataset):
|
||||
"""A wrapper of concatenated dataset.
|
||||
|
||||
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
|
||||
concat the group flag for image aspect ratio.
|
||||
|
||||
Args:
|
||||
datasets (list[:obj:`Dataset`]): A list of datasets.
|
||||
separate_eval (bool): Whether to evaluate the results
|
||||
separately if it is used as validation dataset.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, datasets, separate_eval=True, pipeline=None, **kwargs):
|
||||
from_cfg = all(isinstance(x, dict) for x in datasets)
|
||||
if pipeline is not None:
|
||||
assert from_cfg, 'datasets should be config dicts'
|
||||
for dataset in datasets:
|
||||
dataset['pipeline'] = copy.deepcopy(pipeline)
|
||||
datasets = [build_dataset(c, kwargs) for c in datasets]
|
||||
super().__init__(datasets, separate_eval)
|
|
@ -114,19 +114,19 @@ def test_model_batch_inference_raises_exception_error_free_resize_recog(
|
|||
|
||||
with pytest.raises(
|
||||
Exception,
|
||||
match='Free resize do not support batch mode '
|
||||
match='Batch mode is not supported '
|
||||
'since the image width is not fixed, '
|
||||
'for resize keeping aspect ratio and '
|
||||
'max_width is not give.'):
|
||||
'in the case that keeping aspect ratio but '
|
||||
'max_width is none when do resize.'):
|
||||
sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_recog.jpg')
|
||||
model_inference(
|
||||
model, [sample_img_path, sample_img_path], batch_mode=True)
|
||||
|
||||
with pytest.raises(
|
||||
Exception,
|
||||
match='Free resize do not support batch mode '
|
||||
match='Batch mode is not supported '
|
||||
'since the image width is not fixed, '
|
||||
'for resize keeping aspect ratio and '
|
||||
'max_width is not give.'):
|
||||
'in the case that keeping aspect ratio but '
|
||||
'max_width is none when do resize.'):
|
||||
img = imread(sample_img_path)
|
||||
model_inference(model, [img, img], batch_mode=True)
|
||||
|
|
|
@ -13,6 +13,7 @@ from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
|
|||
from mmdet.apis import multi_gpu_test, single_gpu_test
|
||||
from mmdet.datasets import replace_ImageToTensor
|
||||
|
||||
from mmocr.apis.inference import disable_text_recog_aug_test
|
||||
from mmocr.datasets import build_dataloader, build_dataset
|
||||
from mmocr.models import build_detector
|
||||
|
||||
|
@ -140,12 +141,16 @@ def main():
|
|||
# in case the test dataset is concatenated
|
||||
samples_per_gpu = 1
|
||||
if isinstance(cfg.data.test, dict):
|
||||
cfg.data.test.test_mode = True
|
||||
samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)
|
||||
samples_per_gpu = (cfg.data.get('test_dataloader', {})).get(
|
||||
'samples_per_gpu', cfg.data.get('samples_per_gpu', 1))
|
||||
if samples_per_gpu > 1:
|
||||
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
|
||||
cfg.data.test.pipeline = replace_ImageToTensor(
|
||||
cfg.data.test.pipeline)
|
||||
# Support batch_size > 1 in test for text recognition
|
||||
# by disable MultiRotateAugOCR since it is useless for most case
|
||||
cfg = disable_text_recog_aug_test(cfg)
|
||||
if cfg.data.test.get('pipeline', None) is not None:
|
||||
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
|
||||
cfg.data.test.pipeline = replace_ImageToTensor(
|
||||
cfg.data.test.pipeline)
|
||||
elif isinstance(cfg.data.test, list):
|
||||
for ds_cfg in cfg.data.test:
|
||||
ds_cfg.test_mode = True
|
||||
|
@ -163,13 +168,29 @@ def main():
|
|||
init_dist(args.launcher, **cfg.dist_params)
|
||||
|
||||
# build the dataloader
|
||||
dataset = build_dataset(cfg.data.test)
|
||||
data_loader = build_dataloader(
|
||||
dataset,
|
||||
samples_per_gpu=samples_per_gpu,
|
||||
workers_per_gpu=cfg.data.workers_per_gpu,
|
||||
dist=distributed,
|
||||
shuffle=False)
|
||||
dataset = build_dataset(cfg.data.test, dict(test_mode=True))
|
||||
# step 1: give default values and override (if exist) from cfg.data
|
||||
loader_cfg = {
|
||||
**dict(seed=cfg.get('seed'), drop_last=False, dist=distributed),
|
||||
**({} if torch.__version__ != 'parrots' else dict(
|
||||
prefetch_num=2,
|
||||
pin_memory=False,
|
||||
)),
|
||||
**dict((k, cfg.data[k]) for k in [
|
||||
'workers_per_gpu',
|
||||
'seed',
|
||||
'prefetch_num',
|
||||
'pin_memory',
|
||||
] if k in cfg.data)
|
||||
}
|
||||
test_loader_cfg = {
|
||||
**loader_cfg,
|
||||
**dict(shuffle=False, drop_last=False),
|
||||
**cfg.data.get('test_dataloader', {}),
|
||||
**dict(samples_per_gpu=samples_per_gpu)
|
||||
}
|
||||
|
||||
data_loader = build_dataloader(dataset, **test_loader_cfg)
|
||||
|
||||
# build the model and load checkpoint
|
||||
cfg.model.train_cfg = None
|
||||
|
|
Loading…
Reference in New Issue