[Feature] Add LoadImageFromNdArray pipeline (#1810)

* add load image from ndarray pipeline

* fix import
pull/1794/head
谢昕辰 2022-07-22 19:40:00 +08:00 committed by GitHub
parent 4de57b49c5
commit ba4d1d62aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 80 additions and 5 deletions

View File

@ -15,7 +15,8 @@ from .night_driving import NightDrivingDataset
from .pascal_context import PascalContextDataset, PascalContextDataset59
from .potsdam import PotsdamDataset
from .stare import STAREDataset
from .transforms import (CLAHE, AdjustGamma, LoadAnnotations, PackSegInputs,
from .transforms import (CLAHE, AdjustGamma, LoadAnnotations,
LoadImageFromNDArray, PackSegInputs,
PhotoMetricDistortion, RandomCrop, RandomCutOut,
RandomMosaic, RandomRotate, Rerange, ResizeToMultiple,
RGB2Gray, SegRescale)
@ -29,5 +30,6 @@ __all__ = [
'MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset',
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple'
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
'LoadImageFromNDArray'
]

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .formatting import PackSegInputs
from .loading import LoadAnnotations
from .loading import LoadAnnotations, LoadImageFromNDArray
from .transforms import (CLAHE, AdjustGamma, PhotoMetricDistortion, RandomCrop,
RandomCutOut, RandomMosaic, RandomRotate, Rerange,
ResizeToMultiple, RGB2Gray, SegRescale)
@ -8,5 +8,6 @@ from .transforms import (CLAHE, AdjustGamma, PhotoMetricDistortion, RandomCrop,
__all__ = [
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple'
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
'LoadImageFromNDArray'
]

View File

@ -4,6 +4,7 @@ import warnings
import mmcv
import numpy as np
from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations
from mmcv.transforms import LoadImageFromFile
from mmseg.registry import TRANSFORMS
@ -120,3 +121,50 @@ class LoadAnnotations(MMCV_LoadAnnotations):
repr_str += f"imdecode_backend='{self.imdecode_backend}')"
repr_str += f'file_client_args={self.file_client_args})'
return repr_str
@TRANSFORMS.register_module()
class LoadImageFromNDArray(LoadImageFromFile):
"""Load an image from ``results['img']``.
Similar with :obj:`LoadImageFromFile`, but the image has been loaded as
:obj:`np.ndarray` in ``results['img']``. Can be used when loading image
from webcam.
Required Keys:
- img
Modified Keys:
- img
- img_path
- img_shape
- ori_shape
Args:
to_float32 (bool): Whether to convert the loaded image to a float32
numpy array. If set to False, the loaded image is an uint8 array.
Defaults to False.
"""
def transform(self, results: dict) -> dict:
"""Transform function to add image meta information.
Args:
results (dict): Result dict with Webcam read image in
``results['img']``.
Returns:
dict: The dict contains loaded image and meta information.
"""
img = results['img']
if self.to_float32:
img = img.astype(np.float32)
results['img_path'] = None
results['img'] = img
results['img_shape'] = img.shape[:2]
results['ori_shape'] = img.shape[:2]
return results

View File

@ -7,7 +7,7 @@ import mmcv
import numpy as np
from mmcv.transforms import LoadImageFromFile
from mmseg.datasets.transforms import LoadAnnotations
from mmseg.datasets.transforms import LoadAnnotations, LoadImageFromNDArray
class TestLoading(object):
@ -161,3 +161,27 @@ class TestLoading(object):
np.testing.assert_array_equal(gt_array, test_gt)
tmp_dir.cleanup()
def test_load_image_from_ndarray(self):
results = {'img': np.zeros((256, 256, 3), dtype=np.uint8)}
transform = LoadImageFromNDArray()
results = transform(results)
assert results['img'].shape == (256, 256, 3)
assert results['img'].dtype == np.uint8
assert results['img_shape'] == (256, 256)
assert results['ori_shape'] == (256, 256)
# to_float32
transform = LoadImageFromNDArray(to_float32=True)
results = transform(copy.deepcopy(results))
assert results['img'].dtype == np.float32
# test repr
transform = LoadImageFromNDArray()
assert repr(transform) == ('LoadImageFromNDArray('
'ignore_empty=False, '
'to_float32=False, '
"color_type='color', "
"imdecode_backend='cv2', "
"file_client_args={'backend': 'disk'})")