Merge branch 'zhengmiao/refactory-dataset' into 'refactor_dev'

[Refactory] Dataset refactory

See merge request openmmlab-enterprise/openmmlab-ce/mmsegmentation!15
pull/1801/head
zhengmiao 2022-05-26 09:13:40 +00:00
commit d64f941fb3
17 changed files with 835 additions and 1086 deletions

View File

@ -1,10 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import mmcv
import numpy as np
from PIL import Image
from mmseg.registry import DATASETS from mmseg.registry import DATASETS
from .custom import CustomDataset from .custom import CustomDataset
@ -18,33 +12,36 @@ class ADE20KDataset(CustomDataset):
The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to
'.png'. '.png'.
""" """
CLASSES = ( METAINFO = dict(
'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ', classes=('wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road',
'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', 'bed ', 'windowpane', 'grass', 'cabinet', 'sidewalk',
'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', 'person', 'earth', 'door', 'table', 'mountain', 'plant',
'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', 'curtain', 'chair', 'car', 'water', 'painting', 'sofa',
'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair',
'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp',
'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
'signboard', 'chest of drawers', 'counter', 'sand', 'sink', 'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', 'skyscraper', 'fireplace', 'refrigerator', 'grandstand',
'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow',
'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', 'screen door', 'stairway', 'river', 'bridge', 'bookcase',
'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill',
'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', 'bench', 'countertop', 'stove', 'palm', 'kitchen island',
'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', 'computer', 'swivel chair', 'boat', 'bar', 'arcade machine',
'chandelier', 'awning', 'streetlight', 'booth', 'television receiver', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister', 'chandelier', 'awning', 'streetlight', 'booth',
'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', 'television receiver', 'airplane', 'dirt track', 'apparel',
'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle',
'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent', 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain',
'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', 'conveyer belt', 'canopy', 'washer', 'plaything',
'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall',
'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food',
'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen', 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal',
'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket',
'clock', 'flag') 'sculpture', 'hood', 'sconce', 'vase', 'traffic light',
'tray', 'ashcan', 'fan', 'pier', 'crt screen', 'plate',
PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
'clock', 'flag'),
palette=[[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
@ -81,87 +78,11 @@ class ADE20KDataset(CustomDataset):
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
[102, 255, 0], [92, 0, 255]] [102, 255, 0], [92, 0, 255]])
def __init__(self, **kwargs): def __init__(self, **kwargs) -> None:
super(ADE20KDataset, self).__init__( super().__init__(
img_suffix='.jpg', img_suffix='.jpg',
seg_map_suffix='.png', seg_map_suffix='.png',
reduce_zero_label=True, reduce_zero_label=True,
**kwargs) **kwargs)
def results2img(self, results, imgfile_prefix, to_label_id, indices=None):
"""Write the segmentation results to images.
Args:
results (list[ndarray]): Testing results of the
dataset.
imgfile_prefix (str): The filename prefix of the png files.
If the prefix is "somepath/xxx",
the png files will be named "somepath/xxx.png".
to_label_id (bool): whether convert output to label_id for
submission.
indices (list[int], optional): Indices of input results, if not
set, all the indices of the dataset will be used.
Default: None.
Returns:
list[str: str]: result txt files which contains corresponding
semantic segmentation images.
"""
if indices is None:
indices = list(range(len(self)))
mmcv.mkdir_or_exist(imgfile_prefix)
result_files = []
for result, idx in zip(results, indices):
filename = self.img_infos[idx]['filename']
basename = osp.splitext(osp.basename(filename))[0]
png_filename = osp.join(imgfile_prefix, f'{basename}.png')
# The index range of official requirement is from 0 to 150.
# But the index range of output is from 0 to 149.
# That is because we set reduce_zero_label=True.
result = result + 1
output = Image.fromarray(result.astype(np.uint8))
output.save(png_filename)
result_files.append(png_filename)
return result_files
def format_results(self,
results,
imgfile_prefix,
to_label_id=True,
indices=None):
"""Format the results into dir (standard format for ade20k evaluation).
Args:
results (list): Testing results of the dataset.
imgfile_prefix (str | None): The prefix of images files. It
includes the file path and the prefix of filename, e.g.,
"a/b/prefix".
to_label_id (bool): whether convert output to label_id for
submission. Default: False
indices (list[int], optional): Indices of input results, if not
set, all the indices of the dataset will be used.
Default: None.
Returns:
tuple: (result_files, tmp_dir), result_files is a list containing
the image paths, tmp_dir is the temporal directory created
for saving json/png files when img_prefix is not specified.
"""
if indices is None:
indices = list(range(len(self)))
assert isinstance(results, list), 'results must be a list.'
assert isinstance(indices, list), 'indices must be a list.'
result_files = self.results2img(results, imgfile_prefix, to_label_id,
indices)
return result_files

View File

@ -13,15 +13,14 @@ class ChaseDB1Dataset(CustomDataset):
The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
'_1stHO.png'. '_1stHO.png'.
""" """
METAFILE = dict(
classes=('background', 'vessel'),
palette=[[120, 120, 120], [6, 230, 230]])
CLASSES = ('background', 'vessel') def __init__(self, **kwargs) -> None:
super().__init__(
PALETTE = [[120, 120, 120], [6, 230, 230]]
def __init__(self, **kwargs):
super(ChaseDB1Dataset, self).__init__(
img_suffix='.png', img_suffix='.png',
seg_map_suffix='_1stHO.png', seg_map_suffix='_1stHO.png',
reduce_zero_label=False, reduce_zero_label=False,
**kwargs) **kwargs)
assert self.file_client.exists(self.img_dir) assert self.file_client.exists(self.data_prefix['img_list'])

View File

@ -1,11 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import mmcv
import numpy as np
from mmcv.utils import print_log
from PIL import Image
from mmseg.registry import DATASETS from mmseg.registry import DATASETS
from .custom import CustomDataset from .custom import CustomDataset
@ -17,198 +10,21 @@ class CityscapesDataset(CustomDataset):
The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is
fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset. fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset.
""" """
METAINFO = dict(
CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', classes=('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', 'traffic light', 'traffic sign', 'vegetation', 'terrain',
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train',
'bicycle') 'motorcycle', 'bicycle'),
palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], [190, 153, 153], [153, 153, 153], [250, 170,
[190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], 30], [220, 220, 0],
[107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60], [107, 142, 35], [152, 251, 152], [70, 130, 180],
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70],
[0, 80, 100], [0, 0, 230], [119, 11, 32]] [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]])
def __init__(self, def __init__(self,
img_suffix='_leftImg8bit.png', img_suffix='_leftImg8bit.png',
seg_map_suffix='_gtFine_labelTrainIds.png', seg_map_suffix='_gtFine_labelTrainIds.png',
**kwargs): **kwargs) -> None:
super(CityscapesDataset, self).__init__( super().__init__(
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
@staticmethod
def _convert_to_label_id(result):
"""Convert trainId to id for cityscapes."""
if isinstance(result, str):
result = np.load(result)
import cityscapesscripts.helpers.labels as CSLabels
result_copy = result.copy()
for trainId, label in CSLabels.trainId2label.items():
result_copy[result == trainId] = label.id
return result_copy
def results2img(self, results, imgfile_prefix, to_label_id, indices=None):
"""Write the segmentation results to images.
Args:
results (list[ndarray]): Testing results of the
dataset.
imgfile_prefix (str): The filename prefix of the png files.
If the prefix is "somepath/xxx",
the png files will be named "somepath/xxx.png".
to_label_id (bool): whether convert output to label_id for
submission.
indices (list[int], optional): Indices of input results,
if not set, all the indices of the dataset will be used.
Default: None.
Returns:
list[str: str]: result txt files which contains corresponding
semantic segmentation images.
"""
if indices is None:
indices = list(range(len(self)))
mmcv.mkdir_or_exist(imgfile_prefix)
result_files = []
for result, idx in zip(results, indices):
if to_label_id:
result = self._convert_to_label_id(result)
filename = self.img_infos[idx]['filename']
basename = osp.splitext(osp.basename(filename))[0]
png_filename = osp.join(imgfile_prefix, f'{basename}.png')
output = Image.fromarray(result.astype(np.uint8)).convert('P')
import cityscapesscripts.helpers.labels as CSLabels
palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8)
for label_id, label in CSLabels.id2label.items():
palette[label_id] = label.color
output.putpalette(palette)
output.save(png_filename)
result_files.append(png_filename)
return result_files
def format_results(self,
results,
imgfile_prefix,
to_label_id=True,
indices=None):
"""Format the results into dir (standard format for Cityscapes
evaluation).
Args:
results (list): Testing results of the dataset.
imgfile_prefix (str): The prefix of images files. It
includes the file path and the prefix of filename, e.g.,
"a/b/prefix".
to_label_id (bool): whether convert output to label_id for
submission. Default: False
indices (list[int], optional): Indices of input results,
if not set, all the indices of the dataset will be used.
Default: None.
Returns:
tuple: (result_files, tmp_dir), result_files is a list containing
the image paths, tmp_dir is the temporal directory created
for saving json/png files when img_prefix is not specified.
"""
if indices is None:
indices = list(range(len(self)))
assert isinstance(results, list), 'results must be a list.'
assert isinstance(indices, list), 'indices must be a list.'
result_files = self.results2img(results, imgfile_prefix, to_label_id,
indices)
return result_files
def evaluate(self,
results,
metric='mIoU',
logger=None,
imgfile_prefix=None):
"""Evaluation in Cityscapes/default protocol.
Args:
results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated.
logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Default: None.
imgfile_prefix (str | None): The prefix of output image file,
for cityscapes evaluation only. It includes the file path and
the prefix of filename, e.g., "a/b/prefix".
If results are evaluated with cityscapes protocol, it would be
the prefix of output png files. The output files would be
png images under folder "a/b/prefix/xxx.png", where "xxx" is
the image name of cityscapes. If not specified, a temp file
will be created for evaluation.
Default: None.
Returns:
dict[str, float]: Cityscapes/default metrics.
"""
eval_results = dict()
metrics = metric.copy() if isinstance(metric, list) else [metric]
if 'cityscapes' in metrics:
eval_results.update(
self._evaluate_cityscapes(results, logger, imgfile_prefix))
metrics.remove('cityscapes')
if len(metrics) > 0:
eval_results.update(
super(CityscapesDataset,
self).evaluate(results, metrics, logger))
return eval_results
def _evaluate_cityscapes(self, results, logger, imgfile_prefix):
"""Evaluation in Cityscapes protocol.
Args:
results (list): Testing results of the dataset.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
imgfile_prefix (str | None): The prefix of output image file
Returns:
dict[str: float]: Cityscapes evaluation results.
"""
try:
import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa
except ImportError:
raise ImportError('Please run "pip install cityscapesscripts" to '
'install cityscapesscripts first.')
msg = 'Evaluating in Cityscapes style'
if logger is None:
msg = '\n' + msg
print_log(msg, logger=logger)
result_dir = imgfile_prefix
eval_results = dict()
print_log(f'Evaluating results under {result_dir} ...', logger=logger)
CSEval.args.evalInstLevelScore = True
CSEval.args.predictionPath = osp.abspath(result_dir)
CSEval.args.evalPixelAccuracy = True
CSEval.args.JSONOutput = False
seg_map_list = []
pred_list = []
# when evaluating with official cityscapesscripts,
# **_gtFine_labelIds.png is used
for seg_map in mmcv.scandir(
self.ann_dir, 'gtFine_labelIds.png', recursive=True):
seg_map_list.append(osp.join(self.ann_dir, seg_map))
pred_list.append(CSEval.getPrediction(CSEval.args, seg_map))
eval_results.update(
CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args))
return eval_results

View File

@ -14,38 +14,40 @@ class COCOStuffDataset(CustomDataset):
10k and 164k versions, respectively. The ``img_suffix`` is fixed to '.jpg', 10k and 164k versions, respectively. The ``img_suffix`` is fixed to '.jpg',
and ``seg_map_suffix`` is fixed to '.png'. and ``seg_map_suffix`` is fixed to '.png'.
""" """
CLASSES = ( METAINFO = dict(
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', classes=(
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner', 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner',
'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet', 'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet',
'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile', 'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile',
'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain', 'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain',
'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble', 'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble',
'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', 'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', 'flower',
'flower', 'fog', 'food-other', 'fruit', 'furniture-other', 'grass', 'fog', 'food-other', 'fruit', 'furniture-other', 'grass', 'gravel',
'gravel', 'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', 'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', 'metal',
'metal', 'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net', 'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net',
'paper', 'pavement', 'pillow', 'plant-other', 'plastic', 'platform', 'paper', 'pavement', 'pillow', 'plant-other', 'plastic',
'playingfield', 'railing', 'railroad', 'river', 'road', 'rock', 'roof', 'platform', 'playingfield', 'railing', 'railroad', 'river', 'road',
'rug', 'salad', 'sand', 'sea', 'shelf', 'sky-other', 'skyscraper', 'rock', 'roof', 'rug', 'salad', 'sand', 'sea', 'shelf',
'snow', 'solid-other', 'stairs', 'stone', 'straw', 'structural-other', 'sky-other', 'skyscraper', 'snow', 'solid-other', 'stairs',
'table', 'tent', 'textile-other', 'towel', 'tree', 'vegetable', 'stone', 'straw', 'structural-other', 'table', 'tent',
'wall-brick', 'wall-concrete', 'wall-other', 'wall-panel', 'textile-other', 'towel', 'tree', 'vegetable', 'wall-brick',
'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'waterdrops', 'wall-concrete', 'wall-other', 'wall-panel', 'wall-stone',
'window-blind', 'window-other', 'wood') 'wall-tile', 'wall-wood', 'water-other', 'waterdrops',
'window-blind', 'window-other', 'wood'),
PALETTE = [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192], palette=[[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192],
[0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64], [0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64],
[0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224], [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224],
[0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192], [0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192],
@ -87,8 +89,8 @@ class COCOStuffDataset(CustomDataset):
[0, 64, 32], [64, 160, 192], [192, 64, 64], [128, 64, 160], [0, 64, 32], [64, 160, 192], [192, 64, 64], [128, 64, 160],
[64, 32, 192], [192, 192, 192], [0, 64, 160], [192, 160, 192], [64, 32, 192], [192, 192, 192], [0, 64, 160], [192, 160, 192],
[192, 192, 0], [128, 64, 96], [192, 32, 64], [192, 64, 128], [192, 192, 0], [128, 64, 96], [192, 32, 64], [192, 64, 128],
[64, 192, 96], [64, 160, 64], [64, 64, 0]] [64, 192, 96], [64, 160, 64], [64, 64, 0]])
def __init__(self, **kwargs): def __init__(self, **kwargs) -> None:
super(COCOStuffDataset, self).__init__( super().__init__(
img_suffix='.jpg', seg_map_suffix='_labelTrainIds.png', **kwargs) img_suffix='.jpg', seg_map_suffix='_labelTrainIds.png', **kwargs)

View File

@ -1,22 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp import os.path as osp
import warnings from typing import Callable, Dict, List, Optional, Sequence, Union
from collections import OrderedDict
import mmcv import mmcv
import numpy as np import numpy as np
from mmcv.utils import print_log from mmengine.dataset import BaseDataset, Compose
from prettytable import PrettyTable
from torch.utils.data import Dataset
from mmseg.core import eval_metrics, intersect_and_union, pre_eval_to_metrics
from mmseg.registry import DATASETS from mmseg.registry import DATASETS
from mmseg.utils import get_root_logger
from .pipelines import Compose, LoadAnnotations
@DATASETS.register_module() @DATASETS.register_module()
class CustomDataset(Dataset): class CustomDataset(BaseDataset):
"""Custom dataset for semantic segmentation. An example of file structure """Custom dataset for semantic segmentation. An example of file structure
is as followed. is as followed.
@ -46,330 +41,163 @@ class CustomDataset(Dataset):
Args: Args:
pipeline (list[dict]): Processing pipeline ann_file (str): Annotation file path. Defaults to ''.
img_dir (str): Path to image directory metainfo (dict, optional): Meta information for dataset, such as
specify classes to load. Defaults to None.
data_root (str, optional): The root directory for ``data_prefix`` and
``ann_file``. Defaults to None.
data_prefix (dict, optional): Prefix for training data. Defaults to
dict(img_path=None, seg_path=None).
img_suffix (str): Suffix of images. Default: '.jpg' img_suffix (str): Suffix of images. Default: '.jpg'
ann_dir (str, optional): Path to annotation directory. Default: None
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png' seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
split (str, optional): Split txt file. If split is specified, only filter_cfg (dict, optional): Config for filter data. Defaults to None.
file with suffix in the splits will be loaded. Otherwise, all indices (int or Sequence[int], optional): Support using first few
images in img_dir/ann_dir will be loaded. Default: None data in annotation file to facilitate training/testing on a smaller
data_root (str, optional): Data root for img_dir/ann_dir. Default: dataset. Defaults to None which means using all ``data_infos``.
None. serialize_data (bool, optional): Whether to hold memory using
test_mode (bool): If test_mode=True, gt wouldn't be loaded. serialized objects, when enabled, data loader workers can use
shared RAM from master process instead of making a copy. Defaults
to True.
pipeline (list, optional): Processing pipeline. Defaults to [].
test_mode (bool, optional): ``test_mode=True`` means in test phase.
Defaults to False.
lazy_init (bool, optional): Whether to load annotation during
instantiation. In some cases, such as visualization, only the meta
information of the dataset is needed, which is not necessary to
load annotation file. ``Basedataset`` can skip load annotations to
save time by set ``lazy_init=False``. Defaults to False.
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
None img. The maximum extra number of cycles to get a valid
image. Defaults to 1000.
ignore_index (int): The label index to be ignored. Default: 255 ignore_index (int): The label index to be ignored. Default: 255
reduce_zero_label (bool): Whether to mark label zero as ignored. reduce_zero_label (bool): Whether to mark label zero as ignored.
Default: False Default: False
classes (str | Sequence[str], optional): Specify classes to load.
If is None, ``cls.CLASSES`` will be used. Default: None.
palette (Sequence[Sequence[int]]] | np.ndarray | None):
The palette of segmentation map. If None is given, and
self.PALETTE is None, random palette will be generated.
Default: None
gt_seg_map_loader_cfg (dict, optional): build LoadAnnotations to
load gt for evaluation, load from disk by default. Default: None.
file_client_args (dict): Arguments to instantiate a FileClient. file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details. See :class:`mmcv.fileio.FileClient` for details.
Defaults to ``dict(backend='disk')``. Defaults to ``dict(backend='disk')``.
""" """
METAINFO: dict = dict()
CLASSES = None def __init__(
self,
PALETTE = None ann_file: str = '',
def __init__(self,
pipeline,
img_dir,
img_suffix='.jpg', img_suffix='.jpg',
ann_dir=None,
seg_map_suffix='.png', seg_map_suffix='.png',
split=None, metainfo: Optional[dict] = None,
data_root=None, data_root: Optional[str] = None,
test_mode=False, data_prefix: dict = dict(img_path=None, seg_map_path=None),
ignore_index=255, filter_cfg: Optional[dict] = None,
reduce_zero_label=False, indices: Optional[Union[int, Sequence[int]]] = None,
classes=None, serialize_data: bool = True,
palette=None, pipeline: List[Union[dict, Callable]] = [],
gt_seg_map_loader_cfg=None, test_mode: bool = False,
file_client_args=dict(backend='disk')): lazy_init: bool = False,
self.pipeline = Compose(pipeline) max_refetch: int = 1000,
self.img_dir = img_dir ignore_index: int = 255,
reduce_zero_label: bool = True,
file_client_args: dict = dict(backend='disk')
) -> None:
self.img_suffix = img_suffix self.img_suffix = img_suffix
self.ann_dir = ann_dir
self.seg_map_suffix = seg_map_suffix self.seg_map_suffix = seg_map_suffix
self.split = split
self.data_root = data_root
self.test_mode = test_mode
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.reduce_zero_label = reduce_zero_label self.reduce_zero_label = reduce_zero_label
self.label_map = None
self.CLASSES, self.PALETTE = self.get_classes_and_palette(
classes, palette)
self.gt_seg_map_loader = LoadAnnotations(
) if gt_seg_map_loader_cfg is None else LoadAnnotations(
**gt_seg_map_loader_cfg)
self.file_client_args = file_client_args self.file_client_args = file_client_args
self.file_client = mmcv.FileClient.infer_client(self.file_client_args) self.file_client = mmcv.FileClient.infer_client(self.file_client_args)
self.data_root = data_root
self.data_prefix = copy.copy(data_prefix)
self.ann_file = ann_file
self.filter_cfg = copy.deepcopy(filter_cfg)
self._indices = indices
self.serialize_data = serialize_data
self.test_mode = test_mode
self.max_refetch = max_refetch
self.data_list: List[dict] = []
self.data_bytes: np.ndarray
# Set meta information.
self._metainfo = self._load_metainfo(copy.deepcopy(metainfo))
# Get label map for custom classes
new_classes = self._metainfo.get('classes', None)
self.label_map = self.get_label_map(new_classes)
# Update palette based on label map or generate palette
# if it is not defined
updated_palette = self._update_palette()
self._metainfo.update(dict(palette=updated_palette))
if test_mode: if test_mode:
assert self.CLASSES is not None, \ assert self._metainfo.get('classes') is not None, \
'`cls.CLASSES` or `classes` should be specified when testing' 'dataset metainfo `classes` should be specified when testing'
# join paths if data_root is specified # Join paths.
if self.data_root is not None: if self.data_root is not None:
if not osp.isabs(self.img_dir): self._join_prefix()
self.img_dir = osp.join(self.data_root, self.img_dir)
if not (self.ann_dir is None or osp.isabs(self.ann_dir)):
self.ann_dir = osp.join(self.data_root, self.ann_dir)
if not (self.split is None or osp.isabs(self.split)):
self.split = osp.join(self.data_root, self.split)
# load annotations # Build pipeline.
self.img_infos = self.load_annotations(self.img_dir, self.img_suffix, self.pipeline = Compose(pipeline)
self.ann_dir, # Full initialize the dataset.
self.seg_map_suffix, self.split) if not lazy_init:
self.full_init()
def __len__(self): @classmethod
"""Total number of samples of data.""" def get_label_map(cls,
return len(self.img_infos) new_classes: Optional[Sequence] = None
) -> Union[Dict, None]:
"""Require label mapping.
def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix, The ``label_map`` is a dictionary, its keys are the old label ids and
split): its values are the new label ids, and is used for changing pixel
"""Load annotation from directory. labels in load_annotations. If and only if old classes in cls.METAINFO
is not equal to new classes in self._metainfo and nether of them is not
None, `label_map` is not None.
Args: Args:
img_dir (str): Path to image directory new_classes (list, tuple, optional): The new classes name from
img_suffix (str): Suffix of images. metainfo. Default to None.
ann_dir (str|None): Path to annotation directory.
seg_map_suffix (str|None): Suffix of segmentation maps.
split (str|None): Split txt file. If split is specified, only file
with suffix in the splits will be loaded. Otherwise, all images
in img_dir/ann_dir will be loaded. Default: None
Returns: Returns:
list[dict]: All image info of dataset. dict, optional: The mapping from old classes in cls.METAINFO to
new classes in self._metainfo
""" """
old_classes = cls.METAINFO.get('classes', None)
if (new_classes is not None and old_classes is not None
and list(new_classes) != list(old_classes)):
img_infos = [] label_map = {}
if split is not None: if not set(new_classes).issubset(cls.METAINFO['classes']):
lines = mmcv.list_from_file( raise ValueError(
split, file_client_args=self.file_client_args) f'new classes {new_classes} is not a '
for line in lines: f'subset of classes {old_classes} in METAINFO.')
img_name = line.strip() for i, c in enumerate(old_classes):
img_info = dict(filename=img_name + img_suffix) if c not in new_classes:
if ann_dir is not None: label_map[i] = -1
seg_map = img_name + seg_map_suffix
img_info['ann'] = dict(seg_map=seg_map)
img_infos.append(img_info)
else: else:
for img in self.file_client.list_dir_or_file( label_map[i] = new_classes.index(c)
dir_path=img_dir, return label_map
list_dir=False,
suffix=img_suffix,
recursive=True):
img_info = dict(filename=img)
if ann_dir is not None:
seg_map = img.replace(img_suffix, seg_map_suffix)
img_info['ann'] = dict(seg_map=seg_map)
img_infos.append(img_info)
img_infos = sorted(img_infos, key=lambda x: x['filename'])
print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger())
return img_infos
def get_ann_info(self, idx):
"""Get annotation by index.
Args:
idx (int): Index of data.
Returns:
dict: Annotation info of specified index.
"""
return self.img_infos[idx]['ann']
def pre_pipeline(self, results):
"""Prepare results dict for pipeline."""
results['seg_fields'] = []
results['img_prefix'] = self.img_dir
results['seg_prefix'] = self.ann_dir
if self.custom_classes:
results['label_map'] = self.label_map
def __getitem__(self, idx):
"""Get training/test data after pipeline.
Args:
idx (int): Index of data.
Returns:
dict: Training/test data (with annotation if `test_mode` is set
False).
"""
if self.test_mode:
return self.prepare_test_img(idx)
else: else:
return self.prepare_train_img(idx) return None
def prepare_train_img(self, idx): def _update_palette(self) -> list:
"""Get training data and annotations after pipeline. """Update palette after loading metainfo.
Args: If length of palette is equal to classes, just return the palette.
idx (int): Index of data. If palette is not defined, it will randomly generate a palette.
If classes is updated by customer, it will return the subset of
palette.
Returns: Returns:
dict: Training data and annotation after pipeline with new keys Sequence: Palette for current dataset.
introduced by pipeline.
""" """
palette = self._metainfo.get('palette', [])
classes = self._metainfo.get('classes', [])
# palette does match classes
if len(palette) == len(classes):
return palette
img_info = self.img_infos[idx] if len(palette) == 0:
ann_info = self.get_ann_info(idx)
results = dict(img_info=img_info, ann_info=ann_info)
self.pre_pipeline(results)
return self.pipeline(results)
def prepare_test_img(self, idx):
"""Get testing data after pipeline.
Args:
idx (int): Index of data.
Returns:
dict: Testing data after pipeline with new keys introduced by
pipeline.
"""
img_info = self.img_infos[idx]
results = dict(img_info=img_info)
self.pre_pipeline(results)
return self.pipeline(results)
def format_results(self, results, imgfile_prefix, indices=None, **kwargs):
"""Place holder to format result to dataset specific output."""
raise NotImplementedError
def get_gt_seg_map_by_idx(self, index):
"""Get one ground truth segmentation map for evaluation."""
ann_info = self.get_ann_info(index)
results = dict(ann_info=ann_info)
self.pre_pipeline(results)
self.gt_seg_map_loader(results)
return results['gt_semantic_seg']
def get_gt_seg_maps(self, efficient_test=None):
"""Get ground truth segmentation maps for evaluation."""
if efficient_test is not None:
warnings.warn(
'DeprecationWarning: ``efficient_test`` has been deprecated '
'since MMSeg v0.16, the ``get_gt_seg_maps()`` is CPU memory '
'friendly by default. ')
for idx in range(len(self)):
ann_info = self.get_ann_info(idx)
results = dict(ann_info=ann_info)
self.pre_pipeline(results)
self.gt_seg_map_loader(results)
yield results['gt_semantic_seg']
def pre_eval(self, preds, indices):
"""Collect eval result from each iteration.
Args:
preds (list[torch.Tensor] | torch.Tensor): the segmentation logit
after argmax, shape (N, H, W).
indices (list[int] | int): the prediction related ground truth
indices.
Returns:
list[torch.Tensor]: (area_intersect, area_union, area_prediction,
area_ground_truth).
"""
# In order to compat with batch inference
if not isinstance(indices, list):
indices = [indices]
if not isinstance(preds, list):
preds = [preds]
pre_eval_results = []
for pred, index in zip(preds, indices):
seg_map = self.get_gt_seg_map_by_idx(index)
pre_eval_results.append(
intersect_and_union(
pred,
seg_map,
len(self.CLASSES),
self.ignore_index,
# as the labels has been converted when dataset initialized
# in `get_palette_for_custom_classes ` this `label_map`
# should be `dict()`, see
# https://github.com/open-mmlab/mmsegmentation/issues/1415
# for more ditails
label_map=dict(),
reduce_zero_label=self.reduce_zero_label))
return pre_eval_results
def get_classes_and_palette(self, classes=None, palette=None):
"""Get class names of current dataset.
Args:
classes (Sequence[str] | str | None): If classes is None, use
default CLASSES defined by builtin dataset. If classes is a
string, take it as a file name. The file contains the name of
classes where each line contains one class name. If classes is
a tuple or list, override the CLASSES defined by the dataset.
palette (Sequence[Sequence[int]]] | np.ndarray | None):
The palette of segmentation map. If None is given, random
palette will be generated. Default: None
"""
if classes is None:
self.custom_classes = False
return self.CLASSES, self.PALETTE
self.custom_classes = True
if isinstance(classes, str):
# take it as a file path
class_names = mmcv.list_from_file(classes)
elif isinstance(classes, (tuple, list)):
class_names = classes
else:
raise ValueError(f'Unsupported type {type(classes)} of classes.')
if self.CLASSES:
if not set(class_names).issubset(self.CLASSES):
raise ValueError('classes is not a subset of CLASSES.')
# dictionary, its keys are the old label ids and its values
# are the new label ids.
# used for changing pixel labels in load_annotations.
self.label_map = {}
for i, c in enumerate(self.CLASSES):
if c not in class_names:
self.label_map[i] = -1
else:
self.label_map[i] = class_names.index(c)
palette = self.get_palette_for_custom_classes(class_names, palette)
return class_names, palette
def get_palette_for_custom_classes(self, class_names, palette=None):
if self.label_map is not None:
# return subset of palette
palette = []
for old_id, new_id in sorted(
self.label_map.items(), key=lambda x: x[1]):
if new_id != -1:
palette.append(self.PALETTE[old_id])
palette = type(self.PALETTE)(palette)
elif palette is None:
if self.PALETTE is None:
# Get random state before set seed, and restore # Get random state before set seed, and restore
# random state later. # random state later.
# It will prevent loss of randomness, as the palette # It will prevent loss of randomness, as the palette
@ -378,110 +206,55 @@ class CustomDataset(Dataset):
state = np.random.get_state() state = np.random.get_state()
np.random.seed(42) np.random.seed(42)
# random palette # random palette
palette = np.random.randint(0, 255, size=(len(class_names), 3)) new_palette = np.random.randint(
0, 255, size=(len(classes), 3)).tolist()
np.random.set_state(state) np.random.set_state(state)
elif len(palette) >= len(classes) and self.label_map is not None:
new_palette = []
# return subset of palette
for old_id, new_id in sorted(
self.label_map.items(), key=lambda x: x[1]):
if new_id != -1:
new_palette.append(palette[old_id])
new_palette = type(palette)(new_palette)
else: else:
palette = self.PALETTE raise ValueError('palette does not match classes '
f'as metainfo is {self._metainfo}.')
return new_palette
return palette def load_data_list(self) -> List[dict]:
"""Load annotation from directory or annotation file.
def evaluate(self,
results,
metric='mIoU',
logger=None,
gt_seg_maps=None,
**kwargs):
"""Evaluate the dataset.
Args:
results (list[tuple[torch.Tensor]] | list[str]): per image pre_eval
results or predict segmentation map for computing evaluation
metric.
metric (str | list[str]): Metrics to be evaluated. 'mIoU',
'mDice' and 'mFscore' are supported.
logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Default: None.
gt_seg_maps (generator[ndarray]): Custom gt seg maps as input,
used in ConcatDataset
Returns: Returns:
dict[str, float]: Default metrics. list[dict]: All data info of dataset.
""" """
if isinstance(metric, str): data_list = []
metric = [metric] ann_dir = self.data_prefix.get('seg_map_path', None)
allowed_metrics = ['mIoU', 'mDice', 'mFscore'] if osp.isfile(self.ann_file):
if not set(metric).issubset(set(allowed_metrics)): lines = mmcv.list_from_file(
raise KeyError('metric {} is not supported'.format(metric)) self.ann_file, file_client_args=self.file_client_args)
for line in lines:
eval_results = {} img_name = line.strip()
# test a list of files data_info = dict(img_path=img_name + self.img_suffix)
if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of( if ann_dir is not None:
results, str): seg_map = img_name + self.seg_map_suffix
if gt_seg_maps is None: data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
gt_seg_maps = self.get_gt_seg_maps() data_info['label_map'] = self.label_map
num_classes = len(self.CLASSES) data_info['seg_field'] = []
ret_metrics = eval_metrics( data_list.append(data_info)
results,
gt_seg_maps,
num_classes,
self.ignore_index,
metric,
label_map=dict(),
reduce_zero_label=self.reduce_zero_label)
# test a list of pre_eval_results
else: else:
ret_metrics = pre_eval_to_metrics(results, metric) img_dir = self.data_prefix['img_path']
for img in self.file_client.list_dir_or_file(
# Because dataset.CLASSES is required for per-eval. dir_path=img_dir,
if self.CLASSES is None: list_dir=False,
class_names = tuple(range(num_classes)) suffix=self.img_suffix,
else: recursive=True):
class_names = self.CLASSES data_info = dict(img_path=osp.join(img_dir, img))
if ann_dir is not None:
# summary table seg_map = img.replace(self.img_suffix, self.seg_map_suffix)
ret_metrics_summary = OrderedDict({ data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2) data_info['label_map'] = self.label_map
for ret_metric, ret_metric_value in ret_metrics.items() data_info['seg_field'] = []
}) data_list.append(data_info)
data_list = sorted(data_list, key=lambda x: x['img_path'])
# each class table return data_list
ret_metrics.pop('aAcc', None)
ret_metrics_class = OrderedDict({
ret_metric: np.round(ret_metric_value * 100, 2)
for ret_metric, ret_metric_value in ret_metrics.items()
})
ret_metrics_class.update({'Class': class_names})
ret_metrics_class.move_to_end('Class', last=False)
# for logger
class_table_data = PrettyTable()
for key, val in ret_metrics_class.items():
class_table_data.add_column(key, val)
summary_table_data = PrettyTable()
for key, val in ret_metrics_summary.items():
if key == 'aAcc':
summary_table_data.add_column(key, [val])
else:
summary_table_data.add_column('m' + key, [val])
print_log('per class results:', logger)
print_log('\n' + class_table_data.get_string(), logger=logger)
print_log('Summary:', logger)
print_log('\n' + summary_table_data.get_string(), logger=logger)
# each metric dict
for key, value in ret_metrics_summary.items():
if key == 'aAcc':
eval_results[key] = value / 100.0
else:
eval_results['m' + key] = value / 100.0
ret_metrics_class.pop('Class', None)
for key, value in ret_metrics_class.items():
eval_results.update({
key + '.' + str(name): value[idx] / 100.0
for idx, name in enumerate(class_names)
})
return eval_results

View File

@ -7,7 +7,7 @@ from .cityscapes import CityscapesDataset
class DarkZurichDataset(CityscapesDataset): class DarkZurichDataset(CityscapesDataset):
"""DarkZurichDataset dataset.""" """DarkZurichDataset dataset."""
def __init__(self, **kwargs): def __init__(self, **kwargs) -> None:
super().__init__( super().__init__(
img_suffix='_rgb_anon.png', img_suffix='_rgb_anon.png',
seg_map_suffix='_gt_labelTrainIds.png', seg_map_suffix='_gt_labelTrainIds.png',

View File

@ -13,15 +13,14 @@ class DRIVEDataset(CustomDataset):
``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
'_manual1.png'. '_manual1.png'.
""" """
METAINFO = dict(
classes=('background', 'vessel'),
palette=[[120, 120, 120], [6, 230, 230]])
CLASSES = ('background', 'vessel') def __init__(self, **kwargs) -> None:
super().__init__(
PALETTE = [[120, 120, 120], [6, 230, 230]]
def __init__(self, **kwargs):
super(DRIVEDataset, self).__init__(
img_suffix='.png', img_suffix='.png',
seg_map_suffix='_manual1.png', seg_map_suffix='_manual1.png',
reduce_zero_label=False, reduce_zero_label=False,
**kwargs) **kwargs)
assert self.file_client.exists(self.img_dir) assert self.file_client.exists(self.data_prefix['img_path'])

View File

@ -13,15 +13,14 @@ class HRFDataset(CustomDataset):
``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
'.png'. '.png'.
""" """
METAINFO = dict(
classes=('background', 'vessel'),
palette=[[120, 120, 120], [6, 230, 230]])
CLASSES = ('background', 'vessel') def __init__(self, **kwargs) -> None:
super().__init__(
PALETTE = [[120, 120, 120], [6, 230, 230]]
def __init__(self, **kwargs):
super(HRFDataset, self).__init__(
img_suffix='.png', img_suffix='.png',
seg_map_suffix='.png', seg_map_suffix='.png',
reduce_zero_label=False, reduce_zero_label=False,
**kwargs) **kwargs)
assert self.file_client.exists(self.img_dir) assert self.file_client.exists(self.data_prefix['img_path'])

View File

@ -1,10 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import mmcv
from mmcv.utils import print_log
from mmseg.registry import DATASETS from mmseg.registry import DATASETS
from ..utils import get_root_logger
from .custom import CustomDataset from .custom import CustomDataset
@ -17,66 +12,21 @@ class iSAIDDataset(CustomDataset):
'_manual1.png'. '_manual1.png'.
""" """
CLASSES = ('background', 'ship', 'store_tank', 'baseball_diamond', METAINFO = dict(
classes=('background', 'ship', 'store_tank', 'baseball_diamond',
'tennis_court', 'basketball_court', 'Ground_Track_Field', 'tennis_court', 'basketball_court', 'Ground_Track_Field',
'Bridge', 'Large_Vehicle', 'Small_Vehicle', 'Helicopter', 'Bridge', 'Large_Vehicle', 'Small_Vehicle', 'Helicopter',
'Swimming_pool', 'Roundabout', 'Soccer_ball_field', 'plane', 'Swimming_pool', 'Roundabout', 'Soccer_ball_field', 'plane',
'Harbor') 'Harbor'),
palette=[[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127],
PALETTE = [[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127],
[0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, 127], [0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, 127],
[0, 0, 127], [0, 0, 191], [0, 0, 255], [0, 191, 127], [0, 0, 127], [0, 0, 191], [0, 0, 255], [0, 191, 127],
[0, 127, 191], [0, 127, 255], [0, 100, 155]] [0, 127, 191], [0, 127, 255], [0, 100, 155]])
def __init__(self, **kwargs): def __init__(self, **kwargs) -> None:
super(iSAIDDataset, self).__init__( super().__init__(
img_suffix='.png', img_suffix='.png',
seg_map_suffix='.png', seg_map_suffix='_instance_color_RGB.png',
ignore_index=255, ignore_index=255,
**kwargs) **kwargs)
assert self.file_client.exists(self.img_dir) assert self.file_client.exists(self.data_prefix['img_path'])
def load_annotations(self,
img_dir,
img_suffix,
ann_dir,
seg_map_suffix=None,
split=None):
"""Load annotation from directory.
Args:
img_dir (str): Path to image directory
img_suffix (str): Suffix of images.
ann_dir (str|None): Path to annotation directory.
seg_map_suffix (str|None): Suffix of segmentation maps.
split (str|None): Split txt file. If split is specified, only file
with suffix in the splits will be loaded. Otherwise, all images
in img_dir/ann_dir will be loaded. Default: None
Returns:
list[dict]: All image info of dataset.
"""
img_infos = []
if split is not None:
with open(split) as f:
for line in f:
name = line.strip()
img_info = dict(filename=name + img_suffix)
if ann_dir is not None:
ann_name = name + '_instance_color_RGB'
seg_map = ann_name + seg_map_suffix
img_info['ann'] = dict(seg_map=seg_map)
img_infos.append(img_info)
else:
for img in mmcv.scandir(img_dir, img_suffix, recursive=True):
img_info = dict(filename=img)
if ann_dir is not None:
seg_img = img
seg_map = seg_img.replace(
img_suffix, '_instance_color_RGB' + seg_map_suffix)
img_info['ann'] = dict(seg_map=seg_map)
img_infos.append(img_info)
print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger())
return img_infos

View File

@ -11,14 +11,14 @@ class ISPRSDataset(CustomDataset):
``reduce_zero_label`` should be set to True. The ``img_suffix`` and ``reduce_zero_label`` should be set to True. The ``img_suffix`` and
``seg_map_suffix`` are both fixed to '.png'. ``seg_map_suffix`` are both fixed to '.png'.
""" """
CLASSES = ('impervious_surface', 'building', 'low_vegetation', 'tree', METAINFO = dict(
'car', 'clutter') classes=('impervious_surface', 'building', 'low_vegetation', 'tree',
'car', 'clutter'),
palette=[[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0],
[255, 255, 0], [255, 0, 0]])
PALETTE = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], def __init__(self, **kwargs) -> None:
[255, 255, 0], [255, 0, 0]] super().__init__(
def __init__(self, **kwargs):
super(ISPRSDataset, self).__init__(
img_suffix='.png', img_suffix='.png',
seg_map_suffix='.png', seg_map_suffix='.png',
reduce_zero_label=True, reduce_zero_label=True,

View File

@ -1,10 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import mmcv
import numpy as np
from PIL import Image
from mmseg.registry import DATASETS from mmseg.registry import DATASETS
from .custom import CustomDataset from .custom import CustomDataset
@ -17,76 +11,15 @@ class LoveDADataset(CustomDataset):
``reduce_zero_label`` should be set to True. The ``img_suffix`` and ``reduce_zero_label`` should be set to True. The ``img_suffix`` and
``seg_map_suffix`` are both fixed to '.png'. ``seg_map_suffix`` are both fixed to '.png'.
""" """
CLASSES = ('background', 'building', 'road', 'water', 'barren', 'forest', METAINFO = dict(
'agricultural') classes=('background', 'building', 'road', 'water', 'barren', 'forest',
'agricultural'),
palette=[[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255],
[159, 129, 183], [0, 255, 0], [255, 195, 128]])
PALETTE = [[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255], def __init__(self, **kwargs) -> None:
[159, 129, 183], [0, 255, 0], [255, 195, 128]] super().__init__(
def __init__(self, **kwargs):
super(LoveDADataset, self).__init__(
img_suffix='.png', img_suffix='.png',
seg_map_suffix='.png', seg_map_suffix='.png',
reduce_zero_label=True, reduce_zero_label=True,
**kwargs) **kwargs)
def results2img(self, results, imgfile_prefix, indices=None):
"""Write the segmentation results to images.
Args:
results (list[ndarray]): Testing results of the
dataset.
imgfile_prefix (str): The filename prefix of the png files.
If the prefix is "somepath/xxx",
the png files will be named "somepath/xxx.png".
indices (list[int], optional): Indices of input results, if not
set, all the indices of the dataset will be used.
Default: None.
Returns:
list[str: str]: result txt files which contains corresponding
semantic segmentation images.
"""
mmcv.mkdir_or_exist(imgfile_prefix)
result_files = []
for result, idx in zip(results, indices):
filename = self.img_infos[idx]['filename']
basename = osp.splitext(osp.basename(filename))[0]
png_filename = osp.join(imgfile_prefix, f'{basename}.png')
# The index range of official requirement is from 0 to 6.
output = Image.fromarray(result.astype(np.uint8))
output.save(png_filename)
result_files.append(png_filename)
return result_files
def format_results(self, results, imgfile_prefix, indices=None):
"""Format the results into dir (standard format for LoveDA evaluation).
Args:
results (list): Testing results of the dataset.
imgfile_prefix (str): The prefix of images files. It
includes the file path and the prefix of filename, e.g.,
"a/b/prefix".
indices (list[int], optional): Indices of input results,
if not set, all the indices of the dataset will be used.
Default: None.
Returns:
tuple: (result_files, tmp_dir), result_files is a list containing
the image paths, tmp_dir is the temporal directory created
for saving json/png files when img_prefix is not specified.
"""
if indices is None:
indices = list(range(len(self)))
assert isinstance(results, list), 'results must be a list.'
assert isinstance(indices, list), 'indices must be a list.'
result_files = self.results2img(results, imgfile_prefix, indices)
return result_files

View File

@ -7,7 +7,7 @@ from .cityscapes import CityscapesDataset
class NightDrivingDataset(CityscapesDataset): class NightDrivingDataset(CityscapesDataset):
"""NightDrivingDataset dataset.""" """NightDrivingDataset dataset."""
def __init__(self, **kwargs): def __init__(self, **kwargs) -> None:
super().__init__( super().__init__(
img_suffix='_leftImg8bit.png', img_suffix='_leftImg8bit.png',
seg_map_suffix='_gtCoarse_labelTrainIds.png', seg_map_suffix='_gtCoarse_labelTrainIds.png',

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from mmseg.registry import DATASETS from mmseg.registry import DATASETS
from .custom import CustomDataset from .custom import CustomDataset
@ -14,21 +15,21 @@ class PascalContextDataset(CustomDataset):
fixed to '.png'. fixed to '.png'.
Args: Args:
split (str): Split txt file for PascalContext. ann_file (str): Annotation file path.
""" """
CLASSES = ('background', 'aeroplane', 'bag', 'bed', 'bedclothes', 'bench', METAINFO = dict(
'bicycle', 'bird', 'boat', 'book', 'bottle', 'building', 'bus', classes=('background', 'aeroplane', 'bag', 'bed', 'bedclothes',
'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth', 'bench', 'bicycle', 'bird', 'boat', 'book', 'bottle',
'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence', 'building', 'bus', 'cabinet', 'car', 'cat', 'ceiling',
'floor', 'flower', 'food', 'grass', 'ground', 'horse', 'chair', 'cloth', 'computer', 'cow', 'cup', 'curtain', 'dog',
'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person', 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground',
'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep', 'horse', 'keyboard', 'light', 'motorbike', 'mountain',
'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table', 'mouse', 'person', 'plate', 'platform', 'pottedplant', 'road',
'track', 'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water', 'rock', 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow',
'window', 'wood') 'sofa', 'table', 'track', 'train', 'tree', 'truck',
'tvmonitor', 'wall', 'water', 'window', 'wood'),
PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], palette=[[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
@ -42,16 +43,17 @@ class PascalContextDataset(CustomDataset):
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]] [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]])
def __init__(self, split, **kwargs): def __init__(self, ann_file: str, **kwargs) -> None:
super(PascalContextDataset, self).__init__( super().__init__(
img_suffix='.jpg', img_suffix='.jpg',
seg_map_suffix='.png', seg_map_suffix='.png',
split=split, ann_file=ann_file,
reduce_zero_label=False, reduce_zero_label=False,
**kwargs) **kwargs)
assert self.file_client.exists(self.img_dir) and self.split is not None assert self.file_client.exists(
self.data_prefix['img_path']) and osp.isfile(self.ann_file)
@DATASETS.register_module() @DATASETS.register_module()
@ -64,40 +66,41 @@ class PascalContextDataset59(CustomDataset):
fixed to '.png'. fixed to '.png'.
Args: Args:
split (str): Split txt file for PascalContext. ann_file (str): Annotation file path.
""" """
METAINFO = dict(
classes=('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle',
'bird', 'boat', 'book', 'bottle', 'building', 'bus',
'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth',
'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence',
'floor', 'flower', 'food', 'grass', 'ground', 'horse',
'keyboard', 'light', 'motorbike', 'mountain', 'mouse',
'person', 'plate', 'platform', 'pottedplant', 'road', 'rock',
'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa',
'table', 'track', 'train', 'tree', 'truck', 'tvmonitor',
'wall', 'water', 'window', 'wood'),
palette=[[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3],
[120, 120, 80], [140, 140, 140], [204, 5, 255],
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]])
CLASSES = ('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', def __init__(self, ann_file: str, **kwargs):
'bird', 'boat', 'book', 'bottle', 'building', 'bus', 'cabinet', super().__init__(
'car', 'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow',
'cup', 'curtain', 'dog', 'door', 'fence', 'floor', 'flower',
'food', 'grass', 'ground', 'horse', 'keyboard', 'light',
'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform',
'pottedplant', 'road', 'rock', 'sheep', 'shelves', 'sidewalk',
'sign', 'sky', 'snow', 'sofa', 'table', 'track', 'train',
'tree', 'truck', 'tvmonitor', 'wall', 'water', 'window', 'wood')
PALETTE = [[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3],
[120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230],
[4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61],
[120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140],
[204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200],
[61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71],
[255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92],
[112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6],
[10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8],
[102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8],
[0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255],
[235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140],
[250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0],
[255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0],
[0, 235, 255], [0, 173, 255], [31, 0, 255]]
def __init__(self, split, **kwargs):
super(PascalContextDataset59, self).__init__(
img_suffix='.jpg', img_suffix='.jpg',
seg_map_suffix='.png', seg_map_suffix='.png',
split=split, ann_file=ann_file,
reduce_zero_label=True, reduce_zero_label=True,
**kwargs) **kwargs)
assert self.file_client.exists(self.img_dir) and self.split is not None assert self.file_client.exists(
self.data_prefix['img_path']) and osp.isfile(self.ann_file)

View File

@ -11,14 +11,14 @@ class PotsdamDataset(CustomDataset):
``reduce_zero_label`` should be set to True. The ``img_suffix`` and ``reduce_zero_label`` should be set to True. The ``img_suffix`` and
``seg_map_suffix`` are both fixed to '.png'. ``seg_map_suffix`` are both fixed to '.png'.
""" """
CLASSES = ('impervious_surface', 'building', 'low_vegetation', 'tree', METAINFO = dict(
'car', 'clutter') classes=('impervious_surface', 'building', 'low_vegetation', 'tree',
'car', 'clutter'),
palette=[[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0],
[255, 255, 0], [255, 0, 0]])
PALETTE = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], def __init__(self, **kwargs) -> None:
[255, 255, 0], [255, 0, 0]] super().__init__(
def __init__(self, **kwargs):
super(PotsdamDataset, self).__init__(
img_suffix='.png', img_suffix='.png',
seg_map_suffix='.png', seg_map_suffix='.png',
reduce_zero_label=True, reduce_zero_label=True,

View File

@ -1,6 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from mmseg.registry import DATASETS from mmseg.registry import DATASETS
from .custom import CustomDataset from .custom import CustomDataset
@ -14,15 +12,14 @@ class STAREDataset(CustomDataset):
``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
'.ah.png'. '.ah.png'.
""" """
METAINFO = dict(
classes=('background', 'vessel'),
palette=[[120, 120, 120], [6, 230, 230]])
CLASSES = ('background', 'vessel') def __init__(self, **kwargs) -> None:
super().__init__(
PALETTE = [[120, 120, 120], [6, 230, 230]]
def __init__(self, **kwargs):
super(STAREDataset, self).__init__(
img_suffix='.png', img_suffix='.png',
seg_map_suffix='.ah.png', seg_map_suffix='.ah.png',
reduce_zero_label=False, reduce_zero_label=False,
**kwargs) **kwargs)
assert osp.exists(self.img_dir) assert self.file_client.exists(self.data_prefix['img_path'])

View File

@ -12,19 +12,23 @@ class PascalVOCDataset(CustomDataset):
Args: Args:
split (str): Split txt file for Pascal VOC. split (str): Split txt file for Pascal VOC.
""" """
METAINFO = dict(
classes=('background', 'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',
'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep',
'sofa', 'train', 'tvmonitor'),
palette=[[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
[64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
[64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
[0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
[0, 64, 128]])
CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', def __init__(self, ann_file, **kwargs) -> None:
'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', super().__init__(
'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', img_suffix='.jpg',
'train', 'tvmonitor') seg_map_suffix='.png',
ann_file=ann_file,
PALETTE = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], **kwargs)
[128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], assert self.file_client.exists(
[192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], self.data_prefix['img_path']) and osp.isfile(self.ann_file)
[192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
[128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
def __init__(self, split, **kwargs):
super(PascalVOCDataset, self).__init__(
img_suffix='.jpg', seg_map_suffix='.png', split=split, **kwargs)
assert osp.exists(self.img_dir) and self.split is not None

View File

@ -0,0 +1,353 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import tempfile
from unittest.mock import MagicMock
import pytest
from mmseg.core.evaluation import get_classes, get_palette
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
COCOStuffDataset, CustomDataset, ISPRSDataset,
LoveDADataset, PascalVOCDataset, PotsdamDataset,
iSAIDDataset)
def test_classes():
assert list(
CityscapesDataset.METAINFO['classes']) == get_classes('cityscapes')
assert list(PascalVOCDataset.METAINFO['classes']) == get_classes(
'voc') == get_classes('pascal_voc')
assert list(ADE20KDataset.METAINFO['classes']) == get_classes(
'ade') == get_classes('ade20k')
assert list(
COCOStuffDataset.METAINFO['classes']) == get_classes('cocostuff')
assert list(LoveDADataset.METAINFO['classes']) == get_classes('loveda')
assert list(PotsdamDataset.METAINFO['classes']) == get_classes('potsdam')
assert list(ISPRSDataset.METAINFO['classes']) == get_classes('vaihingen')
assert list(iSAIDDataset.METAINFO['classes']) == get_classes('isaid')
with pytest.raises(ValueError):
get_classes('unsupported')
def test_classes_file_path():
tmp_file = tempfile.NamedTemporaryFile()
classes_path = f'{tmp_file.name}.txt'
train_pipeline = []
kwargs = dict(
pipeline=train_pipeline,
data_prefix=dict(img_path='./', seg_map_path='./'),
metainfo=dict(classes=classes_path))
# classes.txt with full categories
categories = get_classes('cityscapes')
with open(classes_path, 'w') as f:
f.write('\n'.join(categories))
dataset = CityscapesDataset(**kwargs)
assert list(dataset.metainfo['classes']) == categories
assert dataset.label_map is None
# classes.txt with sub categories
categories = ['road', 'sidewalk', 'building']
with open(classes_path, 'w') as f:
f.write('\n'.join(categories))
dataset = CityscapesDataset(**kwargs)
assert list(dataset.metainfo['classes']) == categories
assert dataset.label_map is not None
# classes.txt with unknown categories
categories = ['road', 'sidewalk', 'unknown']
with open(classes_path, 'w') as f:
f.write('\n'.join(categories))
with pytest.raises(ValueError):
CityscapesDataset(**kwargs)
tmp_file.close()
os.remove(classes_path)
assert not osp.exists(classes_path)
def test_palette():
assert CityscapesDataset.METAINFO['palette'] == get_palette('cityscapes')
assert PascalVOCDataset.METAINFO['palette'] == get_palette(
'voc') == get_palette('pascal_voc')
assert ADE20KDataset.METAINFO['palette'] == get_palette(
'ade') == get_palette('ade20k')
assert LoveDADataset.METAINFO['palette'] == get_palette('loveda')
assert PotsdamDataset.METAINFO['palette'] == get_palette('potsdam')
assert COCOStuffDataset.METAINFO['palette'] == get_palette('cocostuff')
assert iSAIDDataset.METAINFO['palette'] == get_palette('isaid')
with pytest.raises(ValueError):
get_palette('unsupported')
def test_custom_dataset():
# with 'img_path' and 'seg_map_path' in data_prefix
train_dataset = CustomDataset(
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
data_prefix=dict(
img_path='imgs/',
seg_map_path='gts/',
),
img_suffix='img.jpg',
seg_map_suffix='gt.png')
assert len(train_dataset) == 5
# with 'img_path' and 'seg_map_path' in data_prefix and ann_file
train_dataset = CustomDataset(
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
data_prefix=dict(
img_path='imgs/',
seg_map_path='gts/',
),
img_suffix='img.jpg',
seg_map_suffix='gt.png',
ann_file='splits/train.txt')
assert len(train_dataset) == 4
# no data_root
train_dataset = CustomDataset(
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
seg_map_path=osp.join(
osp.dirname(__file__), '../data/pseudo_dataset/gts')),
img_suffix='img.jpg',
seg_map_suffix='gt.png')
assert len(train_dataset) == 5
# with data_root but 'img_path' and 'seg_map_path' in data_prefix are
# abs path
train_dataset = CustomDataset(
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
seg_map_path=osp.join(
osp.dirname(__file__), '../data/pseudo_dataset/gts')),
img_suffix='img.jpg',
seg_map_suffix='gt.png')
assert len(train_dataset) == 5
# test_mode=True
test_dataset = CustomDataset(
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__), '../data/pseudo_dataset/imgs')),
img_suffix='img.jpg',
test_mode=True,
metainfo=dict(classes=('pseudo_class', )))
assert len(test_dataset) == 5
# training data get
train_data = train_dataset[0]
assert isinstance(train_data, dict)
assert 'img_path' in train_data and osp.isfile(train_data['img_path'])
assert 'seg_map_path' in train_data and osp.isfile(
train_data['seg_map_path'])
# test data get
test_data = test_dataset[0]
assert isinstance(test_data, dict)
assert 'img_path' in train_data and osp.isfile(train_data['img_path'])
assert 'seg_map_path' in train_data and osp.isfile(
train_data['seg_map_path'])
def test_ade():
test_dataset = ADE20KDataset(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__), '../data/pseudo_dataset/imgs')))
assert len(test_dataset) == 5
def test_cityscapes():
test_dataset = CityscapesDataset(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_cityscapes_dataset/leftImg8bit'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_cityscapes_dataset/gtFine')))
assert len(test_dataset) == 1
def test_loveda():
test_dataset = LoveDADataset(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_loveda_dataset/img_dir'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_loveda_dataset/ann_dir')))
assert len(test_dataset) == 3
def test_potsdam():
test_dataset = PotsdamDataset(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_potsdam_dataset/img_dir'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_potsdam_dataset/ann_dir')))
assert len(test_dataset) == 1
def test_vaihingen():
test_dataset = ISPRSDataset(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_vaihingen_dataset/img_dir'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_vaihingen_dataset/ann_dir')))
assert len(test_dataset) == 1
def test_isaid():
test_dataset = iSAIDDataset(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__), '../data/pseudo_isaid_dataset/img_dir'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_isaid_dataset/ann_dir')))
assert len(test_dataset) == 2
test_dataset = iSAIDDataset(
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__), '../data/pseudo_isaid_dataset/img_dir'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_isaid_dataset/ann_dir')),
ann_file=osp.join(
osp.dirname(__file__),
'../data/pseudo_isaid_dataset/splits/train.txt'))
assert len(test_dataset) == 1
@pytest.mark.parametrize('dataset, classes', [
('ADE20KDataset', ('wall', 'building')),
('CityscapesDataset', ('road', 'sidewalk')),
('CustomDataset', ('bus', 'car')),
('PascalVOCDataset', ('aeroplane', 'bicycle')),
])
def test_custom_classes_override_default(dataset, classes):
dataset_class = DATASETS.get(dataset)
if isinstance(dataset_class, PascalVOCDataset):
tmp_file = tempfile.NamedTemporaryFile()
ann_file = f'{tmp_file.name}.txt'
else:
ann_file = MagicMock()
original_classes = dataset_class.METAINFO.get('classes', None)
# Test setting classes as a tuple
custom_dataset = dataset_class(
data_prefix=dict(img_path=MagicMock()),
ann_file=ann_file,
metainfo=dict(classes=classes),
test_mode=True,
lazy_init=True)
assert custom_dataset.metainfo['classes'] != original_classes
assert custom_dataset.metainfo['classes'] == classes
if not isinstance(custom_dataset, CustomDataset):
assert isinstance(custom_dataset.label_map, dict)
# Test setting classes as a list
custom_dataset = dataset_class(
data_prefix=dict(img_path=MagicMock()),
ann_file=ann_file,
metainfo=dict(classes=list(classes)),
test_mode=True,
lazy_init=True)
assert custom_dataset.metainfo['classes'] != original_classes
assert custom_dataset.metainfo['classes'] == list(classes)
if not isinstance(custom_dataset, CustomDataset):
assert isinstance(custom_dataset.label_map, dict)
# Test overriding not a subset
custom_dataset = dataset_class(
ann_file=ann_file,
data_prefix=dict(img_path=MagicMock()),
metainfo=dict(classes=[classes[0]]),
test_mode=True,
lazy_init=True)
assert custom_dataset.metainfo['classes'] != original_classes
assert custom_dataset.metainfo['classes'] == [classes[0]]
if not isinstance(custom_dataset, CustomDataset):
assert isinstance(custom_dataset.label_map, dict)
# Test default behavior
if dataset_class is CustomDataset:
with pytest.raises(AssertionError):
custom_dataset = dataset_class(
ann_file=ann_file,
data_prefix=dict(img_path=MagicMock()),
metainfo=None,
test_mode=True,
lazy_init=True)
else:
custom_dataset = dataset_class(
data_prefix=dict(img_path=MagicMock()),
ann_file=ann_file,
metainfo=None,
test_mode=True,
lazy_init=True)
assert custom_dataset.METAINFO['classes'] == original_classes
assert custom_dataset.label_map is None
def test_custom_dataset_random_palette_is_generated():
dataset = CustomDataset(
pipeline=[],
data_prefix=dict(img_path=MagicMock()),
ann_file=MagicMock(),
metainfo=dict(classes=('bus', 'car')),
lazy_init=True,
test_mode=True)
assert len(dataset.metainfo['palette']) == 2
for class_color in dataset.metainfo['palette']:
assert len(class_color) == 3
assert all(x >= 0 and x <= 255 for x in class_color)
def test_custom_dataset_custom_palette():
dataset = CustomDataset(
data_prefix=dict(img_path=MagicMock()),
ann_file=MagicMock(),
metainfo=dict(
classes=('bus', 'car'), palette=[[100, 100, 100], [200, 200,
200]]),
lazy_init=True,
test_mode=True)
assert tuple(dataset.metainfo['palette']) == tuple([[100, 100, 100],
[200, 200, 200]])
# test custom class and palette don't match
with pytest.raises(ValueError):
dataset = CustomDataset(
data_prefix=dict(img_path=MagicMock()),
ann_file=MagicMock(),
metainfo=dict(classes=('bus', 'car'), palette=[[200, 200, 200]]),
lazy_init=True)