[Feature]Add Decathlon dataset (#2227)
* [Feature]Add Decathlon dataset * fix test data * add file * remove order * revise default value for prefix * modify example * revise based on comments * add comments for utpull/2243/head
parent
293a057b61
commit
3d8fd35e26
|
@ -6,6 +6,7 @@ from .cityscapes import CityscapesDataset
|
|||
from .coco_stuff import COCOStuffDataset
|
||||
from .dark_zurich import DarkZurichDataset
|
||||
from .dataset_wrappers import MultiImageMixDataset
|
||||
from .decathlon import DecathlonDataset
|
||||
from .drive import DRIVEDataset
|
||||
from .hrf import HRFDataset
|
||||
from .isaid import iSAIDDataset
|
||||
|
@ -33,5 +34,6 @@ __all__ = [
|
|||
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
|
||||
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
|
||||
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge'
|
||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
||||
'DecathlonDataset'
|
||||
]
|
||||
|
|
|
@ -85,7 +85,7 @@ class BaseSegDataset(BaseDataset):
|
|||
seg_map_suffix='.png',
|
||||
metainfo: Optional[dict] = None,
|
||||
data_root: Optional[str] = None,
|
||||
data_prefix: dict = dict(img_path=None, seg_map_path=None),
|
||||
data_prefix: dict = dict(img_path='', seg_map_path=''),
|
||||
filter_cfg: Optional[dict] = None,
|
||||
indices: Optional[Union[int, Sequence[int]]] = None,
|
||||
serialize_data: bool = True,
|
||||
|
@ -132,9 +132,6 @@ class BaseSegDataset(BaseDataset):
|
|||
# if it is not defined
|
||||
updated_palette = self._update_palette()
|
||||
self._metainfo.update(dict(palette=updated_palette))
|
||||
if test_mode:
|
||||
assert self._metainfo.get('classes') is not None, \
|
||||
'dataset metainfo `classes` should be specified when testing'
|
||||
|
||||
# Join paths.
|
||||
if self.data_root is not None:
|
||||
|
@ -146,6 +143,10 @@ class BaseSegDataset(BaseDataset):
|
|||
if not lazy_init:
|
||||
self.full_init()
|
||||
|
||||
if test_mode:
|
||||
assert self._metainfo.get('classes') is not None, \
|
||||
'dataset metainfo `classes` should be specified when testing'
|
||||
|
||||
@classmethod
|
||||
def get_label_map(cls,
|
||||
new_classes: Optional[Sequence] = None
|
||||
|
|
|
@ -0,0 +1,96 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os.path as osp
|
||||
from typing import List
|
||||
|
||||
from mmengine.fileio import load
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class DecathlonDataset(BaseSegDataset):
|
||||
"""Dataset for Dacathlon dataset.
|
||||
|
||||
The dataset.json format is shown as follows
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
{
|
||||
"name": "BRATS",
|
||||
"tensorImageSize": "4D",
|
||||
"modality":
|
||||
{
|
||||
"0": "FLAIR",
|
||||
"1": "T1w",
|
||||
"2": "t1gd",
|
||||
"3": "T2w"
|
||||
},
|
||||
"labels": {
|
||||
"0": "background",
|
||||
"1": "edema",
|
||||
"2": "non-enhancing tumor",
|
||||
"3": "enhancing tumour"
|
||||
},
|
||||
"numTraining": 484,
|
||||
"numTest": 266,
|
||||
"training":
|
||||
[
|
||||
{
|
||||
"image": "./imagesTr/BRATS_306.nii.gz"
|
||||
"label": "./labelsTr/BRATS_306.nii.gz"
|
||||
...
|
||||
}
|
||||
]
|
||||
"test":
|
||||
[
|
||||
"./imagesTs/BRATS_557.nii.gz"
|
||||
...
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotation from directory or annotation file.
|
||||
|
||||
Returns:
|
||||
list[dict]: All data info of dataset.
|
||||
"""
|
||||
# `self.ann_file` denotes the absolute annotation file path if
|
||||
# `self.root=None` or relative path if `self.root=/path/to/data/`.
|
||||
annotations = load(self.ann_file)
|
||||
if not isinstance(annotations, dict):
|
||||
raise TypeError(f'The annotations loaded from annotation file '
|
||||
f'should be a dict, but got {type(annotations)}!')
|
||||
raw_data_list = annotations[
|
||||
'training'] if not self.test_mode else annotations['test']
|
||||
data_list = []
|
||||
for raw_data_info in raw_data_list:
|
||||
# `2:` works for removing './' in file path, which will break
|
||||
# loading from cloud storage.
|
||||
if isinstance(raw_data_info, dict):
|
||||
data_info = dict(
|
||||
img_path=osp.join(self.data_root, raw_data_info['image']
|
||||
[2:]))
|
||||
data_info['seg_map_path'] = osp.join(
|
||||
self.data_root, raw_data_info['label'][2:])
|
||||
else:
|
||||
data_info = dict(
|
||||
img_path=osp.join(self.data_root, raw_data_info)[2:])
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
annotations.pop('training')
|
||||
annotations.pop('test')
|
||||
|
||||
metainfo = copy.deepcopy(annotations)
|
||||
metainfo['classes'] = [*metainfo['labels'].values()]
|
||||
# Meta information load from annotation file will not influence the
|
||||
# existed meta information load from `BaseDataset.METAINFO` and
|
||||
# `metainfo` arguments defined in constructor.
|
||||
for k, v in metainfo.items():
|
||||
self._metainfo.setdefault(k, v)
|
||||
|
||||
return data_list
|
|
@ -0,0 +1,30 @@
|
|||
{
|
||||
"name": "BRATS",
|
||||
"description": "Gliomas segmentation tumour and oedema in on brain images",
|
||||
"tensorImageSize": "4D",
|
||||
"modality": {
|
||||
"0": "FLAIR",
|
||||
"1": "T1w",
|
||||
"2": "t1gd",
|
||||
"3": "T2w"
|
||||
},
|
||||
"labels": {
|
||||
"0": "background",
|
||||
"1": "edema",
|
||||
"2": "non-enhancing tumor",
|
||||
"3": "enhancing tumour"
|
||||
},
|
||||
"numTraining": 484,
|
||||
"numTest": 266,
|
||||
"training": [
|
||||
{
|
||||
"image": "./imagesTr/BRATS_457.nii.gz",
|
||||
"label": "./labelsTr/BRATS_457.nii.gz"
|
||||
}
|
||||
],
|
||||
"test": [
|
||||
"./imagesTs/BRATS_568.nii.gz",
|
||||
"./imagesTs/BRATS_515.nii.gz",
|
||||
"./imagesTs/BRATS_576.nii.gz"
|
||||
]
|
||||
}
|
|
@ -7,8 +7,9 @@ from unittest.mock import MagicMock
|
|||
import pytest
|
||||
|
||||
from mmseg.datasets import (ADE20KDataset, BaseSegDataset, CityscapesDataset,
|
||||
COCOStuffDataset, ISPRSDataset, LoveDADataset,
|
||||
PascalVOCDataset, PotsdamDataset, iSAIDDataset)
|
||||
COCOStuffDataset, DecathlonDataset, ISPRSDataset,
|
||||
LoveDADataset, PascalVOCDataset, PotsdamDataset,
|
||||
iSAIDDataset)
|
||||
from mmseg.registry import DATASETS
|
||||
from mmseg.utils import get_classes, get_palette
|
||||
|
||||
|
@ -242,6 +243,22 @@ def test_isaid():
|
|||
assert len(test_dataset) == 1
|
||||
|
||||
|
||||
def test_decathlon():
|
||||
data_root = osp.join(osp.dirname(__file__), '../data')
|
||||
# test load training dataset
|
||||
test_dataset = DecathlonDataset(
|
||||
pipeline=[], data_root=data_root, ann_file='dataset.json')
|
||||
assert len(test_dataset) == 1
|
||||
|
||||
# test load test dataset
|
||||
test_dataset = DecathlonDataset(
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
ann_file='dataset.json',
|
||||
test_mode=True)
|
||||
assert len(test_dataset) == 3
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dataset, classes', [
|
||||
('ADE20KDataset', ('wall', 'building')),
|
||||
('CityscapesDataset', ('road', 'sidewalk')),
|
||||
|
|
Loading…
Reference in New Issue