mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] Add BioMedical data loading (#2176)
* [WIP] Add BioMedical data loading * add depends nibabel * fix bug * fix ut * fix * add test data * xyz2zyx zyx2xyz * format * remove ignore empty * remove ignore empty * remove with seg in LoadBiomedicalAnnotation * float32 * docstring * toxyz * docstring
This commit is contained in:
parent
2ea4034014
commit
20c7dc689c
@ -16,10 +16,11 @@ from .pascal_context import PascalContextDataset, PascalContextDataset59
|
|||||||
from .potsdam import PotsdamDataset
|
from .potsdam import PotsdamDataset
|
||||||
from .stare import STAREDataset
|
from .stare import STAREDataset
|
||||||
from .transforms import (CLAHE, AdjustGamma, LoadAnnotations,
|
from .transforms import (CLAHE, AdjustGamma, LoadAnnotations,
|
||||||
LoadImageFromNDArray, PackSegInputs,
|
LoadBiomedicalAnnotation, LoadBiomedicalData,
|
||||||
PhotoMetricDistortion, RandomCrop, RandomCutOut,
|
LoadBiomedicalImageFromFile, LoadImageFromNDArray,
|
||||||
RandomMosaic, RandomRotate, Rerange, ResizeToMultiple,
|
PackSegInputs, PhotoMetricDistortion, RandomCrop,
|
||||||
RGB2Gray, SegRescale)
|
RandomCutOut, RandomMosaic, RandomRotate, Rerange,
|
||||||
|
ResizeToMultiple, RGB2Gray, SegRescale)
|
||||||
from .voc import PascalVOCDataset
|
from .voc import PascalVOCDataset
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -31,5 +32,6 @@ __all__ = [
|
|||||||
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
|
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
|
||||||
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
|
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
|
||||||
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
|
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
|
||||||
'LoadImageFromNDArray'
|
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||||
|
'LoadBiomedicalAnnotation', 'LoadBiomedicalData'
|
||||||
]
|
]
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .formatting import PackSegInputs
|
from .formatting import PackSegInputs
|
||||||
from .loading import LoadAnnotations, LoadImageFromNDArray
|
from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
|
||||||
|
LoadBiomedicalData, LoadBiomedicalImageFromFile,
|
||||||
|
LoadImageFromNDArray)
|
||||||
from .transforms import (CLAHE, AdjustGamma, PhotoMetricDistortion, RandomCrop,
|
from .transforms import (CLAHE, AdjustGamma, PhotoMetricDistortion, RandomCrop,
|
||||||
RandomCutOut, RandomMosaic, RandomRotate, Rerange,
|
RandomCutOut, RandomMosaic, RandomRotate, Rerange,
|
||||||
ResizeToMultiple, RGB2Gray, SegRescale)
|
ResizeToMultiple, RGB2Gray, SegRescale)
|
||||||
@ -9,5 +11,6 @@ __all__ = [
|
|||||||
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
|
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
|
||||||
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
|
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
|
||||||
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
|
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
|
||||||
'LoadImageFromNDArray'
|
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||||
|
'LoadBiomedicalAnnotation', 'LoadBiomedicalData'
|
||||||
]
|
]
|
||||||
|
@ -1,12 +1,16 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import warnings
|
import warnings
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
|
import mmengine
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from mmcv.transforms import BaseTransform
|
||||||
from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations
|
from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations
|
||||||
from mmcv.transforms import LoadImageFromFile
|
from mmcv.transforms import LoadImageFromFile
|
||||||
|
|
||||||
from mmseg.registry import TRANSFORMS
|
from mmseg.registry import TRANSFORMS
|
||||||
|
from mmseg.utils import datafrombytes
|
||||||
|
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
@TRANSFORMS.register_module()
|
||||||
@ -168,3 +172,273 @@ class LoadImageFromNDArray(LoadImageFromFile):
|
|||||||
results['img_shape'] = img.shape[:2]
|
results['img_shape'] = img.shape[:2]
|
||||||
results['ori_shape'] = img.shape[:2]
|
results['ori_shape'] = img.shape[:2]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
@TRANSFORMS.register_module()
|
||||||
|
class LoadBiomedicalImageFromFile(BaseTransform):
|
||||||
|
"""Load an biomedical mage from file.
|
||||||
|
|
||||||
|
Required Keys:
|
||||||
|
|
||||||
|
- img_path
|
||||||
|
|
||||||
|
Added Keys:
|
||||||
|
|
||||||
|
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X) by default,
|
||||||
|
N is the number of modalities, and data type is float32
|
||||||
|
if set to_float32 = True, or float64 if decode_backend is 'nifti' and
|
||||||
|
to_float32 is False.
|
||||||
|
- img_shape
|
||||||
|
- ori_shape
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decode_backend (str): The data decoding backend type. Options are
|
||||||
|
'numpy'and 'nifti', and there is a convention that when backend is
|
||||||
|
'nifti' the axis of data loaded is XYZ, and when backend is
|
||||||
|
'numpy', the the axis is ZYX. The data will be transposed if the
|
||||||
|
backend is 'nifti'. Defaults to 'nifti'.
|
||||||
|
to_xyz (bool): Whether transpose data from Z, Y, X to X, Y, Z.
|
||||||
|
Defaults to False.
|
||||||
|
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||||
|
numpy array. If set to False, the loaded image is an float64 array.
|
||||||
|
Defaults to True.
|
||||||
|
file_client_args (dict): Arguments to instantiate a FileClient.
|
||||||
|
See :class:`mmengine.fileio.FileClient` for details.
|
||||||
|
Defaults to ``dict(backend='disk')``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
decode_backend: str = 'nifti',
|
||||||
|
to_xyz: bool = False,
|
||||||
|
to_float32: bool = True,
|
||||||
|
file_client_args: dict = dict(backend='disk')
|
||||||
|
) -> None:
|
||||||
|
self.decode_backend = decode_backend
|
||||||
|
self.to_xyz = to_xyz
|
||||||
|
self.to_float32 = to_float32
|
||||||
|
self.file_client_args = file_client_args.copy()
|
||||||
|
self.file_client = mmengine.FileClient(**self.file_client_args)
|
||||||
|
|
||||||
|
def transform(self, results: Dict) -> Dict:
|
||||||
|
"""Functions to load image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The dict contains loaded image and meta information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
filename = results['img_path']
|
||||||
|
|
||||||
|
data_bytes = self.file_client.get(filename)
|
||||||
|
img = datafrombytes(data_bytes, backend=self.decode_backend)
|
||||||
|
|
||||||
|
if self.to_float32:
|
||||||
|
img = img.astype(np.float32)
|
||||||
|
|
||||||
|
if len(img.shape) == 3:
|
||||||
|
img = img[None, ...]
|
||||||
|
|
||||||
|
if self.decode_backend == 'nifti':
|
||||||
|
img = img.transpose(0, 3, 2, 1)
|
||||||
|
|
||||||
|
if self.to_xyz:
|
||||||
|
img = img.transpose(0, 3, 2, 1)
|
||||||
|
|
||||||
|
results['img'] = img
|
||||||
|
results['img_shape'] = img.shape[1:]
|
||||||
|
results['ori_shape'] = img.shape[1:]
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
repr_str = (f'{self.__class__.__name__}('
|
||||||
|
f"decode_backend='{self.decode_backend}', "
|
||||||
|
f'to_xyz={self.to_xyz}, '
|
||||||
|
f'to_float32={self.to_float32}, '
|
||||||
|
f'file_client_args={self.file_client_args})')
|
||||||
|
return repr_str
|
||||||
|
|
||||||
|
|
||||||
|
@TRANSFORMS.register_module()
|
||||||
|
class LoadBiomedicalAnnotation(BaseTransform):
|
||||||
|
"""Load ``seg_map`` annotation provided by biomedical dataset.
|
||||||
|
|
||||||
|
The annotation format is as the following:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
{
|
||||||
|
'gt_seg_map': np.ndarray (X, Y, Z) or (Z, Y, X)
|
||||||
|
}
|
||||||
|
|
||||||
|
Required Keys:
|
||||||
|
|
||||||
|
- seg_map_path
|
||||||
|
|
||||||
|
Added Keys:
|
||||||
|
|
||||||
|
- gt_seg_map (np.ndarray): Biomedical seg map with shape (Z, Y, X) by
|
||||||
|
default, and data type is float32 if set to_float32 = True, or
|
||||||
|
float64 if decode_backend is 'nifti' and to_float32 is False.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decode_backend (str): The data decoding backend type. Options are
|
||||||
|
'numpy'and 'nifti', and there is a convention that when backend is
|
||||||
|
'nifti' the axis of data loaded is XYZ, and when backend is
|
||||||
|
'numpy', the the axis is ZYX. The data will be transposed if the
|
||||||
|
backend is 'nifti'. Defaults to 'nifti'.
|
||||||
|
to_xyz (bool): Whether transpose data from Z, Y, X to X, Y, Z.
|
||||||
|
Defaults to False.
|
||||||
|
to_float32 (bool): Whether to convert the loaded seg map to a float32
|
||||||
|
numpy array. If set to False, the loaded image is an float64 array.
|
||||||
|
Defaults to True.
|
||||||
|
file_client_args (dict): Arguments to instantiate a FileClient.
|
||||||
|
See :class:`mmengine.fileio.FileClient` for details.
|
||||||
|
Defaults to ``dict(backend='disk')``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
decode_backend: str = 'nifti',
|
||||||
|
to_xyz: bool = False,
|
||||||
|
to_float32: bool = True,
|
||||||
|
file_client_args: dict = dict(backend='disk')
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.decode_backend = decode_backend
|
||||||
|
self.to_xyz = to_xyz
|
||||||
|
self.to_float32 = to_float32
|
||||||
|
self.file_client_args = file_client_args.copy()
|
||||||
|
self.file_client = mmengine.FileClient(**self.file_client_args)
|
||||||
|
|
||||||
|
def transform(self, results: Dict) -> Dict:
|
||||||
|
"""Functions to load image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The dict contains loaded image and meta information.
|
||||||
|
"""
|
||||||
|
data_bytes = self.file_client.get(results['seg_map_path'])
|
||||||
|
gt_seg_map = datafrombytes(data_bytes, backend=self.decode_backend)
|
||||||
|
|
||||||
|
if self.to_float32:
|
||||||
|
gt_seg_map = gt_seg_map.astype(np.float32)
|
||||||
|
|
||||||
|
if self.decode_backend == 'nifti':
|
||||||
|
gt_seg_map = gt_seg_map.transpose(2, 1, 0)
|
||||||
|
|
||||||
|
if self.to_xyz:
|
||||||
|
gt_seg_map = gt_seg_map.transpose(2, 1, 0)
|
||||||
|
|
||||||
|
results['gt_seg_map'] = gt_seg_map
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
repr_str = (f'{self.__class__.__name__}('
|
||||||
|
f"decode_backend='{self.decode_backend}', "
|
||||||
|
f'to_xyz={self.to_xyz}, '
|
||||||
|
f'to_float32={self.to_float32}, '
|
||||||
|
f'file_client_args={self.file_client_args})')
|
||||||
|
return repr_str
|
||||||
|
|
||||||
|
|
||||||
|
@TRANSFORMS.register_module()
|
||||||
|
class LoadBiomedicalData(BaseTransform):
|
||||||
|
"""Load an biomedical image and annotation from file.
|
||||||
|
|
||||||
|
The loading data format is as the following:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
{
|
||||||
|
'img': np.ndarray data[:-1, X, Y, Z]
|
||||||
|
'seg_map': np.ndarray data[-1, X, Y, Z]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Required Keys:
|
||||||
|
|
||||||
|
- img_path
|
||||||
|
|
||||||
|
Added Keys:
|
||||||
|
|
||||||
|
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X) by default,
|
||||||
|
N is the number of modalities.
|
||||||
|
- gt_seg_map (np.ndarray, optional): Biomedical seg map with shape
|
||||||
|
(Z, Y, X) by default.
|
||||||
|
- img_shape
|
||||||
|
- ori_shape
|
||||||
|
|
||||||
|
Args:
|
||||||
|
with_seg (bool): Whether to parse and load the semantic segmentation
|
||||||
|
annotation. Defaults to False.
|
||||||
|
decode_backend (str): The data decoding backend type. Options are
|
||||||
|
'numpy'and 'nifti', and there is a convention that when backend is
|
||||||
|
'nifti' the axis of data loaded is XYZ, and when backend is
|
||||||
|
'numpy', the the axis is ZYX. The data will be transposed if the
|
||||||
|
backend is 'nifti'. Defaults to 'nifti'.
|
||||||
|
to_xyz (bool): Whether transpose data from Z, Y, X to X, Y, Z.
|
||||||
|
Defaults to False.
|
||||||
|
file_client_args (dict): Arguments to instantiate a FileClient.
|
||||||
|
See :class:`mmengine.fileio.FileClient` for details.
|
||||||
|
Defaults to ``dict(backend='disk')``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
with_seg=False,
|
||||||
|
decode_backend: str = 'numpy',
|
||||||
|
to_xyz: bool = False,
|
||||||
|
file_client_args: dict = dict(backend='disk')
|
||||||
|
) -> None:
|
||||||
|
self.with_seg = with_seg
|
||||||
|
self.decode_backend = decode_backend
|
||||||
|
self.to_xyz = to_xyz
|
||||||
|
self.file_client_args = file_client_args.copy()
|
||||||
|
self.file_client = mmengine.FileClient(**self.file_client_args)
|
||||||
|
|
||||||
|
def transform(self, results: Dict) -> Dict:
|
||||||
|
"""Functions to load image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The dict contains loaded image and meta information.
|
||||||
|
"""
|
||||||
|
data_bytes = self.file_client.get(results['img_path'])
|
||||||
|
data = datafrombytes(data_bytes, backend=self.decode_backend)
|
||||||
|
# img is 4D data (N, X, Y, Z), N is the number of protocol
|
||||||
|
img = data[:-1, :]
|
||||||
|
|
||||||
|
if self.decode_backend == 'nifti':
|
||||||
|
img = img.transpose(0, 3, 2, 1)
|
||||||
|
|
||||||
|
if self.to_xyz:
|
||||||
|
img = img.transpose(0, 3, 2, 1)
|
||||||
|
|
||||||
|
results['img'] = img
|
||||||
|
results['img_shape'] = img.shape[1:]
|
||||||
|
results['ori_shape'] = img.shape[1:]
|
||||||
|
|
||||||
|
if self.with_seg:
|
||||||
|
gt_seg_map = data[-1, :]
|
||||||
|
if self.decode_backend == 'nifti':
|
||||||
|
gt_seg_map = gt_seg_map.transpose(2, 1, 0)
|
||||||
|
|
||||||
|
if self.to_xyz:
|
||||||
|
gt_seg_map = gt_seg_map.transpose(2, 1, 0)
|
||||||
|
results['gt_seg_map'] = gt_seg_map
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
repr_str = (f'{self.__class__.__name__}('
|
||||||
|
f'with_seg={self.with_seg}, '
|
||||||
|
f"decode_backend='{self.decode_backend}', "
|
||||||
|
f'to_xyz={self.to_xyz}, '
|
||||||
|
f'file_client_args={self.file_client_args})')
|
||||||
|
return repr_str
|
||||||
|
@ -10,6 +10,7 @@ from .class_names import (ade_classes, ade_palette, cityscapes_classes,
|
|||||||
voc_palette)
|
voc_palette)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from .collect_env import collect_env
|
from .collect_env import collect_env
|
||||||
|
from .io import datafrombytes
|
||||||
from .misc import add_prefix, stack_batch
|
from .misc import add_prefix, stack_batch
|
||||||
from .set_env import register_all_modules
|
from .set_env import register_all_modules
|
||||||
from .typing import (ConfigType, ForwardResults, MultiConfig, OptConfigType,
|
from .typing import (ConfigType, ForwardResults, MultiConfig, OptConfigType,
|
||||||
@ -25,5 +26,6 @@ __all__ = [
|
|||||||
'vaihingen_classes', 'isaid_classes', 'stare_classes',
|
'vaihingen_classes', 'isaid_classes', 'stare_classes',
|
||||||
'cityscapes_palette', 'ade_palette', 'voc_palette', 'cocostuff_palette',
|
'cityscapes_palette', 'ade_palette', 'voc_palette', 'cocostuff_palette',
|
||||||
'loveda_palette', 'potsdam_palette', 'vaihingen_palette', 'isaid_palette',
|
'loveda_palette', 'potsdam_palette', 'vaihingen_palette', 'isaid_palette',
|
||||||
'stare_palette', 'dataset_aliases', 'get_classes', 'get_palette'
|
'stare_palette', 'dataset_aliases', 'get_classes', 'get_palette',
|
||||||
|
'datafrombytes'
|
||||||
]
|
]
|
||||||
|
38
mmseg/utils/io.py
Normal file
38
mmseg/utils/io.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import gzip
|
||||||
|
import io
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def datafrombytes(content: bytes, backend: str = 'numpy') -> np.ndarray:
|
||||||
|
"""Data decoding from bytes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content (bytes): The data bytes got from files or other streams.
|
||||||
|
backend (str): The data decoding backend type. Options are 'numpy',
|
||||||
|
'nifti' and 'pickle'. Defaults to 'numpy'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
numpy.ndarray: Loaded data array.
|
||||||
|
"""
|
||||||
|
if backend == 'pickle':
|
||||||
|
data = pickle.loads(content)
|
||||||
|
else:
|
||||||
|
with io.BytesIO(content) as f:
|
||||||
|
if backend == 'nifti':
|
||||||
|
f = gzip.open(f)
|
||||||
|
try:
|
||||||
|
from nibabel import FileHolder, Nifti1Image
|
||||||
|
except ImportError:
|
||||||
|
print('nifti files io depends on nibabel, please run'
|
||||||
|
'`pip install nibabel` to install it')
|
||||||
|
fh = FileHolder(fileobj=f)
|
||||||
|
data = Nifti1Image.from_file_map({'header': fh, 'image': fh})
|
||||||
|
data = Nifti1Image.from_bytes(data.to_bytes()).get_fdata()
|
||||||
|
elif backend == 'numpy':
|
||||||
|
data = np.load(f)
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
return data
|
@ -1 +1,2 @@
|
|||||||
cityscapesscripts
|
cityscapesscripts
|
||||||
|
nibabel
|
||||||
|
BIN
tests/data/biomedical.nii.gz
Executable file
BIN
tests/data/biomedical.nii.gz
Executable file
Binary file not shown.
BIN
tests/data/biomedical.npy
Normal file
BIN
tests/data/biomedical.npy
Normal file
Binary file not shown.
BIN
tests/data/biomedical.pkl
Normal file
BIN
tests/data/biomedical.pkl
Normal file
Binary file not shown.
BIN
tests/data/biomedical_ann.nii.gz
Executable file
BIN
tests/data/biomedical_ann.nii.gz
Executable file
Binary file not shown.
@ -7,7 +7,11 @@ import mmcv
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from mmcv.transforms import LoadImageFromFile
|
from mmcv.transforms import LoadImageFromFile
|
||||||
|
|
||||||
from mmseg.datasets.transforms import LoadAnnotations, LoadImageFromNDArray
|
from mmseg.datasets.transforms import (LoadAnnotations,
|
||||||
|
LoadBiomedicalAnnotation,
|
||||||
|
LoadBiomedicalData,
|
||||||
|
LoadBiomedicalImageFromFile,
|
||||||
|
LoadImageFromNDArray)
|
||||||
|
|
||||||
|
|
||||||
class TestLoading:
|
class TestLoading:
|
||||||
@ -185,3 +189,53 @@ class TestLoading:
|
|||||||
"color_type='color', "
|
"color_type='color', "
|
||||||
"imdecode_backend='cv2', "
|
"imdecode_backend='cv2', "
|
||||||
"file_client_args={'backend': 'disk'})")
|
"file_client_args={'backend': 'disk'})")
|
||||||
|
|
||||||
|
def test_load_biomedical_img(self):
|
||||||
|
results = dict(
|
||||||
|
img_path=osp.join(self.data_prefix, 'biomedical.nii.gz'))
|
||||||
|
transform = LoadBiomedicalImageFromFile()
|
||||||
|
results = transform(copy.deepcopy(results))
|
||||||
|
assert results['img_path'] == osp.join(self.data_prefix,
|
||||||
|
'biomedical.nii.gz')
|
||||||
|
assert len(results['img'].shape) == 4
|
||||||
|
assert results['img'].dtype == np.float32
|
||||||
|
assert results['ori_shape'] == results['img'].shape[1:]
|
||||||
|
assert repr(transform) == ('LoadBiomedicalImageFromFile('
|
||||||
|
"decode_backend='nifti', "
|
||||||
|
'to_xyz=False, '
|
||||||
|
'to_float32=True, '
|
||||||
|
"file_client_args={'backend': 'disk'})")
|
||||||
|
|
||||||
|
def test_load_biomedical_annotation(self):
|
||||||
|
results = dict(
|
||||||
|
seg_map_path=osp.join(self.data_prefix, 'biomedical_ann.nii.gz'))
|
||||||
|
transform = LoadBiomedicalAnnotation()
|
||||||
|
results = transform(copy.deepcopy(results))
|
||||||
|
assert len(results['gt_seg_map'].shape) == 3
|
||||||
|
assert results['gt_seg_map'].dtype == np.float32
|
||||||
|
|
||||||
|
def test_load_biomedical_data(self):
|
||||||
|
input_results = dict(
|
||||||
|
img_path=osp.join(self.data_prefix, 'biomedical.npy'))
|
||||||
|
transform = LoadBiomedicalData(with_seg=True)
|
||||||
|
results = transform(copy.deepcopy(input_results))
|
||||||
|
assert results['img_path'] == osp.join(self.data_prefix,
|
||||||
|
'biomedical.npy')
|
||||||
|
assert results['img'][0].shape == results['gt_seg_map'].shape
|
||||||
|
assert results['img'].dtype == np.float32
|
||||||
|
assert results['ori_shape'] == results['img'].shape[1:]
|
||||||
|
assert repr(transform) == ('LoadBiomedicalData('
|
||||||
|
'with_seg=True, '
|
||||||
|
"decode_backend='numpy', "
|
||||||
|
'to_xyz=False, '
|
||||||
|
"file_client_args={'backend': 'disk'})")
|
||||||
|
|
||||||
|
transform = LoadBiomedicalData(with_seg=False)
|
||||||
|
results = transform(copy.deepcopy(input_results))
|
||||||
|
assert len(results['img'].shape) == 4
|
||||||
|
assert results.get('gt_seg_map') is None
|
||||||
|
assert repr(transform) == ('LoadBiomedicalData('
|
||||||
|
'with_seg=False, '
|
||||||
|
"decode_backend='numpy', "
|
||||||
|
'to_xyz=False, '
|
||||||
|
"file_client_args={'backend': 'disk'})")
|
||||||
|
33
tests/test_utils/test_io.py
Normal file
33
tests/test_utils/test_io.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from mmengine import FileClient
|
||||||
|
|
||||||
|
from mmseg.utils import datafrombytes
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
['backend', 'suffix'],
|
||||||
|
[['nifti', '.nii.gz'], ['numpy', '.npy'], ['pickle', '.pkl']])
|
||||||
|
def test_datafrombytes(backend, suffix):
|
||||||
|
|
||||||
|
file_client = FileClient('disk')
|
||||||
|
file_path = osp.join(osp.dirname(__file__), '../data/biomedical' + suffix)
|
||||||
|
bytes = file_client.get(file_path)
|
||||||
|
data = datafrombytes(bytes, backend)
|
||||||
|
|
||||||
|
if backend == 'pickle':
|
||||||
|
# test pickle loading
|
||||||
|
assert isinstance(data, dict)
|
||||||
|
else:
|
||||||
|
assert isinstance(data, np.ndarray)
|
||||||
|
if backend == 'nifti':
|
||||||
|
# test nifti file loading
|
||||||
|
assert len(data.shape) == 3
|
||||||
|
else:
|
||||||
|
# test npy file loading
|
||||||
|
# testing data biomedical.npy includes data and label
|
||||||
|
assert len(data.shape) == 4
|
||||||
|
assert data.shape[0] == 2
|
Loading…
x
Reference in New Issue
Block a user