[Feature] Add BioMedical data loading (#2176)
* [WIP] Add BioMedical data loading * add depends nibabel * fix bug * fix ut * fix * add test data * xyz2zyx zyx2xyz * format * remove ignore empty * remove ignore empty * remove with seg in LoadBiomedicalAnnotation * float32 * docstring * toxyz * docstringpull/2220/head
parent
2ea4034014
commit
20c7dc689c
|
@ -16,10 +16,11 @@ from .pascal_context import PascalContextDataset, PascalContextDataset59
|
|||
from .potsdam import PotsdamDataset
|
||||
from .stare import STAREDataset
|
||||
from .transforms import (CLAHE, AdjustGamma, LoadAnnotations,
|
||||
LoadImageFromNDArray, PackSegInputs,
|
||||
PhotoMetricDistortion, RandomCrop, RandomCutOut,
|
||||
RandomMosaic, RandomRotate, Rerange, ResizeToMultiple,
|
||||
RGB2Gray, SegRescale)
|
||||
LoadBiomedicalAnnotation, LoadBiomedicalData,
|
||||
LoadBiomedicalImageFromFile, LoadImageFromNDArray,
|
||||
PackSegInputs, PhotoMetricDistortion, RandomCrop,
|
||||
RandomCutOut, RandomMosaic, RandomRotate, Rerange,
|
||||
ResizeToMultiple, RGB2Gray, SegRescale)
|
||||
from .voc import PascalVOCDataset
|
||||
|
||||
__all__ = [
|
||||
|
@ -31,5 +32,6 @@ __all__ = [
|
|||
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
|
||||
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
|
||||
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
|
||||
'LoadImageFromNDArray'
|
||||
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData'
|
||||
]
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .formatting import PackSegInputs
|
||||
from .loading import LoadAnnotations, LoadImageFromNDArray
|
||||
from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
|
||||
LoadBiomedicalData, LoadBiomedicalImageFromFile,
|
||||
LoadImageFromNDArray)
|
||||
from .transforms import (CLAHE, AdjustGamma, PhotoMetricDistortion, RandomCrop,
|
||||
RandomCutOut, RandomMosaic, RandomRotate, Rerange,
|
||||
ResizeToMultiple, RGB2Gray, SegRescale)
|
||||
|
@ -9,5 +11,6 @@ __all__ = [
|
|||
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
|
||||
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
|
||||
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
|
||||
'LoadImageFromNDArray'
|
||||
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData'
|
||||
]
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from typing import Dict
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
import numpy as np
|
||||
from mmcv.transforms import BaseTransform
|
||||
from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations
|
||||
from mmcv.transforms import LoadImageFromFile
|
||||
|
||||
from mmseg.registry import TRANSFORMS
|
||||
from mmseg.utils import datafrombytes
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
|
@ -168,3 +172,273 @@ class LoadImageFromNDArray(LoadImageFromFile):
|
|||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadBiomedicalImageFromFile(BaseTransform):
|
||||
"""Load an biomedical mage from file.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Added Keys:
|
||||
|
||||
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X) by default,
|
||||
N is the number of modalities, and data type is float32
|
||||
if set to_float32 = True, or float64 if decode_backend is 'nifti' and
|
||||
to_float32 is False.
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
decode_backend (str): The data decoding backend type. Options are
|
||||
'numpy'and 'nifti', and there is a convention that when backend is
|
||||
'nifti' the axis of data loaded is XYZ, and when backend is
|
||||
'numpy', the the axis is ZYX. The data will be transposed if the
|
||||
backend is 'nifti'. Defaults to 'nifti'.
|
||||
to_xyz (bool): Whether transpose data from Z, Y, X to X, Y, Z.
|
||||
Defaults to False.
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is an float64 array.
|
||||
Defaults to True.
|
||||
file_client_args (dict): Arguments to instantiate a FileClient.
|
||||
See :class:`mmengine.fileio.FileClient` for details.
|
||||
Defaults to ``dict(backend='disk')``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decode_backend: str = 'nifti',
|
||||
to_xyz: bool = False,
|
||||
to_float32: bool = True,
|
||||
file_client_args: dict = dict(backend='disk')
|
||||
) -> None:
|
||||
self.decode_backend = decode_backend
|
||||
self.to_xyz = to_xyz
|
||||
self.to_float32 = to_float32
|
||||
self.file_client_args = file_client_args.copy()
|
||||
self.file_client = mmengine.FileClient(**self.file_client_args)
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
|
||||
data_bytes = self.file_client.get(filename)
|
||||
img = datafrombytes(data_bytes, backend=self.decode_backend)
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
if len(img.shape) == 3:
|
||||
img = img[None, ...]
|
||||
|
||||
if self.decode_backend == 'nifti':
|
||||
img = img.transpose(0, 3, 2, 1)
|
||||
|
||||
if self.to_xyz:
|
||||
img = img.transpose(0, 3, 2, 1)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[1:]
|
||||
results['ori_shape'] = img.shape[1:]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f"decode_backend='{self.decode_backend}', "
|
||||
f'to_xyz={self.to_xyz}, '
|
||||
f'to_float32={self.to_float32}, '
|
||||
f'file_client_args={self.file_client_args})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadBiomedicalAnnotation(BaseTransform):
|
||||
"""Load ``seg_map`` annotation provided by biomedical dataset.
|
||||
|
||||
The annotation format is as the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
'gt_seg_map': np.ndarray (X, Y, Z) or (Z, Y, X)
|
||||
}
|
||||
|
||||
Required Keys:
|
||||
|
||||
- seg_map_path
|
||||
|
||||
Added Keys:
|
||||
|
||||
- gt_seg_map (np.ndarray): Biomedical seg map with shape (Z, Y, X) by
|
||||
default, and data type is float32 if set to_float32 = True, or
|
||||
float64 if decode_backend is 'nifti' and to_float32 is False.
|
||||
|
||||
Args:
|
||||
decode_backend (str): The data decoding backend type. Options are
|
||||
'numpy'and 'nifti', and there is a convention that when backend is
|
||||
'nifti' the axis of data loaded is XYZ, and when backend is
|
||||
'numpy', the the axis is ZYX. The data will be transposed if the
|
||||
backend is 'nifti'. Defaults to 'nifti'.
|
||||
to_xyz (bool): Whether transpose data from Z, Y, X to X, Y, Z.
|
||||
Defaults to False.
|
||||
to_float32 (bool): Whether to convert the loaded seg map to a float32
|
||||
numpy array. If set to False, the loaded image is an float64 array.
|
||||
Defaults to True.
|
||||
file_client_args (dict): Arguments to instantiate a FileClient.
|
||||
See :class:`mmengine.fileio.FileClient` for details.
|
||||
Defaults to ``dict(backend='disk')``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decode_backend: str = 'nifti',
|
||||
to_xyz: bool = False,
|
||||
to_float32: bool = True,
|
||||
file_client_args: dict = dict(backend='disk')
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.decode_backend = decode_backend
|
||||
self.to_xyz = to_xyz
|
||||
self.to_float32 = to_float32
|
||||
self.file_client_args = file_client_args.copy()
|
||||
self.file_client = mmengine.FileClient(**self.file_client_args)
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
data_bytes = self.file_client.get(results['seg_map_path'])
|
||||
gt_seg_map = datafrombytes(data_bytes, backend=self.decode_backend)
|
||||
|
||||
if self.to_float32:
|
||||
gt_seg_map = gt_seg_map.astype(np.float32)
|
||||
|
||||
if self.decode_backend == 'nifti':
|
||||
gt_seg_map = gt_seg_map.transpose(2, 1, 0)
|
||||
|
||||
if self.to_xyz:
|
||||
gt_seg_map = gt_seg_map.transpose(2, 1, 0)
|
||||
|
||||
results['gt_seg_map'] = gt_seg_map
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f"decode_backend='{self.decode_backend}', "
|
||||
f'to_xyz={self.to_xyz}, '
|
||||
f'to_float32={self.to_float32}, '
|
||||
f'file_client_args={self.file_client_args})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadBiomedicalData(BaseTransform):
|
||||
"""Load an biomedical image and annotation from file.
|
||||
|
||||
The loading data format is as the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
'img': np.ndarray data[:-1, X, Y, Z]
|
||||
'seg_map': np.ndarray data[-1, X, Y, Z]
|
||||
}
|
||||
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Added Keys:
|
||||
|
||||
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X) by default,
|
||||
N is the number of modalities.
|
||||
- gt_seg_map (np.ndarray, optional): Biomedical seg map with shape
|
||||
(Z, Y, X) by default.
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
with_seg (bool): Whether to parse and load the semantic segmentation
|
||||
annotation. Defaults to False.
|
||||
decode_backend (str): The data decoding backend type. Options are
|
||||
'numpy'and 'nifti', and there is a convention that when backend is
|
||||
'nifti' the axis of data loaded is XYZ, and when backend is
|
||||
'numpy', the the axis is ZYX. The data will be transposed if the
|
||||
backend is 'nifti'. Defaults to 'nifti'.
|
||||
to_xyz (bool): Whether transpose data from Z, Y, X to X, Y, Z.
|
||||
Defaults to False.
|
||||
file_client_args (dict): Arguments to instantiate a FileClient.
|
||||
See :class:`mmengine.fileio.FileClient` for details.
|
||||
Defaults to ``dict(backend='disk')``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
with_seg=False,
|
||||
decode_backend: str = 'numpy',
|
||||
to_xyz: bool = False,
|
||||
file_client_args: dict = dict(backend='disk')
|
||||
) -> None:
|
||||
self.with_seg = with_seg
|
||||
self.decode_backend = decode_backend
|
||||
self.to_xyz = to_xyz
|
||||
self.file_client_args = file_client_args.copy()
|
||||
self.file_client = mmengine.FileClient(**self.file_client_args)
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
data_bytes = self.file_client.get(results['img_path'])
|
||||
data = datafrombytes(data_bytes, backend=self.decode_backend)
|
||||
# img is 4D data (N, X, Y, Z), N is the number of protocol
|
||||
img = data[:-1, :]
|
||||
|
||||
if self.decode_backend == 'nifti':
|
||||
img = img.transpose(0, 3, 2, 1)
|
||||
|
||||
if self.to_xyz:
|
||||
img = img.transpose(0, 3, 2, 1)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[1:]
|
||||
results['ori_shape'] = img.shape[1:]
|
||||
|
||||
if self.with_seg:
|
||||
gt_seg_map = data[-1, :]
|
||||
if self.decode_backend == 'nifti':
|
||||
gt_seg_map = gt_seg_map.transpose(2, 1, 0)
|
||||
|
||||
if self.to_xyz:
|
||||
gt_seg_map = gt_seg_map.transpose(2, 1, 0)
|
||||
results['gt_seg_map'] = gt_seg_map
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'with_seg={self.with_seg}, '
|
||||
f"decode_backend='{self.decode_backend}', "
|
||||
f'to_xyz={self.to_xyz}, '
|
||||
f'file_client_args={self.file_client_args})')
|
||||
return repr_str
|
||||
|
|
|
@ -10,6 +10,7 @@ from .class_names import (ade_classes, ade_palette, cityscapes_classes,
|
|||
voc_palette)
|
||||
# yapf: enable
|
||||
from .collect_env import collect_env
|
||||
from .io import datafrombytes
|
||||
from .misc import add_prefix, stack_batch
|
||||
from .set_env import register_all_modules
|
||||
from .typing import (ConfigType, ForwardResults, MultiConfig, OptConfigType,
|
||||
|
@ -25,5 +26,6 @@ __all__ = [
|
|||
'vaihingen_classes', 'isaid_classes', 'stare_classes',
|
||||
'cityscapes_palette', 'ade_palette', 'voc_palette', 'cocostuff_palette',
|
||||
'loveda_palette', 'potsdam_palette', 'vaihingen_palette', 'isaid_palette',
|
||||
'stare_palette', 'dataset_aliases', 'get_classes', 'get_palette'
|
||||
'stare_palette', 'dataset_aliases', 'get_classes', 'get_palette',
|
||||
'datafrombytes'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import gzip
|
||||
import io
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def datafrombytes(content: bytes, backend: str = 'numpy') -> np.ndarray:
|
||||
"""Data decoding from bytes.
|
||||
|
||||
Args:
|
||||
content (bytes): The data bytes got from files or other streams.
|
||||
backend (str): The data decoding backend type. Options are 'numpy',
|
||||
'nifti' and 'pickle'. Defaults to 'numpy'.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Loaded data array.
|
||||
"""
|
||||
if backend == 'pickle':
|
||||
data = pickle.loads(content)
|
||||
else:
|
||||
with io.BytesIO(content) as f:
|
||||
if backend == 'nifti':
|
||||
f = gzip.open(f)
|
||||
try:
|
||||
from nibabel import FileHolder, Nifti1Image
|
||||
except ImportError:
|
||||
print('nifti files io depends on nibabel, please run'
|
||||
'`pip install nibabel` to install it')
|
||||
fh = FileHolder(fileobj=f)
|
||||
data = Nifti1Image.from_file_map({'header': fh, 'image': fh})
|
||||
data = Nifti1Image.from_bytes(data.to_bytes()).get_fdata()
|
||||
elif backend == 'numpy':
|
||||
data = np.load(f)
|
||||
else:
|
||||
raise ValueError
|
||||
return data
|
|
@ -1 +1,2 @@
|
|||
cityscapesscripts
|
||||
nibabel
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -7,7 +7,11 @@ import mmcv
|
|||
import numpy as np
|
||||
from mmcv.transforms import LoadImageFromFile
|
||||
|
||||
from mmseg.datasets.transforms import LoadAnnotations, LoadImageFromNDArray
|
||||
from mmseg.datasets.transforms import (LoadAnnotations,
|
||||
LoadBiomedicalAnnotation,
|
||||
LoadBiomedicalData,
|
||||
LoadBiomedicalImageFromFile,
|
||||
LoadImageFromNDArray)
|
||||
|
||||
|
||||
class TestLoading:
|
||||
|
@ -185,3 +189,53 @@ class TestLoading:
|
|||
"color_type='color', "
|
||||
"imdecode_backend='cv2', "
|
||||
"file_client_args={'backend': 'disk'})")
|
||||
|
||||
def test_load_biomedical_img(self):
|
||||
results = dict(
|
||||
img_path=osp.join(self.data_prefix, 'biomedical.nii.gz'))
|
||||
transform = LoadBiomedicalImageFromFile()
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['img_path'] == osp.join(self.data_prefix,
|
||||
'biomedical.nii.gz')
|
||||
assert len(results['img'].shape) == 4
|
||||
assert results['img'].dtype == np.float32
|
||||
assert results['ori_shape'] == results['img'].shape[1:]
|
||||
assert repr(transform) == ('LoadBiomedicalImageFromFile('
|
||||
"decode_backend='nifti', "
|
||||
'to_xyz=False, '
|
||||
'to_float32=True, '
|
||||
"file_client_args={'backend': 'disk'})")
|
||||
|
||||
def test_load_biomedical_annotation(self):
|
||||
results = dict(
|
||||
seg_map_path=osp.join(self.data_prefix, 'biomedical_ann.nii.gz'))
|
||||
transform = LoadBiomedicalAnnotation()
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert len(results['gt_seg_map'].shape) == 3
|
||||
assert results['gt_seg_map'].dtype == np.float32
|
||||
|
||||
def test_load_biomedical_data(self):
|
||||
input_results = dict(
|
||||
img_path=osp.join(self.data_prefix, 'biomedical.npy'))
|
||||
transform = LoadBiomedicalData(with_seg=True)
|
||||
results = transform(copy.deepcopy(input_results))
|
||||
assert results['img_path'] == osp.join(self.data_prefix,
|
||||
'biomedical.npy')
|
||||
assert results['img'][0].shape == results['gt_seg_map'].shape
|
||||
assert results['img'].dtype == np.float32
|
||||
assert results['ori_shape'] == results['img'].shape[1:]
|
||||
assert repr(transform) == ('LoadBiomedicalData('
|
||||
'with_seg=True, '
|
||||
"decode_backend='numpy', "
|
||||
'to_xyz=False, '
|
||||
"file_client_args={'backend': 'disk'})")
|
||||
|
||||
transform = LoadBiomedicalData(with_seg=False)
|
||||
results = transform(copy.deepcopy(input_results))
|
||||
assert len(results['img'].shape) == 4
|
||||
assert results.get('gt_seg_map') is None
|
||||
assert repr(transform) == ('LoadBiomedicalData('
|
||||
'with_seg=False, '
|
||||
"decode_backend='numpy', "
|
||||
'to_xyz=False, '
|
||||
"file_client_args={'backend': 'disk'})")
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mmengine import FileClient
|
||||
|
||||
from mmseg.utils import datafrombytes
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
['backend', 'suffix'],
|
||||
[['nifti', '.nii.gz'], ['numpy', '.npy'], ['pickle', '.pkl']])
|
||||
def test_datafrombytes(backend, suffix):
|
||||
|
||||
file_client = FileClient('disk')
|
||||
file_path = osp.join(osp.dirname(__file__), '../data/biomedical' + suffix)
|
||||
bytes = file_client.get(file_path)
|
||||
data = datafrombytes(bytes, backend)
|
||||
|
||||
if backend == 'pickle':
|
||||
# test pickle loading
|
||||
assert isinstance(data, dict)
|
||||
else:
|
||||
assert isinstance(data, np.ndarray)
|
||||
if backend == 'nifti':
|
||||
# test nifti file loading
|
||||
assert len(data.shape) == 3
|
||||
else:
|
||||
# test npy file loading
|
||||
# testing data biomedical.npy includes data and label
|
||||
assert len(data.shape) == 4
|
||||
assert data.shape[0] == 2
|
Loading…
Reference in New Issue