mirror of https://github.com/alibaba/EasyCV.git
61 lines
2.5 KiB
Python
61 lines
2.5 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy
|
|
import warnings
|
|
|
|
|
|
def replace_ImageToTensor(pipelines):
|
|
"""Replace the ImageToTensor transform in a data pipeline to
|
|
DefaultFormatBundle, which is normally useful in batch inference.
|
|
Args:
|
|
pipelines (list[dict]): Data pipeline configs.
|
|
Returns:
|
|
list: The new pipeline list with all ImageToTensor replaced by
|
|
DefaultFormatBundle.
|
|
Examples:
|
|
>>> pipelines = [
|
|
... dict(type='LoadImageFromFile'),
|
|
... dict(
|
|
... type='MultiScaleFlipAug',
|
|
... img_scale=(1333, 800),
|
|
... flip=False,
|
|
... transforms=[
|
|
... dict(type='Resize', keep_ratio=True),
|
|
... dict(type='RandomFlip'),
|
|
... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]),
|
|
... dict(type='Pad', size_divisor=32),
|
|
... dict(type='ImageToTensor', keys=['img']),
|
|
... dict(type='Collect', keys=['img']),
|
|
... ])
|
|
... ]
|
|
>>> expected_pipelines = [
|
|
... dict(type='LoadImageFromFile'),
|
|
... dict(
|
|
... type='MultiScaleFlipAug',
|
|
... img_scale=(1333, 800),
|
|
... flip=False,
|
|
... transforms=[
|
|
... dict(type='Resize', keep_ratio=True),
|
|
... dict(type='RandomFlip'),
|
|
... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]),
|
|
... dict(type='Pad', size_divisor=32),
|
|
... dict(type='DefaultFormatBundle'),
|
|
... dict(type='Collect', keys=['img']),
|
|
... ])
|
|
... ]
|
|
>>> assert expected_pipelines == replace_ImageToTensor(pipelines)
|
|
"""
|
|
pipelines = copy.deepcopy(pipelines)
|
|
for i, pipeline in enumerate(pipelines):
|
|
if pipeline['type'] == 'MMMultiScaleFlipAug':
|
|
assert 'transforms' in pipeline
|
|
pipeline['transforms'] = replace_ImageToTensor(
|
|
pipeline['transforms'])
|
|
elif pipeline['type'] == 'ImageToTensor':
|
|
warnings.warn(
|
|
'"ImageToTensor" pipeline is replaced by '
|
|
'"DefaultFormatBundle" for batch inference. It is '
|
|
'recommended to manually replace it in the test '
|
|
'data pipeline in your config file.', UserWarning)
|
|
pipelines[i] = {'type': 'DefaultFormatBundle'}
|
|
return pipelines
|