Merge branch 'zhengmiao/refactory-dataset' into 'refactor_dev'
[Refactory] Dataset refactory See merge request openmmlab-enterprise/openmmlab-ce/mmsegmentation!15pull/1801/head
commit
d64f941fb3
|
@ -1,10 +1,4 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import os.path as osp
|
|
||||||
|
|
||||||
import mmcv
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from mmseg.registry import DATASETS
|
from mmseg.registry import DATASETS
|
||||||
from .custom import CustomDataset
|
from .custom import CustomDataset
|
||||||
|
|
||||||
|
@ -18,33 +12,36 @@ class ADE20KDataset(CustomDataset):
|
||||||
The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to
|
The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to
|
||||||
'.png'.
|
'.png'.
|
||||||
"""
|
"""
|
||||||
CLASSES = (
|
METAINFO = dict(
|
||||||
'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
|
classes=('wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road',
|
||||||
'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
|
'bed ', 'windowpane', 'grass', 'cabinet', 'sidewalk',
|
||||||
'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
|
'person', 'earth', 'door', 'table', 'mountain', 'plant',
|
||||||
'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
|
'curtain', 'chair', 'car', 'water', 'painting', 'sofa',
|
||||||
'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
|
'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair',
|
||||||
'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
|
'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp',
|
||||||
|
'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
|
||||||
'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
|
'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
|
||||||
'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
|
'skyscraper', 'fireplace', 'refrigerator', 'grandstand',
|
||||||
'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
|
'path', 'stairs', 'runway', 'case', 'pool table', 'pillow',
|
||||||
'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
|
'screen door', 'stairway', 'river', 'bridge', 'bookcase',
|
||||||
'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
|
'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill',
|
||||||
'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
|
'bench', 'countertop', 'stove', 'palm', 'kitchen island',
|
||||||
'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
|
'computer', 'swivel chair', 'boat', 'bar', 'arcade machine',
|
||||||
'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
|
'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
|
||||||
'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
|
'chandelier', 'awning', 'streetlight', 'booth',
|
||||||
'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
|
'television receiver', 'airplane', 'dirt track', 'apparel',
|
||||||
'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
|
'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle',
|
||||||
'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
|
'buffet', 'poster', 'stage', 'van', 'ship', 'fountain',
|
||||||
'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
|
'conveyer belt', 'canopy', 'washer', 'plaything',
|
||||||
'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
|
'swimming pool', 'stool', 'barrel', 'basket', 'waterfall',
|
||||||
'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
|
'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food',
|
||||||
'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
|
'step', 'tank', 'trade name', 'microwave', 'pot', 'animal',
|
||||||
'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
|
'bicycle', 'lake', 'dishwasher', 'screen', 'blanket',
|
||||||
'clock', 'flag')
|
'sculpture', 'hood', 'sconce', 'vase', 'traffic light',
|
||||||
|
'tray', 'ashcan', 'fan', 'pier', 'crt screen', 'plate',
|
||||||
PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
|
'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
|
||||||
|
'clock', 'flag'),
|
||||||
|
palette=[[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
|
||||||
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
|
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
|
||||||
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
|
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
|
||||||
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
|
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
|
||||||
|
@ -81,87 +78,11 @@ class ADE20KDataset(CustomDataset):
|
||||||
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
|
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
|
||||||
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
|
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
|
||||||
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
|
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
|
||||||
[102, 255, 0], [92, 0, 255]]
|
[102, 255, 0], [92, 0, 255]])
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs) -> None:
|
||||||
super(ADE20KDataset, self).__init__(
|
super().__init__(
|
||||||
img_suffix='.jpg',
|
img_suffix='.jpg',
|
||||||
seg_map_suffix='.png',
|
seg_map_suffix='.png',
|
||||||
reduce_zero_label=True,
|
reduce_zero_label=True,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
def results2img(self, results, imgfile_prefix, to_label_id, indices=None):
|
|
||||||
"""Write the segmentation results to images.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
results (list[ndarray]): Testing results of the
|
|
||||||
dataset.
|
|
||||||
imgfile_prefix (str): The filename prefix of the png files.
|
|
||||||
If the prefix is "somepath/xxx",
|
|
||||||
the png files will be named "somepath/xxx.png".
|
|
||||||
to_label_id (bool): whether convert output to label_id for
|
|
||||||
submission.
|
|
||||||
indices (list[int], optional): Indices of input results, if not
|
|
||||||
set, all the indices of the dataset will be used.
|
|
||||||
Default: None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[str: str]: result txt files which contains corresponding
|
|
||||||
semantic segmentation images.
|
|
||||||
"""
|
|
||||||
if indices is None:
|
|
||||||
indices = list(range(len(self)))
|
|
||||||
|
|
||||||
mmcv.mkdir_or_exist(imgfile_prefix)
|
|
||||||
result_files = []
|
|
||||||
for result, idx in zip(results, indices):
|
|
||||||
|
|
||||||
filename = self.img_infos[idx]['filename']
|
|
||||||
basename = osp.splitext(osp.basename(filename))[0]
|
|
||||||
|
|
||||||
png_filename = osp.join(imgfile_prefix, f'{basename}.png')
|
|
||||||
|
|
||||||
# The index range of official requirement is from 0 to 150.
|
|
||||||
# But the index range of output is from 0 to 149.
|
|
||||||
# That is because we set reduce_zero_label=True.
|
|
||||||
result = result + 1
|
|
||||||
|
|
||||||
output = Image.fromarray(result.astype(np.uint8))
|
|
||||||
output.save(png_filename)
|
|
||||||
result_files.append(png_filename)
|
|
||||||
|
|
||||||
return result_files
|
|
||||||
|
|
||||||
def format_results(self,
|
|
||||||
results,
|
|
||||||
imgfile_prefix,
|
|
||||||
to_label_id=True,
|
|
||||||
indices=None):
|
|
||||||
"""Format the results into dir (standard format for ade20k evaluation).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
results (list): Testing results of the dataset.
|
|
||||||
imgfile_prefix (str | None): The prefix of images files. It
|
|
||||||
includes the file path and the prefix of filename, e.g.,
|
|
||||||
"a/b/prefix".
|
|
||||||
to_label_id (bool): whether convert output to label_id for
|
|
||||||
submission. Default: False
|
|
||||||
indices (list[int], optional): Indices of input results, if not
|
|
||||||
set, all the indices of the dataset will be used.
|
|
||||||
Default: None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (result_files, tmp_dir), result_files is a list containing
|
|
||||||
the image paths, tmp_dir is the temporal directory created
|
|
||||||
for saving json/png files when img_prefix is not specified.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if indices is None:
|
|
||||||
indices = list(range(len(self)))
|
|
||||||
|
|
||||||
assert isinstance(results, list), 'results must be a list.'
|
|
||||||
assert isinstance(indices, list), 'indices must be a list.'
|
|
||||||
|
|
||||||
result_files = self.results2img(results, imgfile_prefix, to_label_id,
|
|
||||||
indices)
|
|
||||||
return result_files
|
|
||||||
|
|
|
@ -13,15 +13,14 @@ class ChaseDB1Dataset(CustomDataset):
|
||||||
The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
|
The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
|
||||||
'_1stHO.png'.
|
'_1stHO.png'.
|
||||||
"""
|
"""
|
||||||
|
METAFILE = dict(
|
||||||
|
classes=('background', 'vessel'),
|
||||||
|
palette=[[120, 120, 120], [6, 230, 230]])
|
||||||
|
|
||||||
CLASSES = ('background', 'vessel')
|
def __init__(self, **kwargs) -> None:
|
||||||
|
super().__init__(
|
||||||
PALETTE = [[120, 120, 120], [6, 230, 230]]
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super(ChaseDB1Dataset, self).__init__(
|
|
||||||
img_suffix='.png',
|
img_suffix='.png',
|
||||||
seg_map_suffix='_1stHO.png',
|
seg_map_suffix='_1stHO.png',
|
||||||
reduce_zero_label=False,
|
reduce_zero_label=False,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
assert self.file_client.exists(self.img_dir)
|
assert self.file_client.exists(self.data_prefix['img_list'])
|
||||||
|
|
|
@ -1,11 +1,4 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import os.path as osp
|
|
||||||
|
|
||||||
import mmcv
|
|
||||||
import numpy as np
|
|
||||||
from mmcv.utils import print_log
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from mmseg.registry import DATASETS
|
from mmseg.registry import DATASETS
|
||||||
from .custom import CustomDataset
|
from .custom import CustomDataset
|
||||||
|
|
||||||
|
@ -17,198 +10,21 @@ class CityscapesDataset(CustomDataset):
|
||||||
The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is
|
The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is
|
||||||
fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset.
|
fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset.
|
||||||
"""
|
"""
|
||||||
|
METAINFO = dict(
|
||||||
CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
|
classes=('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
|
||||||
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
|
'traffic light', 'traffic sign', 'vegetation', 'terrain',
|
||||||
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
|
'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train',
|
||||||
'bicycle')
|
'motorcycle', 'bicycle'),
|
||||||
|
palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
|
||||||
PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
|
[190, 153, 153], [153, 153, 153], [250, 170,
|
||||||
[190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
|
30], [220, 220, 0],
|
||||||
[107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
|
[107, 142, 35], [152, 251, 152], [70, 130, 180],
|
||||||
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100],
|
[220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70],
|
||||||
[0, 80, 100], [0, 0, 230], [119, 11, 32]]
|
[0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]])
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
img_suffix='_leftImg8bit.png',
|
img_suffix='_leftImg8bit.png',
|
||||||
seg_map_suffix='_gtFine_labelTrainIds.png',
|
seg_map_suffix='_gtFine_labelTrainIds.png',
|
||||||
**kwargs):
|
**kwargs) -> None:
|
||||||
super(CityscapesDataset, self).__init__(
|
super().__init__(
|
||||||
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
|
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _convert_to_label_id(result):
|
|
||||||
"""Convert trainId to id for cityscapes."""
|
|
||||||
if isinstance(result, str):
|
|
||||||
result = np.load(result)
|
|
||||||
import cityscapesscripts.helpers.labels as CSLabels
|
|
||||||
result_copy = result.copy()
|
|
||||||
for trainId, label in CSLabels.trainId2label.items():
|
|
||||||
result_copy[result == trainId] = label.id
|
|
||||||
|
|
||||||
return result_copy
|
|
||||||
|
|
||||||
def results2img(self, results, imgfile_prefix, to_label_id, indices=None):
|
|
||||||
"""Write the segmentation results to images.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
results (list[ndarray]): Testing results of the
|
|
||||||
dataset.
|
|
||||||
imgfile_prefix (str): The filename prefix of the png files.
|
|
||||||
If the prefix is "somepath/xxx",
|
|
||||||
the png files will be named "somepath/xxx.png".
|
|
||||||
to_label_id (bool): whether convert output to label_id for
|
|
||||||
submission.
|
|
||||||
indices (list[int], optional): Indices of input results,
|
|
||||||
if not set, all the indices of the dataset will be used.
|
|
||||||
Default: None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[str: str]: result txt files which contains corresponding
|
|
||||||
semantic segmentation images.
|
|
||||||
"""
|
|
||||||
if indices is None:
|
|
||||||
indices = list(range(len(self)))
|
|
||||||
|
|
||||||
mmcv.mkdir_or_exist(imgfile_prefix)
|
|
||||||
result_files = []
|
|
||||||
for result, idx in zip(results, indices):
|
|
||||||
if to_label_id:
|
|
||||||
result = self._convert_to_label_id(result)
|
|
||||||
filename = self.img_infos[idx]['filename']
|
|
||||||
basename = osp.splitext(osp.basename(filename))[0]
|
|
||||||
|
|
||||||
png_filename = osp.join(imgfile_prefix, f'{basename}.png')
|
|
||||||
|
|
||||||
output = Image.fromarray(result.astype(np.uint8)).convert('P')
|
|
||||||
import cityscapesscripts.helpers.labels as CSLabels
|
|
||||||
palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8)
|
|
||||||
for label_id, label in CSLabels.id2label.items():
|
|
||||||
palette[label_id] = label.color
|
|
||||||
|
|
||||||
output.putpalette(palette)
|
|
||||||
output.save(png_filename)
|
|
||||||
result_files.append(png_filename)
|
|
||||||
|
|
||||||
return result_files
|
|
||||||
|
|
||||||
def format_results(self,
|
|
||||||
results,
|
|
||||||
imgfile_prefix,
|
|
||||||
to_label_id=True,
|
|
||||||
indices=None):
|
|
||||||
"""Format the results into dir (standard format for Cityscapes
|
|
||||||
evaluation).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
results (list): Testing results of the dataset.
|
|
||||||
imgfile_prefix (str): The prefix of images files. It
|
|
||||||
includes the file path and the prefix of filename, e.g.,
|
|
||||||
"a/b/prefix".
|
|
||||||
to_label_id (bool): whether convert output to label_id for
|
|
||||||
submission. Default: False
|
|
||||||
indices (list[int], optional): Indices of input results,
|
|
||||||
if not set, all the indices of the dataset will be used.
|
|
||||||
Default: None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (result_files, tmp_dir), result_files is a list containing
|
|
||||||
the image paths, tmp_dir is the temporal directory created
|
|
||||||
for saving json/png files when img_prefix is not specified.
|
|
||||||
"""
|
|
||||||
if indices is None:
|
|
||||||
indices = list(range(len(self)))
|
|
||||||
|
|
||||||
assert isinstance(results, list), 'results must be a list.'
|
|
||||||
assert isinstance(indices, list), 'indices must be a list.'
|
|
||||||
|
|
||||||
result_files = self.results2img(results, imgfile_prefix, to_label_id,
|
|
||||||
indices)
|
|
||||||
|
|
||||||
return result_files
|
|
||||||
|
|
||||||
def evaluate(self,
|
|
||||||
results,
|
|
||||||
metric='mIoU',
|
|
||||||
logger=None,
|
|
||||||
imgfile_prefix=None):
|
|
||||||
"""Evaluation in Cityscapes/default protocol.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
results (list): Testing results of the dataset.
|
|
||||||
metric (str | list[str]): Metrics to be evaluated.
|
|
||||||
logger (logging.Logger | None | str): Logger used for printing
|
|
||||||
related information during evaluation. Default: None.
|
|
||||||
imgfile_prefix (str | None): The prefix of output image file,
|
|
||||||
for cityscapes evaluation only. It includes the file path and
|
|
||||||
the prefix of filename, e.g., "a/b/prefix".
|
|
||||||
If results are evaluated with cityscapes protocol, it would be
|
|
||||||
the prefix of output png files. The output files would be
|
|
||||||
png images under folder "a/b/prefix/xxx.png", where "xxx" is
|
|
||||||
the image name of cityscapes. If not specified, a temp file
|
|
||||||
will be created for evaluation.
|
|
||||||
Default: None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict[str, float]: Cityscapes/default metrics.
|
|
||||||
"""
|
|
||||||
|
|
||||||
eval_results = dict()
|
|
||||||
metrics = metric.copy() if isinstance(metric, list) else [metric]
|
|
||||||
if 'cityscapes' in metrics:
|
|
||||||
eval_results.update(
|
|
||||||
self._evaluate_cityscapes(results, logger, imgfile_prefix))
|
|
||||||
metrics.remove('cityscapes')
|
|
||||||
if len(metrics) > 0:
|
|
||||||
eval_results.update(
|
|
||||||
super(CityscapesDataset,
|
|
||||||
self).evaluate(results, metrics, logger))
|
|
||||||
|
|
||||||
return eval_results
|
|
||||||
|
|
||||||
def _evaluate_cityscapes(self, results, logger, imgfile_prefix):
|
|
||||||
"""Evaluation in Cityscapes protocol.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
results (list): Testing results of the dataset.
|
|
||||||
logger (logging.Logger | str | None): Logger used for printing
|
|
||||||
related information during evaluation. Default: None.
|
|
||||||
imgfile_prefix (str | None): The prefix of output image file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict[str: float]: Cityscapes evaluation results.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError('Please run "pip install cityscapesscripts" to '
|
|
||||||
'install cityscapesscripts first.')
|
|
||||||
msg = 'Evaluating in Cityscapes style'
|
|
||||||
if logger is None:
|
|
||||||
msg = '\n' + msg
|
|
||||||
print_log(msg, logger=logger)
|
|
||||||
|
|
||||||
result_dir = imgfile_prefix
|
|
||||||
|
|
||||||
eval_results = dict()
|
|
||||||
print_log(f'Evaluating results under {result_dir} ...', logger=logger)
|
|
||||||
|
|
||||||
CSEval.args.evalInstLevelScore = True
|
|
||||||
CSEval.args.predictionPath = osp.abspath(result_dir)
|
|
||||||
CSEval.args.evalPixelAccuracy = True
|
|
||||||
CSEval.args.JSONOutput = False
|
|
||||||
|
|
||||||
seg_map_list = []
|
|
||||||
pred_list = []
|
|
||||||
|
|
||||||
# when evaluating with official cityscapesscripts,
|
|
||||||
# **_gtFine_labelIds.png is used
|
|
||||||
for seg_map in mmcv.scandir(
|
|
||||||
self.ann_dir, 'gtFine_labelIds.png', recursive=True):
|
|
||||||
seg_map_list.append(osp.join(self.ann_dir, seg_map))
|
|
||||||
pred_list.append(CSEval.getPrediction(CSEval.args, seg_map))
|
|
||||||
|
|
||||||
eval_results.update(
|
|
||||||
CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args))
|
|
||||||
|
|
||||||
return eval_results
|
|
||||||
|
|
|
@ -14,38 +14,40 @@ class COCOStuffDataset(CustomDataset):
|
||||||
10k and 164k versions, respectively. The ``img_suffix`` is fixed to '.jpg',
|
10k and 164k versions, respectively. The ``img_suffix`` is fixed to '.jpg',
|
||||||
and ``seg_map_suffix`` is fixed to '.png'.
|
and ``seg_map_suffix`` is fixed to '.png'.
|
||||||
"""
|
"""
|
||||||
CLASSES = (
|
METAINFO = dict(
|
||||||
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
|
classes=(
|
||||||
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
|
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
||||||
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
|
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
|
||||||
'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
|
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
|
||||||
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
|
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
|
||||||
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
|
||||||
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
|
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
|
||||||
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
|
'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
|
||||||
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
|
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
|
||||||
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
|
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
|
||||||
'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
|
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
|
||||||
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
|
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
|
||||||
|
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
|
||||||
|
'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
|
||||||
'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner',
|
'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner',
|
||||||
'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet',
|
'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet',
|
||||||
'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile',
|
'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile',
|
||||||
'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain',
|
'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain',
|
||||||
'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble',
|
'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble',
|
||||||
'floor-other', 'floor-stone', 'floor-tile', 'floor-wood',
|
'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', 'flower',
|
||||||
'flower', 'fog', 'food-other', 'fruit', 'furniture-other', 'grass',
|
'fog', 'food-other', 'fruit', 'furniture-other', 'grass', 'gravel',
|
||||||
'gravel', 'ground-other', 'hill', 'house', 'leaves', 'light', 'mat',
|
'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', 'metal',
|
||||||
'metal', 'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net',
|
'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net',
|
||||||
'paper', 'pavement', 'pillow', 'plant-other', 'plastic', 'platform',
|
'paper', 'pavement', 'pillow', 'plant-other', 'plastic',
|
||||||
'playingfield', 'railing', 'railroad', 'river', 'road', 'rock', 'roof',
|
'platform', 'playingfield', 'railing', 'railroad', 'river', 'road',
|
||||||
'rug', 'salad', 'sand', 'sea', 'shelf', 'sky-other', 'skyscraper',
|
'rock', 'roof', 'rug', 'salad', 'sand', 'sea', 'shelf',
|
||||||
'snow', 'solid-other', 'stairs', 'stone', 'straw', 'structural-other',
|
'sky-other', 'skyscraper', 'snow', 'solid-other', 'stairs',
|
||||||
'table', 'tent', 'textile-other', 'towel', 'tree', 'vegetable',
|
'stone', 'straw', 'structural-other', 'table', 'tent',
|
||||||
'wall-brick', 'wall-concrete', 'wall-other', 'wall-panel',
|
'textile-other', 'towel', 'tree', 'vegetable', 'wall-brick',
|
||||||
'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'waterdrops',
|
'wall-concrete', 'wall-other', 'wall-panel', 'wall-stone',
|
||||||
'window-blind', 'window-other', 'wood')
|
'wall-tile', 'wall-wood', 'water-other', 'waterdrops',
|
||||||
|
'window-blind', 'window-other', 'wood'),
|
||||||
PALETTE = [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192],
|
palette=[[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192],
|
||||||
[0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64],
|
[0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64],
|
||||||
[0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224],
|
[0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224],
|
||||||
[0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192],
|
[0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192],
|
||||||
|
@ -87,8 +89,8 @@ class COCOStuffDataset(CustomDataset):
|
||||||
[0, 64, 32], [64, 160, 192], [192, 64, 64], [128, 64, 160],
|
[0, 64, 32], [64, 160, 192], [192, 64, 64], [128, 64, 160],
|
||||||
[64, 32, 192], [192, 192, 192], [0, 64, 160], [192, 160, 192],
|
[64, 32, 192], [192, 192, 192], [0, 64, 160], [192, 160, 192],
|
||||||
[192, 192, 0], [128, 64, 96], [192, 32, 64], [192, 64, 128],
|
[192, 192, 0], [128, 64, 96], [192, 32, 64], [192, 64, 128],
|
||||||
[64, 192, 96], [64, 160, 64], [64, 64, 0]]
|
[64, 192, 96], [64, 160, 64], [64, 64, 0]])
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs) -> None:
|
||||||
super(COCOStuffDataset, self).__init__(
|
super().__init__(
|
||||||
img_suffix='.jpg', seg_map_suffix='_labelTrainIds.png', **kwargs)
|
img_suffix='.jpg', seg_map_suffix='_labelTrainIds.png', **kwargs)
|
||||||
|
|
|
@ -1,22 +1,17 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import copy
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import warnings
|
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mmcv.utils import print_log
|
from mmengine.dataset import BaseDataset, Compose
|
||||||
from prettytable import PrettyTable
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
from mmseg.core import eval_metrics, intersect_and_union, pre_eval_to_metrics
|
|
||||||
from mmseg.registry import DATASETS
|
from mmseg.registry import DATASETS
|
||||||
from mmseg.utils import get_root_logger
|
|
||||||
from .pipelines import Compose, LoadAnnotations
|
|
||||||
|
|
||||||
|
|
||||||
@DATASETS.register_module()
|
@DATASETS.register_module()
|
||||||
class CustomDataset(Dataset):
|
class CustomDataset(BaseDataset):
|
||||||
"""Custom dataset for semantic segmentation. An example of file structure
|
"""Custom dataset for semantic segmentation. An example of file structure
|
||||||
is as followed.
|
is as followed.
|
||||||
|
|
||||||
|
@ -46,330 +41,163 @@ class CustomDataset(Dataset):
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pipeline (list[dict]): Processing pipeline
|
ann_file (str): Annotation file path. Defaults to ''.
|
||||||
img_dir (str): Path to image directory
|
metainfo (dict, optional): Meta information for dataset, such as
|
||||||
|
specify classes to load. Defaults to None.
|
||||||
|
data_root (str, optional): The root directory for ``data_prefix`` and
|
||||||
|
``ann_file``. Defaults to None.
|
||||||
|
data_prefix (dict, optional): Prefix for training data. Defaults to
|
||||||
|
dict(img_path=None, seg_path=None).
|
||||||
img_suffix (str): Suffix of images. Default: '.jpg'
|
img_suffix (str): Suffix of images. Default: '.jpg'
|
||||||
ann_dir (str, optional): Path to annotation directory. Default: None
|
|
||||||
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
|
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
|
||||||
split (str, optional): Split txt file. If split is specified, only
|
filter_cfg (dict, optional): Config for filter data. Defaults to None.
|
||||||
file with suffix in the splits will be loaded. Otherwise, all
|
indices (int or Sequence[int], optional): Support using first few
|
||||||
images in img_dir/ann_dir will be loaded. Default: None
|
data in annotation file to facilitate training/testing on a smaller
|
||||||
data_root (str, optional): Data root for img_dir/ann_dir. Default:
|
dataset. Defaults to None which means using all ``data_infos``.
|
||||||
None.
|
serialize_data (bool, optional): Whether to hold memory using
|
||||||
test_mode (bool): If test_mode=True, gt wouldn't be loaded.
|
serialized objects, when enabled, data loader workers can use
|
||||||
|
shared RAM from master process instead of making a copy. Defaults
|
||||||
|
to True.
|
||||||
|
pipeline (list, optional): Processing pipeline. Defaults to [].
|
||||||
|
test_mode (bool, optional): ``test_mode=True`` means in test phase.
|
||||||
|
Defaults to False.
|
||||||
|
lazy_init (bool, optional): Whether to load annotation during
|
||||||
|
instantiation. In some cases, such as visualization, only the meta
|
||||||
|
information of the dataset is needed, which is not necessary to
|
||||||
|
load annotation file. ``Basedataset`` can skip load annotations to
|
||||||
|
save time by set ``lazy_init=False``. Defaults to False.
|
||||||
|
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
|
||||||
|
None img. The maximum extra number of cycles to get a valid
|
||||||
|
image. Defaults to 1000.
|
||||||
ignore_index (int): The label index to be ignored. Default: 255
|
ignore_index (int): The label index to be ignored. Default: 255
|
||||||
reduce_zero_label (bool): Whether to mark label zero as ignored.
|
reduce_zero_label (bool): Whether to mark label zero as ignored.
|
||||||
Default: False
|
Default: False
|
||||||
classes (str | Sequence[str], optional): Specify classes to load.
|
|
||||||
If is None, ``cls.CLASSES`` will be used. Default: None.
|
|
||||||
palette (Sequence[Sequence[int]]] | np.ndarray | None):
|
|
||||||
The palette of segmentation map. If None is given, and
|
|
||||||
self.PALETTE is None, random palette will be generated.
|
|
||||||
Default: None
|
|
||||||
gt_seg_map_loader_cfg (dict, optional): build LoadAnnotations to
|
|
||||||
load gt for evaluation, load from disk by default. Default: None.
|
|
||||||
file_client_args (dict): Arguments to instantiate a FileClient.
|
file_client_args (dict): Arguments to instantiate a FileClient.
|
||||||
See :class:`mmcv.fileio.FileClient` for details.
|
See :class:`mmcv.fileio.FileClient` for details.
|
||||||
Defaults to ``dict(backend='disk')``.
|
Defaults to ``dict(backend='disk')``.
|
||||||
"""
|
"""
|
||||||
|
METAINFO: dict = dict()
|
||||||
|
|
||||||
CLASSES = None
|
def __init__(
|
||||||
|
self,
|
||||||
PALETTE = None
|
ann_file: str = '',
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
pipeline,
|
|
||||||
img_dir,
|
|
||||||
img_suffix='.jpg',
|
img_suffix='.jpg',
|
||||||
ann_dir=None,
|
|
||||||
seg_map_suffix='.png',
|
seg_map_suffix='.png',
|
||||||
split=None,
|
metainfo: Optional[dict] = None,
|
||||||
data_root=None,
|
data_root: Optional[str] = None,
|
||||||
test_mode=False,
|
data_prefix: dict = dict(img_path=None, seg_map_path=None),
|
||||||
ignore_index=255,
|
filter_cfg: Optional[dict] = None,
|
||||||
reduce_zero_label=False,
|
indices: Optional[Union[int, Sequence[int]]] = None,
|
||||||
classes=None,
|
serialize_data: bool = True,
|
||||||
palette=None,
|
pipeline: List[Union[dict, Callable]] = [],
|
||||||
gt_seg_map_loader_cfg=None,
|
test_mode: bool = False,
|
||||||
file_client_args=dict(backend='disk')):
|
lazy_init: bool = False,
|
||||||
self.pipeline = Compose(pipeline)
|
max_refetch: int = 1000,
|
||||||
self.img_dir = img_dir
|
ignore_index: int = 255,
|
||||||
|
reduce_zero_label: bool = True,
|
||||||
|
file_client_args: dict = dict(backend='disk')
|
||||||
|
) -> None:
|
||||||
|
|
||||||
self.img_suffix = img_suffix
|
self.img_suffix = img_suffix
|
||||||
self.ann_dir = ann_dir
|
|
||||||
self.seg_map_suffix = seg_map_suffix
|
self.seg_map_suffix = seg_map_suffix
|
||||||
self.split = split
|
|
||||||
self.data_root = data_root
|
|
||||||
self.test_mode = test_mode
|
|
||||||
self.ignore_index = ignore_index
|
self.ignore_index = ignore_index
|
||||||
self.reduce_zero_label = reduce_zero_label
|
self.reduce_zero_label = reduce_zero_label
|
||||||
self.label_map = None
|
|
||||||
self.CLASSES, self.PALETTE = self.get_classes_and_palette(
|
|
||||||
classes, palette)
|
|
||||||
self.gt_seg_map_loader = LoadAnnotations(
|
|
||||||
) if gt_seg_map_loader_cfg is None else LoadAnnotations(
|
|
||||||
**gt_seg_map_loader_cfg)
|
|
||||||
|
|
||||||
self.file_client_args = file_client_args
|
self.file_client_args = file_client_args
|
||||||
self.file_client = mmcv.FileClient.infer_client(self.file_client_args)
|
self.file_client = mmcv.FileClient.infer_client(self.file_client_args)
|
||||||
|
|
||||||
|
self.data_root = data_root
|
||||||
|
self.data_prefix = copy.copy(data_prefix)
|
||||||
|
self.ann_file = ann_file
|
||||||
|
self.filter_cfg = copy.deepcopy(filter_cfg)
|
||||||
|
self._indices = indices
|
||||||
|
self.serialize_data = serialize_data
|
||||||
|
self.test_mode = test_mode
|
||||||
|
self.max_refetch = max_refetch
|
||||||
|
self.data_list: List[dict] = []
|
||||||
|
self.data_bytes: np.ndarray
|
||||||
|
|
||||||
|
# Set meta information.
|
||||||
|
self._metainfo = self._load_metainfo(copy.deepcopy(metainfo))
|
||||||
|
|
||||||
|
# Get label map for custom classes
|
||||||
|
new_classes = self._metainfo.get('classes', None)
|
||||||
|
self.label_map = self.get_label_map(new_classes)
|
||||||
|
|
||||||
|
# Update palette based on label map or generate palette
|
||||||
|
# if it is not defined
|
||||||
|
updated_palette = self._update_palette()
|
||||||
|
self._metainfo.update(dict(palette=updated_palette))
|
||||||
if test_mode:
|
if test_mode:
|
||||||
assert self.CLASSES is not None, \
|
assert self._metainfo.get('classes') is not None, \
|
||||||
'`cls.CLASSES` or `classes` should be specified when testing'
|
'dataset metainfo `classes` should be specified when testing'
|
||||||
|
|
||||||
# join paths if data_root is specified
|
# Join paths.
|
||||||
if self.data_root is not None:
|
if self.data_root is not None:
|
||||||
if not osp.isabs(self.img_dir):
|
self._join_prefix()
|
||||||
self.img_dir = osp.join(self.data_root, self.img_dir)
|
|
||||||
if not (self.ann_dir is None or osp.isabs(self.ann_dir)):
|
|
||||||
self.ann_dir = osp.join(self.data_root, self.ann_dir)
|
|
||||||
if not (self.split is None or osp.isabs(self.split)):
|
|
||||||
self.split = osp.join(self.data_root, self.split)
|
|
||||||
|
|
||||||
# load annotations
|
# Build pipeline.
|
||||||
self.img_infos = self.load_annotations(self.img_dir, self.img_suffix,
|
self.pipeline = Compose(pipeline)
|
||||||
self.ann_dir,
|
# Full initialize the dataset.
|
||||||
self.seg_map_suffix, self.split)
|
if not lazy_init:
|
||||||
|
self.full_init()
|
||||||
|
|
||||||
def __len__(self):
|
@classmethod
|
||||||
"""Total number of samples of data."""
|
def get_label_map(cls,
|
||||||
return len(self.img_infos)
|
new_classes: Optional[Sequence] = None
|
||||||
|
) -> Union[Dict, None]:
|
||||||
|
"""Require label mapping.
|
||||||
|
|
||||||
def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix,
|
The ``label_map`` is a dictionary, its keys are the old label ids and
|
||||||
split):
|
its values are the new label ids, and is used for changing pixel
|
||||||
"""Load annotation from directory.
|
labels in load_annotations. If and only if old classes in cls.METAINFO
|
||||||
|
is not equal to new classes in self._metainfo and nether of them is not
|
||||||
|
None, `label_map` is not None.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
img_dir (str): Path to image directory
|
new_classes (list, tuple, optional): The new classes name from
|
||||||
img_suffix (str): Suffix of images.
|
metainfo. Default to None.
|
||||||
ann_dir (str|None): Path to annotation directory.
|
|
||||||
seg_map_suffix (str|None): Suffix of segmentation maps.
|
|
||||||
split (str|None): Split txt file. If split is specified, only file
|
|
||||||
with suffix in the splits will be loaded. Otherwise, all images
|
|
||||||
in img_dir/ann_dir will be loaded. Default: None
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[dict]: All image info of dataset.
|
dict, optional: The mapping from old classes in cls.METAINFO to
|
||||||
|
new classes in self._metainfo
|
||||||
"""
|
"""
|
||||||
|
old_classes = cls.METAINFO.get('classes', None)
|
||||||
|
if (new_classes is not None and old_classes is not None
|
||||||
|
and list(new_classes) != list(old_classes)):
|
||||||
|
|
||||||
img_infos = []
|
label_map = {}
|
||||||
if split is not None:
|
if not set(new_classes).issubset(cls.METAINFO['classes']):
|
||||||
lines = mmcv.list_from_file(
|
raise ValueError(
|
||||||
split, file_client_args=self.file_client_args)
|
f'new classes {new_classes} is not a '
|
||||||
for line in lines:
|
f'subset of classes {old_classes} in METAINFO.')
|
||||||
img_name = line.strip()
|
for i, c in enumerate(old_classes):
|
||||||
img_info = dict(filename=img_name + img_suffix)
|
if c not in new_classes:
|
||||||
if ann_dir is not None:
|
label_map[i] = -1
|
||||||
seg_map = img_name + seg_map_suffix
|
|
||||||
img_info['ann'] = dict(seg_map=seg_map)
|
|
||||||
img_infos.append(img_info)
|
|
||||||
else:
|
else:
|
||||||
for img in self.file_client.list_dir_or_file(
|
label_map[i] = new_classes.index(c)
|
||||||
dir_path=img_dir,
|
return label_map
|
||||||
list_dir=False,
|
|
||||||
suffix=img_suffix,
|
|
||||||
recursive=True):
|
|
||||||
img_info = dict(filename=img)
|
|
||||||
if ann_dir is not None:
|
|
||||||
seg_map = img.replace(img_suffix, seg_map_suffix)
|
|
||||||
img_info['ann'] = dict(seg_map=seg_map)
|
|
||||||
img_infos.append(img_info)
|
|
||||||
img_infos = sorted(img_infos, key=lambda x: x['filename'])
|
|
||||||
|
|
||||||
print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger())
|
|
||||||
return img_infos
|
|
||||||
|
|
||||||
def get_ann_info(self, idx):
|
|
||||||
"""Get annotation by index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
idx (int): Index of data.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Annotation info of specified index.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return self.img_infos[idx]['ann']
|
|
||||||
|
|
||||||
def pre_pipeline(self, results):
|
|
||||||
"""Prepare results dict for pipeline."""
|
|
||||||
results['seg_fields'] = []
|
|
||||||
results['img_prefix'] = self.img_dir
|
|
||||||
results['seg_prefix'] = self.ann_dir
|
|
||||||
if self.custom_classes:
|
|
||||||
results['label_map'] = self.label_map
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
"""Get training/test data after pipeline.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
idx (int): Index of data.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Training/test data (with annotation if `test_mode` is set
|
|
||||||
False).
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self.test_mode:
|
|
||||||
return self.prepare_test_img(idx)
|
|
||||||
else:
|
else:
|
||||||
return self.prepare_train_img(idx)
|
return None
|
||||||
|
|
||||||
def prepare_train_img(self, idx):
|
def _update_palette(self) -> list:
|
||||||
"""Get training data and annotations after pipeline.
|
"""Update palette after loading metainfo.
|
||||||
|
|
||||||
Args:
|
If length of palette is equal to classes, just return the palette.
|
||||||
idx (int): Index of data.
|
If palette is not defined, it will randomly generate a palette.
|
||||||
|
If classes is updated by customer, it will return the subset of
|
||||||
|
palette.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Training data and annotation after pipeline with new keys
|
Sequence: Palette for current dataset.
|
||||||
introduced by pipeline.
|
|
||||||
"""
|
"""
|
||||||
|
palette = self._metainfo.get('palette', [])
|
||||||
|
classes = self._metainfo.get('classes', [])
|
||||||
|
# palette does match classes
|
||||||
|
if len(palette) == len(classes):
|
||||||
|
return palette
|
||||||
|
|
||||||
img_info = self.img_infos[idx]
|
if len(palette) == 0:
|
||||||
ann_info = self.get_ann_info(idx)
|
|
||||||
results = dict(img_info=img_info, ann_info=ann_info)
|
|
||||||
self.pre_pipeline(results)
|
|
||||||
return self.pipeline(results)
|
|
||||||
|
|
||||||
def prepare_test_img(self, idx):
|
|
||||||
"""Get testing data after pipeline.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
idx (int): Index of data.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Testing data after pipeline with new keys introduced by
|
|
||||||
pipeline.
|
|
||||||
"""
|
|
||||||
|
|
||||||
img_info = self.img_infos[idx]
|
|
||||||
results = dict(img_info=img_info)
|
|
||||||
self.pre_pipeline(results)
|
|
||||||
return self.pipeline(results)
|
|
||||||
|
|
||||||
def format_results(self, results, imgfile_prefix, indices=None, **kwargs):
|
|
||||||
"""Place holder to format result to dataset specific output."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def get_gt_seg_map_by_idx(self, index):
|
|
||||||
"""Get one ground truth segmentation map for evaluation."""
|
|
||||||
ann_info = self.get_ann_info(index)
|
|
||||||
results = dict(ann_info=ann_info)
|
|
||||||
self.pre_pipeline(results)
|
|
||||||
self.gt_seg_map_loader(results)
|
|
||||||
return results['gt_semantic_seg']
|
|
||||||
|
|
||||||
def get_gt_seg_maps(self, efficient_test=None):
|
|
||||||
"""Get ground truth segmentation maps for evaluation."""
|
|
||||||
if efficient_test is not None:
|
|
||||||
warnings.warn(
|
|
||||||
'DeprecationWarning: ``efficient_test`` has been deprecated '
|
|
||||||
'since MMSeg v0.16, the ``get_gt_seg_maps()`` is CPU memory '
|
|
||||||
'friendly by default. ')
|
|
||||||
|
|
||||||
for idx in range(len(self)):
|
|
||||||
ann_info = self.get_ann_info(idx)
|
|
||||||
results = dict(ann_info=ann_info)
|
|
||||||
self.pre_pipeline(results)
|
|
||||||
self.gt_seg_map_loader(results)
|
|
||||||
yield results['gt_semantic_seg']
|
|
||||||
|
|
||||||
def pre_eval(self, preds, indices):
|
|
||||||
"""Collect eval result from each iteration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
preds (list[torch.Tensor] | torch.Tensor): the segmentation logit
|
|
||||||
after argmax, shape (N, H, W).
|
|
||||||
indices (list[int] | int): the prediction related ground truth
|
|
||||||
indices.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[torch.Tensor]: (area_intersect, area_union, area_prediction,
|
|
||||||
area_ground_truth).
|
|
||||||
"""
|
|
||||||
# In order to compat with batch inference
|
|
||||||
if not isinstance(indices, list):
|
|
||||||
indices = [indices]
|
|
||||||
if not isinstance(preds, list):
|
|
||||||
preds = [preds]
|
|
||||||
|
|
||||||
pre_eval_results = []
|
|
||||||
|
|
||||||
for pred, index in zip(preds, indices):
|
|
||||||
seg_map = self.get_gt_seg_map_by_idx(index)
|
|
||||||
pre_eval_results.append(
|
|
||||||
intersect_and_union(
|
|
||||||
pred,
|
|
||||||
seg_map,
|
|
||||||
len(self.CLASSES),
|
|
||||||
self.ignore_index,
|
|
||||||
# as the labels has been converted when dataset initialized
|
|
||||||
# in `get_palette_for_custom_classes ` this `label_map`
|
|
||||||
# should be `dict()`, see
|
|
||||||
# https://github.com/open-mmlab/mmsegmentation/issues/1415
|
|
||||||
# for more ditails
|
|
||||||
label_map=dict(),
|
|
||||||
reduce_zero_label=self.reduce_zero_label))
|
|
||||||
|
|
||||||
return pre_eval_results
|
|
||||||
|
|
||||||
def get_classes_and_palette(self, classes=None, palette=None):
|
|
||||||
"""Get class names of current dataset.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
classes (Sequence[str] | str | None): If classes is None, use
|
|
||||||
default CLASSES defined by builtin dataset. If classes is a
|
|
||||||
string, take it as a file name. The file contains the name of
|
|
||||||
classes where each line contains one class name. If classes is
|
|
||||||
a tuple or list, override the CLASSES defined by the dataset.
|
|
||||||
palette (Sequence[Sequence[int]]] | np.ndarray | None):
|
|
||||||
The palette of segmentation map. If None is given, random
|
|
||||||
palette will be generated. Default: None
|
|
||||||
"""
|
|
||||||
if classes is None:
|
|
||||||
self.custom_classes = False
|
|
||||||
return self.CLASSES, self.PALETTE
|
|
||||||
|
|
||||||
self.custom_classes = True
|
|
||||||
if isinstance(classes, str):
|
|
||||||
# take it as a file path
|
|
||||||
class_names = mmcv.list_from_file(classes)
|
|
||||||
elif isinstance(classes, (tuple, list)):
|
|
||||||
class_names = classes
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unsupported type {type(classes)} of classes.')
|
|
||||||
|
|
||||||
if self.CLASSES:
|
|
||||||
if not set(class_names).issubset(self.CLASSES):
|
|
||||||
raise ValueError('classes is not a subset of CLASSES.')
|
|
||||||
|
|
||||||
# dictionary, its keys are the old label ids and its values
|
|
||||||
# are the new label ids.
|
|
||||||
# used for changing pixel labels in load_annotations.
|
|
||||||
self.label_map = {}
|
|
||||||
for i, c in enumerate(self.CLASSES):
|
|
||||||
if c not in class_names:
|
|
||||||
self.label_map[i] = -1
|
|
||||||
else:
|
|
||||||
self.label_map[i] = class_names.index(c)
|
|
||||||
|
|
||||||
palette = self.get_palette_for_custom_classes(class_names, palette)
|
|
||||||
|
|
||||||
return class_names, palette
|
|
||||||
|
|
||||||
def get_palette_for_custom_classes(self, class_names, palette=None):
|
|
||||||
|
|
||||||
if self.label_map is not None:
|
|
||||||
# return subset of palette
|
|
||||||
palette = []
|
|
||||||
for old_id, new_id in sorted(
|
|
||||||
self.label_map.items(), key=lambda x: x[1]):
|
|
||||||
if new_id != -1:
|
|
||||||
palette.append(self.PALETTE[old_id])
|
|
||||||
palette = type(self.PALETTE)(palette)
|
|
||||||
|
|
||||||
elif palette is None:
|
|
||||||
if self.PALETTE is None:
|
|
||||||
# Get random state before set seed, and restore
|
# Get random state before set seed, and restore
|
||||||
# random state later.
|
# random state later.
|
||||||
# It will prevent loss of randomness, as the palette
|
# It will prevent loss of randomness, as the palette
|
||||||
|
@ -378,110 +206,55 @@ class CustomDataset(Dataset):
|
||||||
state = np.random.get_state()
|
state = np.random.get_state()
|
||||||
np.random.seed(42)
|
np.random.seed(42)
|
||||||
# random palette
|
# random palette
|
||||||
palette = np.random.randint(0, 255, size=(len(class_names), 3))
|
new_palette = np.random.randint(
|
||||||
|
0, 255, size=(len(classes), 3)).tolist()
|
||||||
np.random.set_state(state)
|
np.random.set_state(state)
|
||||||
|
elif len(palette) >= len(classes) and self.label_map is not None:
|
||||||
|
new_palette = []
|
||||||
|
# return subset of palette
|
||||||
|
for old_id, new_id in sorted(
|
||||||
|
self.label_map.items(), key=lambda x: x[1]):
|
||||||
|
if new_id != -1:
|
||||||
|
new_palette.append(palette[old_id])
|
||||||
|
new_palette = type(palette)(new_palette)
|
||||||
else:
|
else:
|
||||||
palette = self.PALETTE
|
raise ValueError('palette does not match classes '
|
||||||
|
f'as metainfo is {self._metainfo}.')
|
||||||
|
return new_palette
|
||||||
|
|
||||||
return palette
|
def load_data_list(self) -> List[dict]:
|
||||||
|
"""Load annotation from directory or annotation file.
|
||||||
def evaluate(self,
|
|
||||||
results,
|
|
||||||
metric='mIoU',
|
|
||||||
logger=None,
|
|
||||||
gt_seg_maps=None,
|
|
||||||
**kwargs):
|
|
||||||
"""Evaluate the dataset.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
results (list[tuple[torch.Tensor]] | list[str]): per image pre_eval
|
|
||||||
results or predict segmentation map for computing evaluation
|
|
||||||
metric.
|
|
||||||
metric (str | list[str]): Metrics to be evaluated. 'mIoU',
|
|
||||||
'mDice' and 'mFscore' are supported.
|
|
||||||
logger (logging.Logger | None | str): Logger used for printing
|
|
||||||
related information during evaluation. Default: None.
|
|
||||||
gt_seg_maps (generator[ndarray]): Custom gt seg maps as input,
|
|
||||||
used in ConcatDataset
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, float]: Default metrics.
|
list[dict]: All data info of dataset.
|
||||||
"""
|
"""
|
||||||
if isinstance(metric, str):
|
data_list = []
|
||||||
metric = [metric]
|
ann_dir = self.data_prefix.get('seg_map_path', None)
|
||||||
allowed_metrics = ['mIoU', 'mDice', 'mFscore']
|
if osp.isfile(self.ann_file):
|
||||||
if not set(metric).issubset(set(allowed_metrics)):
|
lines = mmcv.list_from_file(
|
||||||
raise KeyError('metric {} is not supported'.format(metric))
|
self.ann_file, file_client_args=self.file_client_args)
|
||||||
|
for line in lines:
|
||||||
eval_results = {}
|
img_name = line.strip()
|
||||||
# test a list of files
|
data_info = dict(img_path=img_name + self.img_suffix)
|
||||||
if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of(
|
if ann_dir is not None:
|
||||||
results, str):
|
seg_map = img_name + self.seg_map_suffix
|
||||||
if gt_seg_maps is None:
|
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
||||||
gt_seg_maps = self.get_gt_seg_maps()
|
data_info['label_map'] = self.label_map
|
||||||
num_classes = len(self.CLASSES)
|
data_info['seg_field'] = []
|
||||||
ret_metrics = eval_metrics(
|
data_list.append(data_info)
|
||||||
results,
|
|
||||||
gt_seg_maps,
|
|
||||||
num_classes,
|
|
||||||
self.ignore_index,
|
|
||||||
metric,
|
|
||||||
label_map=dict(),
|
|
||||||
reduce_zero_label=self.reduce_zero_label)
|
|
||||||
# test a list of pre_eval_results
|
|
||||||
else:
|
else:
|
||||||
ret_metrics = pre_eval_to_metrics(results, metric)
|
img_dir = self.data_prefix['img_path']
|
||||||
|
for img in self.file_client.list_dir_or_file(
|
||||||
# Because dataset.CLASSES is required for per-eval.
|
dir_path=img_dir,
|
||||||
if self.CLASSES is None:
|
list_dir=False,
|
||||||
class_names = tuple(range(num_classes))
|
suffix=self.img_suffix,
|
||||||
else:
|
recursive=True):
|
||||||
class_names = self.CLASSES
|
data_info = dict(img_path=osp.join(img_dir, img))
|
||||||
|
if ann_dir is not None:
|
||||||
# summary table
|
seg_map = img.replace(self.img_suffix, self.seg_map_suffix)
|
||||||
ret_metrics_summary = OrderedDict({
|
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
||||||
ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
|
data_info['label_map'] = self.label_map
|
||||||
for ret_metric, ret_metric_value in ret_metrics.items()
|
data_info['seg_field'] = []
|
||||||
})
|
data_list.append(data_info)
|
||||||
|
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||||
# each class table
|
return data_list
|
||||||
ret_metrics.pop('aAcc', None)
|
|
||||||
ret_metrics_class = OrderedDict({
|
|
||||||
ret_metric: np.round(ret_metric_value * 100, 2)
|
|
||||||
for ret_metric, ret_metric_value in ret_metrics.items()
|
|
||||||
})
|
|
||||||
ret_metrics_class.update({'Class': class_names})
|
|
||||||
ret_metrics_class.move_to_end('Class', last=False)
|
|
||||||
|
|
||||||
# for logger
|
|
||||||
class_table_data = PrettyTable()
|
|
||||||
for key, val in ret_metrics_class.items():
|
|
||||||
class_table_data.add_column(key, val)
|
|
||||||
|
|
||||||
summary_table_data = PrettyTable()
|
|
||||||
for key, val in ret_metrics_summary.items():
|
|
||||||
if key == 'aAcc':
|
|
||||||
summary_table_data.add_column(key, [val])
|
|
||||||
else:
|
|
||||||
summary_table_data.add_column('m' + key, [val])
|
|
||||||
|
|
||||||
print_log('per class results:', logger)
|
|
||||||
print_log('\n' + class_table_data.get_string(), logger=logger)
|
|
||||||
print_log('Summary:', logger)
|
|
||||||
print_log('\n' + summary_table_data.get_string(), logger=logger)
|
|
||||||
|
|
||||||
# each metric dict
|
|
||||||
for key, value in ret_metrics_summary.items():
|
|
||||||
if key == 'aAcc':
|
|
||||||
eval_results[key] = value / 100.0
|
|
||||||
else:
|
|
||||||
eval_results['m' + key] = value / 100.0
|
|
||||||
|
|
||||||
ret_metrics_class.pop('Class', None)
|
|
||||||
for key, value in ret_metrics_class.items():
|
|
||||||
eval_results.update({
|
|
||||||
key + '.' + str(name): value[idx] / 100.0
|
|
||||||
for idx, name in enumerate(class_names)
|
|
||||||
})
|
|
||||||
|
|
||||||
return eval_results
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ from .cityscapes import CityscapesDataset
|
||||||
class DarkZurichDataset(CityscapesDataset):
|
class DarkZurichDataset(CityscapesDataset):
|
||||||
"""DarkZurichDataset dataset."""
|
"""DarkZurichDataset dataset."""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
img_suffix='_rgb_anon.png',
|
img_suffix='_rgb_anon.png',
|
||||||
seg_map_suffix='_gt_labelTrainIds.png',
|
seg_map_suffix='_gt_labelTrainIds.png',
|
||||||
|
|
|
@ -13,15 +13,14 @@ class DRIVEDataset(CustomDataset):
|
||||||
``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
|
``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
|
||||||
'_manual1.png'.
|
'_manual1.png'.
|
||||||
"""
|
"""
|
||||||
|
METAINFO = dict(
|
||||||
|
classes=('background', 'vessel'),
|
||||||
|
palette=[[120, 120, 120], [6, 230, 230]])
|
||||||
|
|
||||||
CLASSES = ('background', 'vessel')
|
def __init__(self, **kwargs) -> None:
|
||||||
|
super().__init__(
|
||||||
PALETTE = [[120, 120, 120], [6, 230, 230]]
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super(DRIVEDataset, self).__init__(
|
|
||||||
img_suffix='.png',
|
img_suffix='.png',
|
||||||
seg_map_suffix='_manual1.png',
|
seg_map_suffix='_manual1.png',
|
||||||
reduce_zero_label=False,
|
reduce_zero_label=False,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
assert self.file_client.exists(self.img_dir)
|
assert self.file_client.exists(self.data_prefix['img_path'])
|
||||||
|
|
|
@ -13,15 +13,14 @@ class HRFDataset(CustomDataset):
|
||||||
``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
|
``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
|
||||||
'.png'.
|
'.png'.
|
||||||
"""
|
"""
|
||||||
|
METAINFO = dict(
|
||||||
|
classes=('background', 'vessel'),
|
||||||
|
palette=[[120, 120, 120], [6, 230, 230]])
|
||||||
|
|
||||||
CLASSES = ('background', 'vessel')
|
def __init__(self, **kwargs) -> None:
|
||||||
|
super().__init__(
|
||||||
PALETTE = [[120, 120, 120], [6, 230, 230]]
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super(HRFDataset, self).__init__(
|
|
||||||
img_suffix='.png',
|
img_suffix='.png',
|
||||||
seg_map_suffix='.png',
|
seg_map_suffix='.png',
|
||||||
reduce_zero_label=False,
|
reduce_zero_label=False,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
assert self.file_client.exists(self.img_dir)
|
assert self.file_client.exists(self.data_prefix['img_path'])
|
||||||
|
|
|
@ -1,10 +1,5 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
|
||||||
import mmcv
|
|
||||||
from mmcv.utils import print_log
|
|
||||||
|
|
||||||
from mmseg.registry import DATASETS
|
from mmseg.registry import DATASETS
|
||||||
from ..utils import get_root_logger
|
|
||||||
from .custom import CustomDataset
|
from .custom import CustomDataset
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,66 +12,21 @@ class iSAIDDataset(CustomDataset):
|
||||||
'_manual1.png'.
|
'_manual1.png'.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
CLASSES = ('background', 'ship', 'store_tank', 'baseball_diamond',
|
METAINFO = dict(
|
||||||
|
classes=('background', 'ship', 'store_tank', 'baseball_diamond',
|
||||||
'tennis_court', 'basketball_court', 'Ground_Track_Field',
|
'tennis_court', 'basketball_court', 'Ground_Track_Field',
|
||||||
'Bridge', 'Large_Vehicle', 'Small_Vehicle', 'Helicopter',
|
'Bridge', 'Large_Vehicle', 'Small_Vehicle', 'Helicopter',
|
||||||
'Swimming_pool', 'Roundabout', 'Soccer_ball_field', 'plane',
|
'Swimming_pool', 'Roundabout', 'Soccer_ball_field', 'plane',
|
||||||
'Harbor')
|
'Harbor'),
|
||||||
|
palette=[[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127],
|
||||||
PALETTE = [[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127],
|
|
||||||
[0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, 127],
|
[0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, 127],
|
||||||
[0, 0, 127], [0, 0, 191], [0, 0, 255], [0, 191, 127],
|
[0, 0, 127], [0, 0, 191], [0, 0, 255], [0, 191, 127],
|
||||||
[0, 127, 191], [0, 127, 255], [0, 100, 155]]
|
[0, 127, 191], [0, 127, 255], [0, 100, 155]])
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs) -> None:
|
||||||
super(iSAIDDataset, self).__init__(
|
super().__init__(
|
||||||
img_suffix='.png',
|
img_suffix='.png',
|
||||||
seg_map_suffix='.png',
|
seg_map_suffix='_instance_color_RGB.png',
|
||||||
ignore_index=255,
|
ignore_index=255,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
assert self.file_client.exists(self.img_dir)
|
assert self.file_client.exists(self.data_prefix['img_path'])
|
||||||
|
|
||||||
def load_annotations(self,
|
|
||||||
img_dir,
|
|
||||||
img_suffix,
|
|
||||||
ann_dir,
|
|
||||||
seg_map_suffix=None,
|
|
||||||
split=None):
|
|
||||||
"""Load annotation from directory.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
img_dir (str): Path to image directory
|
|
||||||
img_suffix (str): Suffix of images.
|
|
||||||
ann_dir (str|None): Path to annotation directory.
|
|
||||||
seg_map_suffix (str|None): Suffix of segmentation maps.
|
|
||||||
split (str|None): Split txt file. If split is specified, only file
|
|
||||||
with suffix in the splits will be loaded. Otherwise, all images
|
|
||||||
in img_dir/ann_dir will be loaded. Default: None
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[dict]: All image info of dataset.
|
|
||||||
"""
|
|
||||||
|
|
||||||
img_infos = []
|
|
||||||
if split is not None:
|
|
||||||
with open(split) as f:
|
|
||||||
for line in f:
|
|
||||||
name = line.strip()
|
|
||||||
img_info = dict(filename=name + img_suffix)
|
|
||||||
if ann_dir is not None:
|
|
||||||
ann_name = name + '_instance_color_RGB'
|
|
||||||
seg_map = ann_name + seg_map_suffix
|
|
||||||
img_info['ann'] = dict(seg_map=seg_map)
|
|
||||||
img_infos.append(img_info)
|
|
||||||
else:
|
|
||||||
for img in mmcv.scandir(img_dir, img_suffix, recursive=True):
|
|
||||||
img_info = dict(filename=img)
|
|
||||||
if ann_dir is not None:
|
|
||||||
seg_img = img
|
|
||||||
seg_map = seg_img.replace(
|
|
||||||
img_suffix, '_instance_color_RGB' + seg_map_suffix)
|
|
||||||
img_info['ann'] = dict(seg_map=seg_map)
|
|
||||||
img_infos.append(img_info)
|
|
||||||
|
|
||||||
print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger())
|
|
||||||
return img_infos
|
|
||||||
|
|
|
@ -11,14 +11,14 @@ class ISPRSDataset(CustomDataset):
|
||||||
``reduce_zero_label`` should be set to True. The ``img_suffix`` and
|
``reduce_zero_label`` should be set to True. The ``img_suffix`` and
|
||||||
``seg_map_suffix`` are both fixed to '.png'.
|
``seg_map_suffix`` are both fixed to '.png'.
|
||||||
"""
|
"""
|
||||||
CLASSES = ('impervious_surface', 'building', 'low_vegetation', 'tree',
|
METAINFO = dict(
|
||||||
'car', 'clutter')
|
classes=('impervious_surface', 'building', 'low_vegetation', 'tree',
|
||||||
|
'car', 'clutter'),
|
||||||
|
palette=[[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0],
|
||||||
|
[255, 255, 0], [255, 0, 0]])
|
||||||
|
|
||||||
PALETTE = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0],
|
def __init__(self, **kwargs) -> None:
|
||||||
[255, 255, 0], [255, 0, 0]]
|
super().__init__(
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super(ISPRSDataset, self).__init__(
|
|
||||||
img_suffix='.png',
|
img_suffix='.png',
|
||||||
seg_map_suffix='.png',
|
seg_map_suffix='.png',
|
||||||
reduce_zero_label=True,
|
reduce_zero_label=True,
|
||||||
|
|
|
@ -1,10 +1,4 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import os.path as osp
|
|
||||||
|
|
||||||
import mmcv
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from mmseg.registry import DATASETS
|
from mmseg.registry import DATASETS
|
||||||
from .custom import CustomDataset
|
from .custom import CustomDataset
|
||||||
|
|
||||||
|
@ -17,76 +11,15 @@ class LoveDADataset(CustomDataset):
|
||||||
``reduce_zero_label`` should be set to True. The ``img_suffix`` and
|
``reduce_zero_label`` should be set to True. The ``img_suffix`` and
|
||||||
``seg_map_suffix`` are both fixed to '.png'.
|
``seg_map_suffix`` are both fixed to '.png'.
|
||||||
"""
|
"""
|
||||||
CLASSES = ('background', 'building', 'road', 'water', 'barren', 'forest',
|
METAINFO = dict(
|
||||||
'agricultural')
|
classes=('background', 'building', 'road', 'water', 'barren', 'forest',
|
||||||
|
'agricultural'),
|
||||||
|
palette=[[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255],
|
||||||
|
[159, 129, 183], [0, 255, 0], [255, 195, 128]])
|
||||||
|
|
||||||
PALETTE = [[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255],
|
def __init__(self, **kwargs) -> None:
|
||||||
[159, 129, 183], [0, 255, 0], [255, 195, 128]]
|
super().__init__(
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super(LoveDADataset, self).__init__(
|
|
||||||
img_suffix='.png',
|
img_suffix='.png',
|
||||||
seg_map_suffix='.png',
|
seg_map_suffix='.png',
|
||||||
reduce_zero_label=True,
|
reduce_zero_label=True,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
def results2img(self, results, imgfile_prefix, indices=None):
|
|
||||||
"""Write the segmentation results to images.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
results (list[ndarray]): Testing results of the
|
|
||||||
dataset.
|
|
||||||
imgfile_prefix (str): The filename prefix of the png files.
|
|
||||||
If the prefix is "somepath/xxx",
|
|
||||||
the png files will be named "somepath/xxx.png".
|
|
||||||
indices (list[int], optional): Indices of input results, if not
|
|
||||||
set, all the indices of the dataset will be used.
|
|
||||||
Default: None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[str: str]: result txt files which contains corresponding
|
|
||||||
semantic segmentation images.
|
|
||||||
"""
|
|
||||||
|
|
||||||
mmcv.mkdir_or_exist(imgfile_prefix)
|
|
||||||
result_files = []
|
|
||||||
for result, idx in zip(results, indices):
|
|
||||||
|
|
||||||
filename = self.img_infos[idx]['filename']
|
|
||||||
basename = osp.splitext(osp.basename(filename))[0]
|
|
||||||
|
|
||||||
png_filename = osp.join(imgfile_prefix, f'{basename}.png')
|
|
||||||
|
|
||||||
# The index range of official requirement is from 0 to 6.
|
|
||||||
output = Image.fromarray(result.astype(np.uint8))
|
|
||||||
output.save(png_filename)
|
|
||||||
result_files.append(png_filename)
|
|
||||||
|
|
||||||
return result_files
|
|
||||||
|
|
||||||
def format_results(self, results, imgfile_prefix, indices=None):
|
|
||||||
"""Format the results into dir (standard format for LoveDA evaluation).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
results (list): Testing results of the dataset.
|
|
||||||
imgfile_prefix (str): The prefix of images files. It
|
|
||||||
includes the file path and the prefix of filename, e.g.,
|
|
||||||
"a/b/prefix".
|
|
||||||
indices (list[int], optional): Indices of input results,
|
|
||||||
if not set, all the indices of the dataset will be used.
|
|
||||||
Default: None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (result_files, tmp_dir), result_files is a list containing
|
|
||||||
the image paths, tmp_dir is the temporal directory created
|
|
||||||
for saving json/png files when img_prefix is not specified.
|
|
||||||
"""
|
|
||||||
if indices is None:
|
|
||||||
indices = list(range(len(self)))
|
|
||||||
|
|
||||||
assert isinstance(results, list), 'results must be a list.'
|
|
||||||
assert isinstance(indices, list), 'indices must be a list.'
|
|
||||||
|
|
||||||
result_files = self.results2img(results, imgfile_prefix, indices)
|
|
||||||
|
|
||||||
return result_files
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ from .cityscapes import CityscapesDataset
|
||||||
class NightDrivingDataset(CityscapesDataset):
|
class NightDrivingDataset(CityscapesDataset):
|
||||||
"""NightDrivingDataset dataset."""
|
"""NightDrivingDataset dataset."""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
img_suffix='_leftImg8bit.png',
|
img_suffix='_leftImg8bit.png',
|
||||||
seg_map_suffix='_gtCoarse_labelTrainIds.png',
|
seg_map_suffix='_gtCoarse_labelTrainIds.png',
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import os.path as osp
|
||||||
|
|
||||||
from mmseg.registry import DATASETS
|
from mmseg.registry import DATASETS
|
||||||
from .custom import CustomDataset
|
from .custom import CustomDataset
|
||||||
|
@ -14,21 +15,21 @@ class PascalContextDataset(CustomDataset):
|
||||||
fixed to '.png'.
|
fixed to '.png'.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
split (str): Split txt file for PascalContext.
|
ann_file (str): Annotation file path.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
CLASSES = ('background', 'aeroplane', 'bag', 'bed', 'bedclothes', 'bench',
|
METAINFO = dict(
|
||||||
'bicycle', 'bird', 'boat', 'book', 'bottle', 'building', 'bus',
|
classes=('background', 'aeroplane', 'bag', 'bed', 'bedclothes',
|
||||||
'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth',
|
'bench', 'bicycle', 'bird', 'boat', 'book', 'bottle',
|
||||||
'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence',
|
'building', 'bus', 'cabinet', 'car', 'cat', 'ceiling',
|
||||||
'floor', 'flower', 'food', 'grass', 'ground', 'horse',
|
'chair', 'cloth', 'computer', 'cow', 'cup', 'curtain', 'dog',
|
||||||
'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person',
|
'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground',
|
||||||
'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep',
|
'horse', 'keyboard', 'light', 'motorbike', 'mountain',
|
||||||
'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table',
|
'mouse', 'person', 'plate', 'platform', 'pottedplant', 'road',
|
||||||
'track', 'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water',
|
'rock', 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow',
|
||||||
'window', 'wood')
|
'sofa', 'table', 'track', 'train', 'tree', 'truck',
|
||||||
|
'tvmonitor', 'wall', 'water', 'window', 'wood'),
|
||||||
PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
|
palette=[[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
|
||||||
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
|
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
|
||||||
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
|
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
|
||||||
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
|
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
|
||||||
|
@ -42,16 +43,17 @@ class PascalContextDataset(CustomDataset):
|
||||||
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
|
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
|
||||||
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
|
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
|
||||||
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
|
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
|
||||||
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]]
|
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]])
|
||||||
|
|
||||||
def __init__(self, split, **kwargs):
|
def __init__(self, ann_file: str, **kwargs) -> None:
|
||||||
super(PascalContextDataset, self).__init__(
|
super().__init__(
|
||||||
img_suffix='.jpg',
|
img_suffix='.jpg',
|
||||||
seg_map_suffix='.png',
|
seg_map_suffix='.png',
|
||||||
split=split,
|
ann_file=ann_file,
|
||||||
reduce_zero_label=False,
|
reduce_zero_label=False,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
assert self.file_client.exists(self.img_dir) and self.split is not None
|
assert self.file_client.exists(
|
||||||
|
self.data_prefix['img_path']) and osp.isfile(self.ann_file)
|
||||||
|
|
||||||
|
|
||||||
@DATASETS.register_module()
|
@DATASETS.register_module()
|
||||||
|
@ -64,40 +66,41 @@ class PascalContextDataset59(CustomDataset):
|
||||||
fixed to '.png'.
|
fixed to '.png'.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
split (str): Split txt file for PascalContext.
|
ann_file (str): Annotation file path.
|
||||||
"""
|
"""
|
||||||
|
METAINFO = dict(
|
||||||
|
classes=('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle',
|
||||||
|
'bird', 'boat', 'book', 'bottle', 'building', 'bus',
|
||||||
|
'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth',
|
||||||
|
'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence',
|
||||||
|
'floor', 'flower', 'food', 'grass', 'ground', 'horse',
|
||||||
|
'keyboard', 'light', 'motorbike', 'mountain', 'mouse',
|
||||||
|
'person', 'plate', 'platform', 'pottedplant', 'road', 'rock',
|
||||||
|
'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa',
|
||||||
|
'table', 'track', 'train', 'tree', 'truck', 'tvmonitor',
|
||||||
|
'wall', 'water', 'window', 'wood'),
|
||||||
|
palette=[[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3],
|
||||||
|
[120, 120, 80], [140, 140, 140], [204, 5, 255],
|
||||||
|
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
|
||||||
|
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
|
||||||
|
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
|
||||||
|
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
|
||||||
|
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
|
||||||
|
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
|
||||||
|
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
|
||||||
|
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
|
||||||
|
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
|
||||||
|
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
|
||||||
|
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
|
||||||
|
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
|
||||||
|
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]])
|
||||||
|
|
||||||
CLASSES = ('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle',
|
def __init__(self, ann_file: str, **kwargs):
|
||||||
'bird', 'boat', 'book', 'bottle', 'building', 'bus', 'cabinet',
|
super().__init__(
|
||||||
'car', 'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow',
|
|
||||||
'cup', 'curtain', 'dog', 'door', 'fence', 'floor', 'flower',
|
|
||||||
'food', 'grass', 'ground', 'horse', 'keyboard', 'light',
|
|
||||||
'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform',
|
|
||||||
'pottedplant', 'road', 'rock', 'sheep', 'shelves', 'sidewalk',
|
|
||||||
'sign', 'sky', 'snow', 'sofa', 'table', 'track', 'train',
|
|
||||||
'tree', 'truck', 'tvmonitor', 'wall', 'water', 'window', 'wood')
|
|
||||||
|
|
||||||
PALETTE = [[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3],
|
|
||||||
[120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230],
|
|
||||||
[4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61],
|
|
||||||
[120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140],
|
|
||||||
[204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200],
|
|
||||||
[61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71],
|
|
||||||
[255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92],
|
|
||||||
[112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6],
|
|
||||||
[10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8],
|
|
||||||
[102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8],
|
|
||||||
[0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255],
|
|
||||||
[235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140],
|
|
||||||
[250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0],
|
|
||||||
[255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0],
|
|
||||||
[0, 235, 255], [0, 173, 255], [31, 0, 255]]
|
|
||||||
|
|
||||||
def __init__(self, split, **kwargs):
|
|
||||||
super(PascalContextDataset59, self).__init__(
|
|
||||||
img_suffix='.jpg',
|
img_suffix='.jpg',
|
||||||
seg_map_suffix='.png',
|
seg_map_suffix='.png',
|
||||||
split=split,
|
ann_file=ann_file,
|
||||||
reduce_zero_label=True,
|
reduce_zero_label=True,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
assert self.file_client.exists(self.img_dir) and self.split is not None
|
assert self.file_client.exists(
|
||||||
|
self.data_prefix['img_path']) and osp.isfile(self.ann_file)
|
||||||
|
|
|
@ -11,14 +11,14 @@ class PotsdamDataset(CustomDataset):
|
||||||
``reduce_zero_label`` should be set to True. The ``img_suffix`` and
|
``reduce_zero_label`` should be set to True. The ``img_suffix`` and
|
||||||
``seg_map_suffix`` are both fixed to '.png'.
|
``seg_map_suffix`` are both fixed to '.png'.
|
||||||
"""
|
"""
|
||||||
CLASSES = ('impervious_surface', 'building', 'low_vegetation', 'tree',
|
METAINFO = dict(
|
||||||
'car', 'clutter')
|
classes=('impervious_surface', 'building', 'low_vegetation', 'tree',
|
||||||
|
'car', 'clutter'),
|
||||||
|
palette=[[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0],
|
||||||
|
[255, 255, 0], [255, 0, 0]])
|
||||||
|
|
||||||
PALETTE = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0],
|
def __init__(self, **kwargs) -> None:
|
||||||
[255, 255, 0], [255, 0, 0]]
|
super().__init__(
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super(PotsdamDataset, self).__init__(
|
|
||||||
img_suffix='.png',
|
img_suffix='.png',
|
||||||
seg_map_suffix='.png',
|
seg_map_suffix='.png',
|
||||||
reduce_zero_label=True,
|
reduce_zero_label=True,
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import os.path as osp
|
|
||||||
|
|
||||||
from mmseg.registry import DATASETS
|
from mmseg.registry import DATASETS
|
||||||
from .custom import CustomDataset
|
from .custom import CustomDataset
|
||||||
|
|
||||||
|
@ -14,15 +12,14 @@ class STAREDataset(CustomDataset):
|
||||||
``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
|
``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
|
||||||
'.ah.png'.
|
'.ah.png'.
|
||||||
"""
|
"""
|
||||||
|
METAINFO = dict(
|
||||||
|
classes=('background', 'vessel'),
|
||||||
|
palette=[[120, 120, 120], [6, 230, 230]])
|
||||||
|
|
||||||
CLASSES = ('background', 'vessel')
|
def __init__(self, **kwargs) -> None:
|
||||||
|
super().__init__(
|
||||||
PALETTE = [[120, 120, 120], [6, 230, 230]]
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super(STAREDataset, self).__init__(
|
|
||||||
img_suffix='.png',
|
img_suffix='.png',
|
||||||
seg_map_suffix='.ah.png',
|
seg_map_suffix='.ah.png',
|
||||||
reduce_zero_label=False,
|
reduce_zero_label=False,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
assert osp.exists(self.img_dir)
|
assert self.file_client.exists(self.data_prefix['img_path'])
|
||||||
|
|
|
@ -12,19 +12,23 @@ class PascalVOCDataset(CustomDataset):
|
||||||
Args:
|
Args:
|
||||||
split (str): Split txt file for Pascal VOC.
|
split (str): Split txt file for Pascal VOC.
|
||||||
"""
|
"""
|
||||||
|
METAINFO = dict(
|
||||||
|
classes=('background', 'aeroplane', 'bicycle', 'bird', 'boat',
|
||||||
|
'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',
|
||||||
|
'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep',
|
||||||
|
'sofa', 'train', 'tvmonitor'),
|
||||||
|
palette=[[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
|
||||||
|
[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
|
||||||
|
[64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
|
||||||
|
[64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
|
||||||
|
[0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
|
||||||
|
[0, 64, 128]])
|
||||||
|
|
||||||
CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
|
def __init__(self, ann_file, **kwargs) -> None:
|
||||||
'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
|
super().__init__(
|
||||||
'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa',
|
img_suffix='.jpg',
|
||||||
'train', 'tvmonitor')
|
seg_map_suffix='.png',
|
||||||
|
ann_file=ann_file,
|
||||||
PALETTE = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
|
**kwargs)
|
||||||
[128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
|
assert self.file_client.exists(
|
||||||
[192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
|
self.data_prefix['img_path']) and osp.isfile(self.ann_file)
|
||||||
[192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
|
|
||||||
[128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
|
|
||||||
|
|
||||||
def __init__(self, split, **kwargs):
|
|
||||||
super(PascalVOCDataset, self).__init__(
|
|
||||||
img_suffix='.jpg', seg_map_suffix='.png', split=split, **kwargs)
|
|
||||||
assert osp.exists(self.img_dir) and self.split is not None
|
|
||||||
|
|
|
@ -0,0 +1,353 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mmseg.core.evaluation import get_classes, get_palette
|
||||||
|
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
|
||||||
|
COCOStuffDataset, CustomDataset, ISPRSDataset,
|
||||||
|
LoveDADataset, PascalVOCDataset, PotsdamDataset,
|
||||||
|
iSAIDDataset)
|
||||||
|
|
||||||
|
|
||||||
|
def test_classes():
|
||||||
|
assert list(
|
||||||
|
CityscapesDataset.METAINFO['classes']) == get_classes('cityscapes')
|
||||||
|
assert list(PascalVOCDataset.METAINFO['classes']) == get_classes(
|
||||||
|
'voc') == get_classes('pascal_voc')
|
||||||
|
assert list(ADE20KDataset.METAINFO['classes']) == get_classes(
|
||||||
|
'ade') == get_classes('ade20k')
|
||||||
|
assert list(
|
||||||
|
COCOStuffDataset.METAINFO['classes']) == get_classes('cocostuff')
|
||||||
|
assert list(LoveDADataset.METAINFO['classes']) == get_classes('loveda')
|
||||||
|
assert list(PotsdamDataset.METAINFO['classes']) == get_classes('potsdam')
|
||||||
|
assert list(ISPRSDataset.METAINFO['classes']) == get_classes('vaihingen')
|
||||||
|
assert list(iSAIDDataset.METAINFO['classes']) == get_classes('isaid')
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
get_classes('unsupported')
|
||||||
|
|
||||||
|
|
||||||
|
def test_classes_file_path():
|
||||||
|
tmp_file = tempfile.NamedTemporaryFile()
|
||||||
|
classes_path = f'{tmp_file.name}.txt'
|
||||||
|
train_pipeline = []
|
||||||
|
kwargs = dict(
|
||||||
|
pipeline=train_pipeline,
|
||||||
|
data_prefix=dict(img_path='./', seg_map_path='./'),
|
||||||
|
metainfo=dict(classes=classes_path))
|
||||||
|
|
||||||
|
# classes.txt with full categories
|
||||||
|
categories = get_classes('cityscapes')
|
||||||
|
with open(classes_path, 'w') as f:
|
||||||
|
f.write('\n'.join(categories))
|
||||||
|
dataset = CityscapesDataset(**kwargs)
|
||||||
|
assert list(dataset.metainfo['classes']) == categories
|
||||||
|
assert dataset.label_map is None
|
||||||
|
|
||||||
|
# classes.txt with sub categories
|
||||||
|
categories = ['road', 'sidewalk', 'building']
|
||||||
|
with open(classes_path, 'w') as f:
|
||||||
|
f.write('\n'.join(categories))
|
||||||
|
dataset = CityscapesDataset(**kwargs)
|
||||||
|
assert list(dataset.metainfo['classes']) == categories
|
||||||
|
assert dataset.label_map is not None
|
||||||
|
|
||||||
|
# classes.txt with unknown categories
|
||||||
|
categories = ['road', 'sidewalk', 'unknown']
|
||||||
|
with open(classes_path, 'w') as f:
|
||||||
|
f.write('\n'.join(categories))
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
CityscapesDataset(**kwargs)
|
||||||
|
|
||||||
|
tmp_file.close()
|
||||||
|
os.remove(classes_path)
|
||||||
|
assert not osp.exists(classes_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_palette():
|
||||||
|
assert CityscapesDataset.METAINFO['palette'] == get_palette('cityscapes')
|
||||||
|
assert PascalVOCDataset.METAINFO['palette'] == get_palette(
|
||||||
|
'voc') == get_palette('pascal_voc')
|
||||||
|
assert ADE20KDataset.METAINFO['palette'] == get_palette(
|
||||||
|
'ade') == get_palette('ade20k')
|
||||||
|
assert LoveDADataset.METAINFO['palette'] == get_palette('loveda')
|
||||||
|
assert PotsdamDataset.METAINFO['palette'] == get_palette('potsdam')
|
||||||
|
assert COCOStuffDataset.METAINFO['palette'] == get_palette('cocostuff')
|
||||||
|
assert iSAIDDataset.METAINFO['palette'] == get_palette('isaid')
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
get_palette('unsupported')
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_dataset():
|
||||||
|
|
||||||
|
# with 'img_path' and 'seg_map_path' in data_prefix
|
||||||
|
train_dataset = CustomDataset(
|
||||||
|
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
|
||||||
|
data_prefix=dict(
|
||||||
|
img_path='imgs/',
|
||||||
|
seg_map_path='gts/',
|
||||||
|
),
|
||||||
|
img_suffix='img.jpg',
|
||||||
|
seg_map_suffix='gt.png')
|
||||||
|
assert len(train_dataset) == 5
|
||||||
|
|
||||||
|
# with 'img_path' and 'seg_map_path' in data_prefix and ann_file
|
||||||
|
train_dataset = CustomDataset(
|
||||||
|
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
|
||||||
|
data_prefix=dict(
|
||||||
|
img_path='imgs/',
|
||||||
|
seg_map_path='gts/',
|
||||||
|
),
|
||||||
|
img_suffix='img.jpg',
|
||||||
|
seg_map_suffix='gt.png',
|
||||||
|
ann_file='splits/train.txt')
|
||||||
|
assert len(train_dataset) == 4
|
||||||
|
|
||||||
|
# no data_root
|
||||||
|
train_dataset = CustomDataset(
|
||||||
|
data_prefix=dict(
|
||||||
|
img_path=osp.join(
|
||||||
|
osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
|
||||||
|
seg_map_path=osp.join(
|
||||||
|
osp.dirname(__file__), '../data/pseudo_dataset/gts')),
|
||||||
|
img_suffix='img.jpg',
|
||||||
|
seg_map_suffix='gt.png')
|
||||||
|
assert len(train_dataset) == 5
|
||||||
|
|
||||||
|
# with data_root but 'img_path' and 'seg_map_path' in data_prefix are
|
||||||
|
# abs path
|
||||||
|
train_dataset = CustomDataset(
|
||||||
|
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
|
||||||
|
data_prefix=dict(
|
||||||
|
img_path=osp.join(
|
||||||
|
osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
|
||||||
|
seg_map_path=osp.join(
|
||||||
|
osp.dirname(__file__), '../data/pseudo_dataset/gts')),
|
||||||
|
img_suffix='img.jpg',
|
||||||
|
seg_map_suffix='gt.png')
|
||||||
|
assert len(train_dataset) == 5
|
||||||
|
|
||||||
|
# test_mode=True
|
||||||
|
test_dataset = CustomDataset(
|
||||||
|
data_prefix=dict(
|
||||||
|
img_path=osp.join(
|
||||||
|
osp.dirname(__file__), '../data/pseudo_dataset/imgs')),
|
||||||
|
img_suffix='img.jpg',
|
||||||
|
test_mode=True,
|
||||||
|
metainfo=dict(classes=('pseudo_class', )))
|
||||||
|
assert len(test_dataset) == 5
|
||||||
|
|
||||||
|
# training data get
|
||||||
|
train_data = train_dataset[0]
|
||||||
|
assert isinstance(train_data, dict)
|
||||||
|
assert 'img_path' in train_data and osp.isfile(train_data['img_path'])
|
||||||
|
assert 'seg_map_path' in train_data and osp.isfile(
|
||||||
|
train_data['seg_map_path'])
|
||||||
|
|
||||||
|
# test data get
|
||||||
|
test_data = test_dataset[0]
|
||||||
|
assert isinstance(test_data, dict)
|
||||||
|
assert 'img_path' in train_data and osp.isfile(train_data['img_path'])
|
||||||
|
assert 'seg_map_path' in train_data and osp.isfile(
|
||||||
|
train_data['seg_map_path'])
|
||||||
|
|
||||||
|
|
||||||
|
def test_ade():
|
||||||
|
test_dataset = ADE20KDataset(
|
||||||
|
pipeline=[],
|
||||||
|
data_prefix=dict(
|
||||||
|
img_path=osp.join(
|
||||||
|
osp.dirname(__file__), '../data/pseudo_dataset/imgs')))
|
||||||
|
assert len(test_dataset) == 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_cityscapes():
|
||||||
|
test_dataset = CityscapesDataset(
|
||||||
|
pipeline=[],
|
||||||
|
data_prefix=dict(
|
||||||
|
img_path=osp.join(
|
||||||
|
osp.dirname(__file__),
|
||||||
|
'../data/pseudo_cityscapes_dataset/leftImg8bit'),
|
||||||
|
seg_map_path=osp.join(
|
||||||
|
osp.dirname(__file__),
|
||||||
|
'../data/pseudo_cityscapes_dataset/gtFine')))
|
||||||
|
assert len(test_dataset) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_loveda():
|
||||||
|
test_dataset = LoveDADataset(
|
||||||
|
pipeline=[],
|
||||||
|
data_prefix=dict(
|
||||||
|
img_path=osp.join(
|
||||||
|
osp.dirname(__file__),
|
||||||
|
'../data/pseudo_loveda_dataset/img_dir'),
|
||||||
|
seg_map_path=osp.join(
|
||||||
|
osp.dirname(__file__),
|
||||||
|
'../data/pseudo_loveda_dataset/ann_dir')))
|
||||||
|
assert len(test_dataset) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_potsdam():
|
||||||
|
test_dataset = PotsdamDataset(
|
||||||
|
pipeline=[],
|
||||||
|
data_prefix=dict(
|
||||||
|
img_path=osp.join(
|
||||||
|
osp.dirname(__file__),
|
||||||
|
'../data/pseudo_potsdam_dataset/img_dir'),
|
||||||
|
seg_map_path=osp.join(
|
||||||
|
osp.dirname(__file__),
|
||||||
|
'../data/pseudo_potsdam_dataset/ann_dir')))
|
||||||
|
assert len(test_dataset) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_vaihingen():
|
||||||
|
test_dataset = ISPRSDataset(
|
||||||
|
pipeline=[],
|
||||||
|
data_prefix=dict(
|
||||||
|
img_path=osp.join(
|
||||||
|
osp.dirname(__file__),
|
||||||
|
'../data/pseudo_vaihingen_dataset/img_dir'),
|
||||||
|
seg_map_path=osp.join(
|
||||||
|
osp.dirname(__file__),
|
||||||
|
'../data/pseudo_vaihingen_dataset/ann_dir')))
|
||||||
|
assert len(test_dataset) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_isaid():
|
||||||
|
test_dataset = iSAIDDataset(
|
||||||
|
pipeline=[],
|
||||||
|
data_prefix=dict(
|
||||||
|
img_path=osp.join(
|
||||||
|
osp.dirname(__file__), '../data/pseudo_isaid_dataset/img_dir'),
|
||||||
|
seg_map_path=osp.join(
|
||||||
|
osp.dirname(__file__),
|
||||||
|
'../data/pseudo_isaid_dataset/ann_dir')))
|
||||||
|
assert len(test_dataset) == 2
|
||||||
|
test_dataset = iSAIDDataset(
|
||||||
|
data_prefix=dict(
|
||||||
|
img_path=osp.join(
|
||||||
|
osp.dirname(__file__), '../data/pseudo_isaid_dataset/img_dir'),
|
||||||
|
seg_map_path=osp.join(
|
||||||
|
osp.dirname(__file__),
|
||||||
|
'../data/pseudo_isaid_dataset/ann_dir')),
|
||||||
|
ann_file=osp.join(
|
||||||
|
osp.dirname(__file__),
|
||||||
|
'../data/pseudo_isaid_dataset/splits/train.txt'))
|
||||||
|
assert len(test_dataset) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('dataset, classes', [
|
||||||
|
('ADE20KDataset', ('wall', 'building')),
|
||||||
|
('CityscapesDataset', ('road', 'sidewalk')),
|
||||||
|
('CustomDataset', ('bus', 'car')),
|
||||||
|
('PascalVOCDataset', ('aeroplane', 'bicycle')),
|
||||||
|
])
|
||||||
|
def test_custom_classes_override_default(dataset, classes):
|
||||||
|
|
||||||
|
dataset_class = DATASETS.get(dataset)
|
||||||
|
if isinstance(dataset_class, PascalVOCDataset):
|
||||||
|
tmp_file = tempfile.NamedTemporaryFile()
|
||||||
|
ann_file = f'{tmp_file.name}.txt'
|
||||||
|
else:
|
||||||
|
ann_file = MagicMock()
|
||||||
|
|
||||||
|
original_classes = dataset_class.METAINFO.get('classes', None)
|
||||||
|
|
||||||
|
# Test setting classes as a tuple
|
||||||
|
custom_dataset = dataset_class(
|
||||||
|
data_prefix=dict(img_path=MagicMock()),
|
||||||
|
ann_file=ann_file,
|
||||||
|
metainfo=dict(classes=classes),
|
||||||
|
test_mode=True,
|
||||||
|
lazy_init=True)
|
||||||
|
|
||||||
|
assert custom_dataset.metainfo['classes'] != original_classes
|
||||||
|
assert custom_dataset.metainfo['classes'] == classes
|
||||||
|
if not isinstance(custom_dataset, CustomDataset):
|
||||||
|
assert isinstance(custom_dataset.label_map, dict)
|
||||||
|
|
||||||
|
# Test setting classes as a list
|
||||||
|
custom_dataset = dataset_class(
|
||||||
|
data_prefix=dict(img_path=MagicMock()),
|
||||||
|
ann_file=ann_file,
|
||||||
|
metainfo=dict(classes=list(classes)),
|
||||||
|
test_mode=True,
|
||||||
|
lazy_init=True)
|
||||||
|
|
||||||
|
assert custom_dataset.metainfo['classes'] != original_classes
|
||||||
|
assert custom_dataset.metainfo['classes'] == list(classes)
|
||||||
|
if not isinstance(custom_dataset, CustomDataset):
|
||||||
|
assert isinstance(custom_dataset.label_map, dict)
|
||||||
|
|
||||||
|
# Test overriding not a subset
|
||||||
|
custom_dataset = dataset_class(
|
||||||
|
ann_file=ann_file,
|
||||||
|
data_prefix=dict(img_path=MagicMock()),
|
||||||
|
metainfo=dict(classes=[classes[0]]),
|
||||||
|
test_mode=True,
|
||||||
|
lazy_init=True)
|
||||||
|
|
||||||
|
assert custom_dataset.metainfo['classes'] != original_classes
|
||||||
|
assert custom_dataset.metainfo['classes'] == [classes[0]]
|
||||||
|
if not isinstance(custom_dataset, CustomDataset):
|
||||||
|
assert isinstance(custom_dataset.label_map, dict)
|
||||||
|
|
||||||
|
# Test default behavior
|
||||||
|
if dataset_class is CustomDataset:
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
custom_dataset = dataset_class(
|
||||||
|
ann_file=ann_file,
|
||||||
|
data_prefix=dict(img_path=MagicMock()),
|
||||||
|
metainfo=None,
|
||||||
|
test_mode=True,
|
||||||
|
lazy_init=True)
|
||||||
|
else:
|
||||||
|
custom_dataset = dataset_class(
|
||||||
|
data_prefix=dict(img_path=MagicMock()),
|
||||||
|
ann_file=ann_file,
|
||||||
|
metainfo=None,
|
||||||
|
test_mode=True,
|
||||||
|
lazy_init=True)
|
||||||
|
|
||||||
|
assert custom_dataset.METAINFO['classes'] == original_classes
|
||||||
|
assert custom_dataset.label_map is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_dataset_random_palette_is_generated():
|
||||||
|
dataset = CustomDataset(
|
||||||
|
pipeline=[],
|
||||||
|
data_prefix=dict(img_path=MagicMock()),
|
||||||
|
ann_file=MagicMock(),
|
||||||
|
metainfo=dict(classes=('bus', 'car')),
|
||||||
|
lazy_init=True,
|
||||||
|
test_mode=True)
|
||||||
|
assert len(dataset.metainfo['palette']) == 2
|
||||||
|
for class_color in dataset.metainfo['palette']:
|
||||||
|
assert len(class_color) == 3
|
||||||
|
assert all(x >= 0 and x <= 255 for x in class_color)
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_dataset_custom_palette():
|
||||||
|
dataset = CustomDataset(
|
||||||
|
data_prefix=dict(img_path=MagicMock()),
|
||||||
|
ann_file=MagicMock(),
|
||||||
|
metainfo=dict(
|
||||||
|
classes=('bus', 'car'), palette=[[100, 100, 100], [200, 200,
|
||||||
|
200]]),
|
||||||
|
lazy_init=True,
|
||||||
|
test_mode=True)
|
||||||
|
assert tuple(dataset.metainfo['palette']) == tuple([[100, 100, 100],
|
||||||
|
[200, 200, 200]])
|
||||||
|
# test custom class and palette don't match
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
dataset = CustomDataset(
|
||||||
|
data_prefix=dict(img_path=MagicMock()),
|
||||||
|
ann_file=MagicMock(),
|
||||||
|
metainfo=dict(classes=('bus', 'car'), palette=[[200, 200, 200]]),
|
||||||
|
lazy_init=True)
|
Loading…
Reference in New Issue