mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature]Add Decathlon dataset (#2227)
* [Feature]Add Decathlon dataset * fix test data * add file * remove order * revise default value for prefix * modify example * revise based on comments * add comments for ut
This commit is contained in:
parent
293a057b61
commit
3d8fd35e26
@ -6,6 +6,7 @@ from .cityscapes import CityscapesDataset
|
|||||||
from .coco_stuff import COCOStuffDataset
|
from .coco_stuff import COCOStuffDataset
|
||||||
from .dark_zurich import DarkZurichDataset
|
from .dark_zurich import DarkZurichDataset
|
||||||
from .dataset_wrappers import MultiImageMixDataset
|
from .dataset_wrappers import MultiImageMixDataset
|
||||||
|
from .decathlon import DecathlonDataset
|
||||||
from .drive import DRIVEDataset
|
from .drive import DRIVEDataset
|
||||||
from .hrf import HRFDataset
|
from .hrf import HRFDataset
|
||||||
from .isaid import iSAIDDataset
|
from .isaid import iSAIDDataset
|
||||||
@ -33,5 +34,6 @@ __all__ = [
|
|||||||
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
|
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
|
||||||
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
|
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
|
||||||
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge'
|
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
||||||
|
'DecathlonDataset'
|
||||||
]
|
]
|
||||||
|
@ -85,7 +85,7 @@ class BaseSegDataset(BaseDataset):
|
|||||||
seg_map_suffix='.png',
|
seg_map_suffix='.png',
|
||||||
metainfo: Optional[dict] = None,
|
metainfo: Optional[dict] = None,
|
||||||
data_root: Optional[str] = None,
|
data_root: Optional[str] = None,
|
||||||
data_prefix: dict = dict(img_path=None, seg_map_path=None),
|
data_prefix: dict = dict(img_path='', seg_map_path=''),
|
||||||
filter_cfg: Optional[dict] = None,
|
filter_cfg: Optional[dict] = None,
|
||||||
indices: Optional[Union[int, Sequence[int]]] = None,
|
indices: Optional[Union[int, Sequence[int]]] = None,
|
||||||
serialize_data: bool = True,
|
serialize_data: bool = True,
|
||||||
@ -132,9 +132,6 @@ class BaseSegDataset(BaseDataset):
|
|||||||
# if it is not defined
|
# if it is not defined
|
||||||
updated_palette = self._update_palette()
|
updated_palette = self._update_palette()
|
||||||
self._metainfo.update(dict(palette=updated_palette))
|
self._metainfo.update(dict(palette=updated_palette))
|
||||||
if test_mode:
|
|
||||||
assert self._metainfo.get('classes') is not None, \
|
|
||||||
'dataset metainfo `classes` should be specified when testing'
|
|
||||||
|
|
||||||
# Join paths.
|
# Join paths.
|
||||||
if self.data_root is not None:
|
if self.data_root is not None:
|
||||||
@ -146,6 +143,10 @@ class BaseSegDataset(BaseDataset):
|
|||||||
if not lazy_init:
|
if not lazy_init:
|
||||||
self.full_init()
|
self.full_init()
|
||||||
|
|
||||||
|
if test_mode:
|
||||||
|
assert self._metainfo.get('classes') is not None, \
|
||||||
|
'dataset metainfo `classes` should be specified when testing'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_label_map(cls,
|
def get_label_map(cls,
|
||||||
new_classes: Optional[Sequence] = None
|
new_classes: Optional[Sequence] = None
|
||||||
|
96
mmseg/datasets/decathlon.py
Normal file
96
mmseg/datasets/decathlon.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import copy
|
||||||
|
import os.path as osp
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from mmengine.fileio import load
|
||||||
|
|
||||||
|
from mmseg.registry import DATASETS
|
||||||
|
from .basesegdataset import BaseSegDataset
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module()
|
||||||
|
class DecathlonDataset(BaseSegDataset):
|
||||||
|
"""Dataset for Dacathlon dataset.
|
||||||
|
|
||||||
|
The dataset.json format is shown as follows
|
||||||
|
|
||||||
|
.. code-block:: none
|
||||||
|
|
||||||
|
{
|
||||||
|
"name": "BRATS",
|
||||||
|
"tensorImageSize": "4D",
|
||||||
|
"modality":
|
||||||
|
{
|
||||||
|
"0": "FLAIR",
|
||||||
|
"1": "T1w",
|
||||||
|
"2": "t1gd",
|
||||||
|
"3": "T2w"
|
||||||
|
},
|
||||||
|
"labels": {
|
||||||
|
"0": "background",
|
||||||
|
"1": "edema",
|
||||||
|
"2": "non-enhancing tumor",
|
||||||
|
"3": "enhancing tumour"
|
||||||
|
},
|
||||||
|
"numTraining": 484,
|
||||||
|
"numTest": 266,
|
||||||
|
"training":
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"image": "./imagesTr/BRATS_306.nii.gz"
|
||||||
|
"label": "./labelsTr/BRATS_306.nii.gz"
|
||||||
|
...
|
||||||
|
}
|
||||||
|
]
|
||||||
|
"test":
|
||||||
|
[
|
||||||
|
"./imagesTs/BRATS_557.nii.gz"
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def load_data_list(self) -> List[dict]:
|
||||||
|
"""Load annotation from directory or annotation file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[dict]: All data info of dataset.
|
||||||
|
"""
|
||||||
|
# `self.ann_file` denotes the absolute annotation file path if
|
||||||
|
# `self.root=None` or relative path if `self.root=/path/to/data/`.
|
||||||
|
annotations = load(self.ann_file)
|
||||||
|
if not isinstance(annotations, dict):
|
||||||
|
raise TypeError(f'The annotations loaded from annotation file '
|
||||||
|
f'should be a dict, but got {type(annotations)}!')
|
||||||
|
raw_data_list = annotations[
|
||||||
|
'training'] if not self.test_mode else annotations['test']
|
||||||
|
data_list = []
|
||||||
|
for raw_data_info in raw_data_list:
|
||||||
|
# `2:` works for removing './' in file path, which will break
|
||||||
|
# loading from cloud storage.
|
||||||
|
if isinstance(raw_data_info, dict):
|
||||||
|
data_info = dict(
|
||||||
|
img_path=osp.join(self.data_root, raw_data_info['image']
|
||||||
|
[2:]))
|
||||||
|
data_info['seg_map_path'] = osp.join(
|
||||||
|
self.data_root, raw_data_info['label'][2:])
|
||||||
|
else:
|
||||||
|
data_info = dict(
|
||||||
|
img_path=osp.join(self.data_root, raw_data_info)[2:])
|
||||||
|
data_info['label_map'] = self.label_map
|
||||||
|
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||||
|
data_info['seg_fields'] = []
|
||||||
|
data_list.append(data_info)
|
||||||
|
annotations.pop('training')
|
||||||
|
annotations.pop('test')
|
||||||
|
|
||||||
|
metainfo = copy.deepcopy(annotations)
|
||||||
|
metainfo['classes'] = [*metainfo['labels'].values()]
|
||||||
|
# Meta information load from annotation file will not influence the
|
||||||
|
# existed meta information load from `BaseDataset.METAINFO` and
|
||||||
|
# `metainfo` arguments defined in constructor.
|
||||||
|
for k, v in metainfo.items():
|
||||||
|
self._metainfo.setdefault(k, v)
|
||||||
|
|
||||||
|
return data_list
|
30
tests/data/dataset.json
Executable file
30
tests/data/dataset.json
Executable file
@ -0,0 +1,30 @@
|
|||||||
|
{
|
||||||
|
"name": "BRATS",
|
||||||
|
"description": "Gliomas segmentation tumour and oedema in on brain images",
|
||||||
|
"tensorImageSize": "4D",
|
||||||
|
"modality": {
|
||||||
|
"0": "FLAIR",
|
||||||
|
"1": "T1w",
|
||||||
|
"2": "t1gd",
|
||||||
|
"3": "T2w"
|
||||||
|
},
|
||||||
|
"labels": {
|
||||||
|
"0": "background",
|
||||||
|
"1": "edema",
|
||||||
|
"2": "non-enhancing tumor",
|
||||||
|
"3": "enhancing tumour"
|
||||||
|
},
|
||||||
|
"numTraining": 484,
|
||||||
|
"numTest": 266,
|
||||||
|
"training": [
|
||||||
|
{
|
||||||
|
"image": "./imagesTr/BRATS_457.nii.gz",
|
||||||
|
"label": "./labelsTr/BRATS_457.nii.gz"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"test": [
|
||||||
|
"./imagesTs/BRATS_568.nii.gz",
|
||||||
|
"./imagesTs/BRATS_515.nii.gz",
|
||||||
|
"./imagesTs/BRATS_576.nii.gz"
|
||||||
|
]
|
||||||
|
}
|
@ -7,8 +7,9 @@ from unittest.mock import MagicMock
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from mmseg.datasets import (ADE20KDataset, BaseSegDataset, CityscapesDataset,
|
from mmseg.datasets import (ADE20KDataset, BaseSegDataset, CityscapesDataset,
|
||||||
COCOStuffDataset, ISPRSDataset, LoveDADataset,
|
COCOStuffDataset, DecathlonDataset, ISPRSDataset,
|
||||||
PascalVOCDataset, PotsdamDataset, iSAIDDataset)
|
LoveDADataset, PascalVOCDataset, PotsdamDataset,
|
||||||
|
iSAIDDataset)
|
||||||
from mmseg.registry import DATASETS
|
from mmseg.registry import DATASETS
|
||||||
from mmseg.utils import get_classes, get_palette
|
from mmseg.utils import get_classes, get_palette
|
||||||
|
|
||||||
@ -242,6 +243,22 @@ def test_isaid():
|
|||||||
assert len(test_dataset) == 1
|
assert len(test_dataset) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_decathlon():
|
||||||
|
data_root = osp.join(osp.dirname(__file__), '../data')
|
||||||
|
# test load training dataset
|
||||||
|
test_dataset = DecathlonDataset(
|
||||||
|
pipeline=[], data_root=data_root, ann_file='dataset.json')
|
||||||
|
assert len(test_dataset) == 1
|
||||||
|
|
||||||
|
# test load test dataset
|
||||||
|
test_dataset = DecathlonDataset(
|
||||||
|
pipeline=[],
|
||||||
|
data_root=data_root,
|
||||||
|
ann_file='dataset.json',
|
||||||
|
test_mode=True)
|
||||||
|
assert len(test_dataset) == 3
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('dataset, classes', [
|
@pytest.mark.parametrize('dataset, classes', [
|
||||||
('ADE20KDataset', ('wall', 'building')),
|
('ADE20KDataset', ('wall', 'building')),
|
||||||
('CityscapesDataset', ('road', 'sidewalk')),
|
('CityscapesDataset', ('road', 'sidewalk')),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user