support batch inference during testing (#310)

* support batch inference during testing

* fix unittest

* update docs using url

* set cfg for train, val and test

* update docs

* update docs and test.py

* samples_per_gpu as global setting

* changes revert
pull/316/head
Hongbin Sun 2021-06-23 11:34:29 +08:00 committed by GitHub
parent e6cb750922
commit 82f64a5b62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 263 additions and 95 deletions

View File

@ -74,6 +74,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=16,
workers_per_gpu=8,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',

View File

@ -83,6 +83,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=8,
workers_per_gpu=4,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',

View File

@ -91,6 +91,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=f'{data_root}/instances_training.json',

View File

@ -96,6 +96,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=8,
workers_per_gpu=2,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',

View File

@ -95,6 +95,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=6,
workers_per_gpu=2,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',

View File

@ -47,6 +47,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=8,
workers_per_gpu=4,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',

View File

@ -46,6 +46,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=8,
workers_per_gpu=4,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',

View File

@ -47,6 +47,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=8,
workers_per_gpu=4,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',

View File

@ -82,6 +82,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',

View File

@ -80,6 +80,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=8,
workers_per_gpu=2,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',

View File

@ -75,6 +75,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',

View File

@ -89,6 +89,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',

View File

@ -89,6 +89,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=8,
workers_per_gpu=2,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',

View File

@ -84,6 +84,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=8,
workers_per_gpu=4,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',

View File

@ -94,6 +94,8 @@ test_pipeline = [
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type=dataset_type,
ann_file=f'{data_root}/instances_training.json',

View File

@ -148,6 +148,8 @@ test6['ann_file'] = test_ann_file6
data = dict(
samples_per_gpu=64,
workers_per_gpu=4,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(type='ConcatDataset', datasets=[train1]),
val=dict(
type='ConcatDataset',

View File

@ -152,6 +152,8 @@ test6['ann_file'] = test_ann_file6
data = dict(
samples_per_gpu=128,
workers_per_gpu=4,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(type='ConcatDataset', datasets=[train1, train2]),
val=dict(
type='ConcatDataset',

View File

@ -152,6 +152,8 @@ test6['ann_file'] = test_ann_file6
data = dict(
samples_per_gpu=128,
workers_per_gpu=4,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(type='ConcatDataset', datasets=[train1, train2]),
val=dict(
type='ConcatDataset',

View File

@ -182,6 +182,8 @@ test6['ann_file'] = test_ann_file6
data = dict(
samples_per_gpu=64,
workers_per_gpu=2,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type='ConcatDataset',
datasets=[

View File

@ -204,6 +204,8 @@ test6['ann_file'] = test_ann_file6
data = dict(
samples_per_gpu=64,
workers_per_gpu=2,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type='ConcatDataset',
datasets=[

View File

@ -119,6 +119,8 @@ test = dict(
data = dict(
samples_per_gpu=40,
workers_per_gpu=2,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(type='ConcatDataset', datasets=[train]),
val=dict(type='ConcatDataset', datasets=[test]),
test=dict(type='ConcatDataset', datasets=[test]))

View File

@ -30,24 +30,19 @@ train_pipeline = [
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiRotateAugOCR',
rotate_degrees=[0, 90, 270],
transforms=[
dict(
type='ResizeOCR',
height=48,
min_width=48,
max_width=160,
keep_aspect_ratio=True),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'valid_ratio',
'img_norm_cfg', 'ori_filename'
]),
type='ResizeOCR',
height=48,
min_width=48,
max_width=160,
keep_aspect_ratio=True),
dict(type='ToTensorOCR'),
dict(type='NormalizeOCR', **img_norm_cfg),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'filename', 'ori_shape', 'img_shape', 'valid_ratio',
'img_norm_cfg', 'ori_filename'
])
]
@ -66,7 +61,7 @@ train1 = dict(
keys=['filename', 'text'],
keys_idx=[0, 1],
separator=' ')),
pipeline=train_pipeline,
pipeline=None,
test_mode=False)
train_anno_file2 = 'tests/data/ocr_toy_dataset/label.lmdb'
@ -82,7 +77,7 @@ train2 = dict(
keys=['filename', 'text'],
keys_idx=[0, 1],
separator=' ')),
pipeline=train_pipeline,
pipeline=None,
test_mode=False)
test_anno_file1 = 'tests/data/ocr_toy_dataset/label.lmdb'
@ -92,20 +87,25 @@ test = dict(
ann_file=test_anno_file1,
loader=dict(
type='LmdbLoader',
repeat=1,
repeat=10,
parser=dict(
type='LineStrParser',
keys=['filename', 'text'],
keys_idx=[0, 1],
separator=' ')),
pipeline=test_pipeline,
pipeline=None,
test_mode=True)
data = dict(
samples_per_gpu=16,
workers_per_gpu=2,
train=dict(type='ConcatDataset', datasets=[train1, train2]),
val=dict(type='ConcatDataset', datasets=[test]),
test=dict(type='ConcatDataset', datasets=[test]))
samples_per_gpu=8,
train=dict(
type='UniformConcatDataset',
datasets=[train1, train2],
pipeline=train_pipeline),
val=dict(
type='UniformConcatDataset', datasets=[test], pipeline=test_pipeline),
test=dict(
type='UniformConcatDataset', datasets=[test], pipeline=test_pipeline))
evaluation = dict(interval=1, metric='acc')

View File

@ -204,6 +204,8 @@ test6['ann_file'] = test_ann_file6
data = dict(
samples_per_gpu=64,
workers_per_gpu=2,
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type='ConcatDataset',
datasets=[

View File

@ -202,7 +202,12 @@ To support the tasks of `text detection`, `text recognition` and `key informatio
Here we show some examples of using different combination of `loader` and `parser`.
#### Encoder-Decoder-Based Text Recognition Task
#### Text Recognition Task
##### OCRDataset
<small>*Dataset for encoder-decoder based recognizer*</small>
```python
dataset_type = 'OCRDataset'
img_prefix = 'tests/data/ocr_toy_dataset/imgs'
@ -225,7 +230,7 @@ train = dict(
You can check the content of the annotation file in `tests/data/ocr_toy_dataset/label.txt`.
The combination of `HardDiskLoader` and `LineStrParser` will return a dict for each file by calling `__getitem__`: `{'filename': '1223731.jpg', 'text': 'GRAND'}`.
##### Optional Arguments:
**Optional Arguments:**
- `repeat`: The number of repeated lines in the annotation files. For example, if there are `10` lines in the annotation file, setting `repeat=10` will generate a corresponding annotation file with size `100`.
@ -254,7 +259,10 @@ train = dict(
test_mode=False)
```
#### Segmentation-Based Text Recognition Task
##### OCRSegDataset
<small>*Dataset for segmentation-based recognizer*</small>
```python
prefix = 'tests/data/ocr_char_ann_toy_dataset/'
train = dict(
@ -277,6 +285,11 @@ The combination of `HardDiskLoader` and `LineJsonParser` will return a dict for
```
#### Text Detection Task
##### TextDetDataset
<small>*Dataset with annotation file in line-json txt format*</small>
```python
dataset_type = 'TextDetDataset'
img_prefix = 'tests/data/toy_dataset/imgs'
@ -302,7 +315,10 @@ The combination of `HardDiskLoader` and `LineJsonParser` will return a dict for
```
### COCO-like Dataset
##### IcdarDataset
<small>*Dataset with annotation file in coco-like json format*</small>
For text detection, you can also use an annotation file in a COCO format that is defined in [mmdet](https://github.com/open-mmlab/mmdetection/blob/master/mmdet/datasets/coco.py):
```python
dataset_type = 'IcdarDataset'
@ -325,4 +341,40 @@ You can check the content of the annotation file in `tests/data/toy_dataset/inst
```shell
python tools/data_converter/ctw1500_converter.py ${src_root_path} -o ${out_path} --split-list training test
```
#### UniformConcatDataset
To use the `universal pipeline` for multiple datasets, we design `UniformConcatDataset`.
For example, apply `train_pipeline` for both `train1` and `train2`,
```python
data = dict(
...
train=dict(
type='UniformConcatDataset',
datasets=[train1, train2],
pipeline=train_pipeline))
```
Meanwhile, we have
- train_dataloader
- val_dataloader
- test_dataloader
to give specific settings. They will override the general settings in `data` dict.
For example,
```python
data = dict(
workers_per_gpu=2, # global setting
train_dataloader=dict(samples_per_gpu=8, drop_last=True), # train-specific setting
val_dataloader=dict(samples_per_gpu=8, workers_per_gpu=1), # val-specific setting
test_dataloader=dict(samples_per_gpu=8), # test-specific setting
...
```
`workers_per_gpu` is global setting and `train_dataloader` and `val_dataloader` will inherit the values.
`val_dataloader` override the value by `workers_per_gpu=1`.
To activate `batch inference` for `val` and `test`, please set `val_dataloader=dict(samples_per_gpu=8)` and `test_dataloader=dict(samples_per_gpu=8)` as above.
Or just set `samples_per_gpu=8` as global setting.
See [config](/configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py) for an example.

View File

@ -6,23 +6,40 @@ from mmdet.datasets import replace_ImageToTensor
from mmdet.datasets.pipelines import Compose
def disable_text_recog_aug_test(cfg):
def disable_text_recog_aug_test(cfg, set_types=None):
"""Remove aug_test from test pipeline of text recognition.
Args:
cfg (mmcv.Config): Input config.
set_types (list[str]): Type of dataset source. Should be
None or sublist of ['test', 'val']
Returns:
cfg (mmcv.Config): Output config removing
`MultiRotateAugOCR` in test pipeline.
"""
if cfg.data.test.pipeline[1].type == 'MultiRotateAugOCR':
cfg.data.test.pipeline = [
cfg.data.test.pipeline[0], *cfg.data.test.pipeline[1].transforms
]
assert set_types is None or isinstance(set_types, list)
if set_types is None:
set_types = ['val', 'test']
for set_type in set_types:
if cfg.data[set_type].pipeline[1].type == 'MultiRotateAugOCR':
cfg.data[set_type].pipeline = [
cfg.data[set_type].pipeline[0],
*cfg.data[set_type].pipeline[1].transforms
]
assert_if_not_support_batch_mode(cfg, set_type)
return cfg
def assert_if_not_support_batch_mode(cfg, set_type='test'):
if cfg.data[set_type].pipeline[1].type == 'ResizeOCR':
if cfg.data[set_type].pipeline[1].max_width is None:
raise Exception('Batch mode is not supported '
'since the image width is not fixed, '
'in the case that keeping aspect ratio but '
'max_width is none when do resize.')
def model_inference(model, imgs, batch_mode=False):
"""Inference image(s) with the detector.
@ -51,13 +68,7 @@ def model_inference(model, imgs, batch_mode=False):
cfg = model.cfg
if batch_mode:
if cfg.data.test.pipeline[1].type == 'ResizeOCR':
if cfg.data.test.pipeline[1].max_width is None:
raise Exception('Free resize do not support batch mode '
'since the image width is not fixed, '
'for resize keeping aspect ratio and '
'max_width is not give.')
cfg = disable_text_recog_aug_test(cfg)
cfg = disable_text_recog_aug_test(cfg, set_types=['test'])
device = next(model.parameters()).device # model device

View File

@ -10,6 +10,7 @@ from mmdet.core import DistEvalHook, EvalHook
from mmdet.datasets import (build_dataloader, build_dataset,
replace_ImageToTensor)
from mmocr.apis.inference import disable_text_recog_aug_test
from mmocr.utils import get_root_logger
@ -24,30 +25,32 @@ def train_detector(model,
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
if 'imgs_per_gpu' in cfg.data:
logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. '
'Please use "samples_per_gpu" instead')
if 'samples_per_gpu' in cfg.data:
logger.warning(
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
f'={cfg.data.imgs_per_gpu} is used in this experiments')
else:
logger.warning(
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
f'{cfg.data.imgs_per_gpu} in this experiments')
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
# cfg.gpus will be ignored if distributed
len(cfg.gpu_ids),
# step 1: give default values and override (if exist) from cfg.data
loader_cfg = {
**dict(
seed=cfg.get('seed'),
drop_last=False,
dist=distributed,
seed=cfg.seed) for ds in dataset
]
num_gpus=len(cfg.gpu_ids)),
**({} if torch.__version__ != 'parrots' else dict(
prefetch_num=2,
pin_memory=False,
)),
**dict((k, cfg.data[k]) for k in [
'samples_per_gpu',
'workers_per_gpu',
'shuffle',
'seed',
'drop_last',
'prefetch_num',
'pin_memory',
] if k in cfg.data)
}
# step 2: cfg.data.train_dataloader has highest priority
train_loader_cfg = dict(loader_cfg, **cfg.data.get('train_dataloader', {}))
data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
# put model on gpus
if distributed:
@ -110,19 +113,28 @@ def train_detector(model,
# register eval hooks
if validate:
# Support batch_size > 1 in validation
val_samples_per_gpu = cfg.data.val.pop('samples_per_gpu', 1)
val_samples_per_gpu = (cfg.data.get('val_dataloader', {})).get(
'samples_per_gpu', cfg.data.get('samples_per_gpu', 1))
if val_samples_per_gpu > 1:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg.data.val.pipeline = replace_ImageToTensor(
cfg.data.val.pipeline)
# Support batch_size > 1 in test for text recognition
# by disable MultiRotateAugOCR since it is useless for most case
cfg = disable_text_recog_aug_test(cfg)
if cfg.data.val.get('pipeline', None) is not None:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg.data.val.pipeline = replace_ImageToTensor(
cfg.data.val.pipeline)
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
val_dataloader = build_dataloader(
val_dataset,
samples_per_gpu=val_samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
val_loader_cfg = {
**loader_cfg,
**dict(shuffle=False, drop_last=False),
**cfg.data.get('val_dataloader', {}),
**dict(samples_per_gpu=val_samples_per_gpu)
}
val_dataloader = build_dataloader(val_dataset, **val_loader_cfg)
eval_cfg = cfg.get('evaluation', {})
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
eval_hook = DistEvalHook if distributed else EvalHook

View File

@ -9,6 +9,7 @@ from .ocr_dataset import OCRDataset
from .ocr_seg_dataset import OCRSegDataset
from .pipelines import CustomFormatBundle, DBNetTargets, FCENetTargets
from .text_det_dataset import TextDetDataset
from .uniform_concat_dataset import UniformConcatDataset
from .utils import * # NOQA
@ -16,7 +17,7 @@ __all__ = [
'DATASETS', 'IcdarDataset', 'build_dataloader', 'build_dataset',
'BaseDataset', 'OCRDataset', 'TextDetDataset', 'CustomFormatBundle',
'DBNetTargets', 'OCRSegDataset', 'KIEDataset', 'FCENetTargets',
'NerDataset'
'NerDataset', 'UniformConcatDataset'
]
__all__ += utils.__all__

View File

@ -0,0 +1,27 @@
import copy
from mmdet.datasets import DATASETS, ConcatDataset, build_dataset
@DATASETS.register_module()
class UniformConcatDataset(ConcatDataset):
"""A wrapper of concatenated dataset.
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
concat the group flag for image aspect ratio.
Args:
datasets (list[:obj:`Dataset`]): A list of datasets.
separate_eval (bool): Whether to evaluate the results
separately if it is used as validation dataset.
Defaults to True.
"""
def __init__(self, datasets, separate_eval=True, pipeline=None, **kwargs):
from_cfg = all(isinstance(x, dict) for x in datasets)
if pipeline is not None:
assert from_cfg, 'datasets should be config dicts'
for dataset in datasets:
dataset['pipeline'] = copy.deepcopy(pipeline)
datasets = [build_dataset(c, kwargs) for c in datasets]
super().__init__(datasets, separate_eval)

View File

@ -114,19 +114,19 @@ def test_model_batch_inference_raises_exception_error_free_resize_recog(
with pytest.raises(
Exception,
match='Free resize do not support batch mode '
match='Batch mode is not supported '
'since the image width is not fixed, '
'for resize keeping aspect ratio and '
'max_width is not give.'):
'in the case that keeping aspect ratio but '
'max_width is none when do resize.'):
sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_recog.jpg')
model_inference(
model, [sample_img_path, sample_img_path], batch_mode=True)
with pytest.raises(
Exception,
match='Free resize do not support batch mode '
match='Batch mode is not supported '
'since the image width is not fixed, '
'for resize keeping aspect ratio and '
'max_width is not give.'):
'in the case that keeping aspect ratio but '
'max_width is none when do resize.'):
img = imread(sample_img_path)
model_inference(model, [img, img], batch_mode=True)

View File

@ -13,6 +13,7 @@ from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
from mmdet.apis import multi_gpu_test, single_gpu_test
from mmdet.datasets import replace_ImageToTensor
from mmocr.apis.inference import disable_text_recog_aug_test
from mmocr.datasets import build_dataloader, build_dataset
from mmocr.models import build_detector
@ -140,12 +141,16 @@ def main():
# in case the test dataset is concatenated
samples_per_gpu = 1
if isinstance(cfg.data.test, dict):
cfg.data.test.test_mode = True
samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)
samples_per_gpu = (cfg.data.get('test_dataloader', {})).get(
'samples_per_gpu', cfg.data.get('samples_per_gpu', 1))
if samples_per_gpu > 1:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg.data.test.pipeline = replace_ImageToTensor(
cfg.data.test.pipeline)
# Support batch_size > 1 in test for text recognition
# by disable MultiRotateAugOCR since it is useless for most case
cfg = disable_text_recog_aug_test(cfg)
if cfg.data.test.get('pipeline', None) is not None:
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
cfg.data.test.pipeline = replace_ImageToTensor(
cfg.data.test.pipeline)
elif isinstance(cfg.data.test, list):
for ds_cfg in cfg.data.test:
ds_cfg.test_mode = True
@ -163,13 +168,29 @@ def main():
init_dist(args.launcher, **cfg.dist_params)
# build the dataloader
dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(
dataset,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
dataset = build_dataset(cfg.data.test, dict(test_mode=True))
# step 1: give default values and override (if exist) from cfg.data
loader_cfg = {
**dict(seed=cfg.get('seed'), drop_last=False, dist=distributed),
**({} if torch.__version__ != 'parrots' else dict(
prefetch_num=2,
pin_memory=False,
)),
**dict((k, cfg.data[k]) for k in [
'workers_per_gpu',
'seed',
'prefetch_num',
'pin_memory',
] if k in cfg.data)
}
test_loader_cfg = {
**loader_cfg,
**dict(shuffle=False, drop_last=False),
**cfg.data.get('test_dataloader', {}),
**dict(samples_per_gpu=samples_per_gpu)
}
data_loader = build_dataloader(dataset, **test_loader_cfg)
# build the model and load checkpoint
cfg.model.train_cfg = None