Merge branch 'zhengmiao/refactory-dataset' into 'refactor_dev'
[Refactory] Dataset refactory See merge request openmmlab-enterprise/openmmlab-ce/mmsegmentation!15pull/1801/head
commit
d64f941fb3
|
@ -1,10 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import CustomDataset
|
||||
|
||||
|
@ -18,150 +12,77 @@ class ADE20KDataset(CustomDataset):
|
|||
The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to
|
||||
'.png'.
|
||||
"""
|
||||
CLASSES = (
|
||||
'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
|
||||
'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
|
||||
'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
|
||||
'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
|
||||
'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
|
||||
'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
|
||||
'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
|
||||
'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
|
||||
'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
|
||||
'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
|
||||
'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
|
||||
'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
|
||||
'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
|
||||
'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
|
||||
'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
|
||||
'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
|
||||
'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
|
||||
'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
|
||||
'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
|
||||
'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
|
||||
'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
|
||||
'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
|
||||
'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
|
||||
'clock', 'flag')
|
||||
METAINFO = dict(
|
||||
classes=('wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road',
|
||||
'bed ', 'windowpane', 'grass', 'cabinet', 'sidewalk',
|
||||
'person', 'earth', 'door', 'table', 'mountain', 'plant',
|
||||
'curtain', 'chair', 'car', 'water', 'painting', 'sofa',
|
||||
'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair',
|
||||
'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp',
|
||||
'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
|
||||
'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
|
||||
'skyscraper', 'fireplace', 'refrigerator', 'grandstand',
|
||||
'path', 'stairs', 'runway', 'case', 'pool table', 'pillow',
|
||||
'screen door', 'stairway', 'river', 'bridge', 'bookcase',
|
||||
'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill',
|
||||
'bench', 'countertop', 'stove', 'palm', 'kitchen island',
|
||||
'computer', 'swivel chair', 'boat', 'bar', 'arcade machine',
|
||||
'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
|
||||
'chandelier', 'awning', 'streetlight', 'booth',
|
||||
'television receiver', 'airplane', 'dirt track', 'apparel',
|
||||
'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle',
|
||||
'buffet', 'poster', 'stage', 'van', 'ship', 'fountain',
|
||||
'conveyer belt', 'canopy', 'washer', 'plaything',
|
||||
'swimming pool', 'stool', 'barrel', 'basket', 'waterfall',
|
||||
'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food',
|
||||
'step', 'tank', 'trade name', 'microwave', 'pot', 'animal',
|
||||
'bicycle', 'lake', 'dishwasher', 'screen', 'blanket',
|
||||
'sculpture', 'hood', 'sconce', 'vase', 'traffic light',
|
||||
'tray', 'ashcan', 'fan', 'pier', 'crt screen', 'plate',
|
||||
'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
|
||||
'clock', 'flag'),
|
||||
palette=[[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
|
||||
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
|
||||
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
|
||||
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
|
||||
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
|
||||
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
|
||||
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
|
||||
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
|
||||
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
|
||||
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
|
||||
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
|
||||
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
|
||||
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
|
||||
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
|
||||
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
|
||||
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
|
||||
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
|
||||
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
|
||||
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
|
||||
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
|
||||
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
|
||||
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
|
||||
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
|
||||
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
|
||||
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
|
||||
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
|
||||
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
|
||||
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
|
||||
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
|
||||
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
|
||||
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
|
||||
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
|
||||
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
|
||||
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
|
||||
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
|
||||
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
|
||||
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
|
||||
[102, 255, 0], [92, 0, 255]])
|
||||
|
||||
PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
|
||||
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
|
||||
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
|
||||
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
|
||||
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
|
||||
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
|
||||
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
|
||||
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
|
||||
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
|
||||
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
|
||||
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
|
||||
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
|
||||
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
|
||||
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
|
||||
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
|
||||
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
|
||||
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
|
||||
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
|
||||
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
|
||||
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
|
||||
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
|
||||
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
|
||||
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
|
||||
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
|
||||
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
|
||||
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
|
||||
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
|
||||
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
|
||||
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
|
||||
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
|
||||
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
|
||||
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
|
||||
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
|
||||
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
|
||||
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
|
||||
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
|
||||
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
|
||||
[102, 255, 0], [92, 0, 255]]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(ADE20KDataset, self).__init__(
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
reduce_zero_label=True,
|
||||
**kwargs)
|
||||
|
||||
def results2img(self, results, imgfile_prefix, to_label_id, indices=None):
|
||||
"""Write the segmentation results to images.
|
||||
|
||||
Args:
|
||||
results (list[ndarray]): Testing results of the
|
||||
dataset.
|
||||
imgfile_prefix (str): The filename prefix of the png files.
|
||||
If the prefix is "somepath/xxx",
|
||||
the png files will be named "somepath/xxx.png".
|
||||
to_label_id (bool): whether convert output to label_id for
|
||||
submission.
|
||||
indices (list[int], optional): Indices of input results, if not
|
||||
set, all the indices of the dataset will be used.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
list[str: str]: result txt files which contains corresponding
|
||||
semantic segmentation images.
|
||||
"""
|
||||
if indices is None:
|
||||
indices = list(range(len(self)))
|
||||
|
||||
mmcv.mkdir_or_exist(imgfile_prefix)
|
||||
result_files = []
|
||||
for result, idx in zip(results, indices):
|
||||
|
||||
filename = self.img_infos[idx]['filename']
|
||||
basename = osp.splitext(osp.basename(filename))[0]
|
||||
|
||||
png_filename = osp.join(imgfile_prefix, f'{basename}.png')
|
||||
|
||||
# The index range of official requirement is from 0 to 150.
|
||||
# But the index range of output is from 0 to 149.
|
||||
# That is because we set reduce_zero_label=True.
|
||||
result = result + 1
|
||||
|
||||
output = Image.fromarray(result.astype(np.uint8))
|
||||
output.save(png_filename)
|
||||
result_files.append(png_filename)
|
||||
|
||||
return result_files
|
||||
|
||||
def format_results(self,
|
||||
results,
|
||||
imgfile_prefix,
|
||||
to_label_id=True,
|
||||
indices=None):
|
||||
"""Format the results into dir (standard format for ade20k evaluation).
|
||||
|
||||
Args:
|
||||
results (list): Testing results of the dataset.
|
||||
imgfile_prefix (str | None): The prefix of images files. It
|
||||
includes the file path and the prefix of filename, e.g.,
|
||||
"a/b/prefix".
|
||||
to_label_id (bool): whether convert output to label_id for
|
||||
submission. Default: False
|
||||
indices (list[int], optional): Indices of input results, if not
|
||||
set, all the indices of the dataset will be used.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
tuple: (result_files, tmp_dir), result_files is a list containing
|
||||
the image paths, tmp_dir is the temporal directory created
|
||||
for saving json/png files when img_prefix is not specified.
|
||||
"""
|
||||
|
||||
if indices is None:
|
||||
indices = list(range(len(self)))
|
||||
|
||||
assert isinstance(results, list), 'results must be a list.'
|
||||
assert isinstance(indices, list), 'indices must be a list.'
|
||||
|
||||
result_files = self.results2img(results, imgfile_prefix, to_label_id,
|
||||
indices)
|
||||
return result_files
|
||||
|
|
|
@ -13,15 +13,14 @@ class ChaseDB1Dataset(CustomDataset):
|
|||
The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
|
||||
'_1stHO.png'.
|
||||
"""
|
||||
METAFILE = dict(
|
||||
classes=('background', 'vessel'),
|
||||
palette=[[120, 120, 120], [6, 230, 230]])
|
||||
|
||||
CLASSES = ('background', 'vessel')
|
||||
|
||||
PALETTE = [[120, 120, 120], [6, 230, 230]]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(ChaseDB1Dataset, self).__init__(
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix='.png',
|
||||
seg_map_suffix='_1stHO.png',
|
||||
reduce_zero_label=False,
|
||||
**kwargs)
|
||||
assert self.file_client.exists(self.img_dir)
|
||||
assert self.file_client.exists(self.data_prefix['img_list'])
|
||||
|
|
|
@ -1,11 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv.utils import print_log
|
||||
from PIL import Image
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import CustomDataset
|
||||
|
||||
|
@ -17,198 +10,21 @@ class CityscapesDataset(CustomDataset):
|
|||
The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is
|
||||
fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset.
|
||||
"""
|
||||
|
||||
CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
|
||||
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
|
||||
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
|
||||
'bicycle')
|
||||
|
||||
PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
|
||||
[190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
|
||||
[107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
|
||||
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100],
|
||||
[0, 80, 100], [0, 0, 230], [119, 11, 32]]
|
||||
METAINFO = dict(
|
||||
classes=('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
|
||||
'traffic light', 'traffic sign', 'vegetation', 'terrain',
|
||||
'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train',
|
||||
'motorcycle', 'bicycle'),
|
||||
palette=[[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
|
||||
[190, 153, 153], [153, 153, 153], [250, 170,
|
||||
30], [220, 220, 0],
|
||||
[107, 142, 35], [152, 251, 152], [70, 130, 180],
|
||||
[220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70],
|
||||
[0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]])
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='_leftImg8bit.png',
|
||||
seg_map_suffix='_gtFine_labelTrainIds.png',
|
||||
**kwargs):
|
||||
super(CityscapesDataset, self).__init__(
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_label_id(result):
|
||||
"""Convert trainId to id for cityscapes."""
|
||||
if isinstance(result, str):
|
||||
result = np.load(result)
|
||||
import cityscapesscripts.helpers.labels as CSLabels
|
||||
result_copy = result.copy()
|
||||
for trainId, label in CSLabels.trainId2label.items():
|
||||
result_copy[result == trainId] = label.id
|
||||
|
||||
return result_copy
|
||||
|
||||
def results2img(self, results, imgfile_prefix, to_label_id, indices=None):
|
||||
"""Write the segmentation results to images.
|
||||
|
||||
Args:
|
||||
results (list[ndarray]): Testing results of the
|
||||
dataset.
|
||||
imgfile_prefix (str): The filename prefix of the png files.
|
||||
If the prefix is "somepath/xxx",
|
||||
the png files will be named "somepath/xxx.png".
|
||||
to_label_id (bool): whether convert output to label_id for
|
||||
submission.
|
||||
indices (list[int], optional): Indices of input results,
|
||||
if not set, all the indices of the dataset will be used.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
list[str: str]: result txt files which contains corresponding
|
||||
semantic segmentation images.
|
||||
"""
|
||||
if indices is None:
|
||||
indices = list(range(len(self)))
|
||||
|
||||
mmcv.mkdir_or_exist(imgfile_prefix)
|
||||
result_files = []
|
||||
for result, idx in zip(results, indices):
|
||||
if to_label_id:
|
||||
result = self._convert_to_label_id(result)
|
||||
filename = self.img_infos[idx]['filename']
|
||||
basename = osp.splitext(osp.basename(filename))[0]
|
||||
|
||||
png_filename = osp.join(imgfile_prefix, f'{basename}.png')
|
||||
|
||||
output = Image.fromarray(result.astype(np.uint8)).convert('P')
|
||||
import cityscapesscripts.helpers.labels as CSLabels
|
||||
palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8)
|
||||
for label_id, label in CSLabels.id2label.items():
|
||||
palette[label_id] = label.color
|
||||
|
||||
output.putpalette(palette)
|
||||
output.save(png_filename)
|
||||
result_files.append(png_filename)
|
||||
|
||||
return result_files
|
||||
|
||||
def format_results(self,
|
||||
results,
|
||||
imgfile_prefix,
|
||||
to_label_id=True,
|
||||
indices=None):
|
||||
"""Format the results into dir (standard format for Cityscapes
|
||||
evaluation).
|
||||
|
||||
Args:
|
||||
results (list): Testing results of the dataset.
|
||||
imgfile_prefix (str): The prefix of images files. It
|
||||
includes the file path and the prefix of filename, e.g.,
|
||||
"a/b/prefix".
|
||||
to_label_id (bool): whether convert output to label_id for
|
||||
submission. Default: False
|
||||
indices (list[int], optional): Indices of input results,
|
||||
if not set, all the indices of the dataset will be used.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
tuple: (result_files, tmp_dir), result_files is a list containing
|
||||
the image paths, tmp_dir is the temporal directory created
|
||||
for saving json/png files when img_prefix is not specified.
|
||||
"""
|
||||
if indices is None:
|
||||
indices = list(range(len(self)))
|
||||
|
||||
assert isinstance(results, list), 'results must be a list.'
|
||||
assert isinstance(indices, list), 'indices must be a list.'
|
||||
|
||||
result_files = self.results2img(results, imgfile_prefix, to_label_id,
|
||||
indices)
|
||||
|
||||
return result_files
|
||||
|
||||
def evaluate(self,
|
||||
results,
|
||||
metric='mIoU',
|
||||
logger=None,
|
||||
imgfile_prefix=None):
|
||||
"""Evaluation in Cityscapes/default protocol.
|
||||
|
||||
Args:
|
||||
results (list): Testing results of the dataset.
|
||||
metric (str | list[str]): Metrics to be evaluated.
|
||||
logger (logging.Logger | None | str): Logger used for printing
|
||||
related information during evaluation. Default: None.
|
||||
imgfile_prefix (str | None): The prefix of output image file,
|
||||
for cityscapes evaluation only. It includes the file path and
|
||||
the prefix of filename, e.g., "a/b/prefix".
|
||||
If results are evaluated with cityscapes protocol, it would be
|
||||
the prefix of output png files. The output files would be
|
||||
png images under folder "a/b/prefix/xxx.png", where "xxx" is
|
||||
the image name of cityscapes. If not specified, a temp file
|
||||
will be created for evaluation.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
dict[str, float]: Cityscapes/default metrics.
|
||||
"""
|
||||
|
||||
eval_results = dict()
|
||||
metrics = metric.copy() if isinstance(metric, list) else [metric]
|
||||
if 'cityscapes' in metrics:
|
||||
eval_results.update(
|
||||
self._evaluate_cityscapes(results, logger, imgfile_prefix))
|
||||
metrics.remove('cityscapes')
|
||||
if len(metrics) > 0:
|
||||
eval_results.update(
|
||||
super(CityscapesDataset,
|
||||
self).evaluate(results, metrics, logger))
|
||||
|
||||
return eval_results
|
||||
|
||||
def _evaluate_cityscapes(self, results, logger, imgfile_prefix):
|
||||
"""Evaluation in Cityscapes protocol.
|
||||
|
||||
Args:
|
||||
results (list): Testing results of the dataset.
|
||||
logger (logging.Logger | str | None): Logger used for printing
|
||||
related information during evaluation. Default: None.
|
||||
imgfile_prefix (str | None): The prefix of output image file
|
||||
|
||||
Returns:
|
||||
dict[str: float]: Cityscapes evaluation results.
|
||||
"""
|
||||
try:
|
||||
import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa
|
||||
except ImportError:
|
||||
raise ImportError('Please run "pip install cityscapesscripts" to '
|
||||
'install cityscapesscripts first.')
|
||||
msg = 'Evaluating in Cityscapes style'
|
||||
if logger is None:
|
||||
msg = '\n' + msg
|
||||
print_log(msg, logger=logger)
|
||||
|
||||
result_dir = imgfile_prefix
|
||||
|
||||
eval_results = dict()
|
||||
print_log(f'Evaluating results under {result_dir} ...', logger=logger)
|
||||
|
||||
CSEval.args.evalInstLevelScore = True
|
||||
CSEval.args.predictionPath = osp.abspath(result_dir)
|
||||
CSEval.args.evalPixelAccuracy = True
|
||||
CSEval.args.JSONOutput = False
|
||||
|
||||
seg_map_list = []
|
||||
pred_list = []
|
||||
|
||||
# when evaluating with official cityscapesscripts,
|
||||
# **_gtFine_labelIds.png is used
|
||||
for seg_map in mmcv.scandir(
|
||||
self.ann_dir, 'gtFine_labelIds.png', recursive=True):
|
||||
seg_map_list.append(osp.join(self.ann_dir, seg_map))
|
||||
pred_list.append(CSEval.getPrediction(CSEval.args, seg_map))
|
||||
|
||||
eval_results.update(
|
||||
CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args))
|
||||
|
||||
return eval_results
|
||||
|
|
|
@ -14,81 +14,83 @@ class COCOStuffDataset(CustomDataset):
|
|||
10k and 164k versions, respectively. The ``img_suffix`` is fixed to '.jpg',
|
||||
and ``seg_map_suffix`` is fixed to '.png'.
|
||||
"""
|
||||
CLASSES = (
|
||||
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
|
||||
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
|
||||
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
|
||||
'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
|
||||
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
|
||||
'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
||||
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
|
||||
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
|
||||
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
|
||||
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
|
||||
'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
|
||||
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
|
||||
'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner',
|
||||
'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet',
|
||||
'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile',
|
||||
'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain',
|
||||
'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble',
|
||||
'floor-other', 'floor-stone', 'floor-tile', 'floor-wood',
|
||||
'flower', 'fog', 'food-other', 'fruit', 'furniture-other', 'grass',
|
||||
'gravel', 'ground-other', 'hill', 'house', 'leaves', 'light', 'mat',
|
||||
'metal', 'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net',
|
||||
'paper', 'pavement', 'pillow', 'plant-other', 'plastic', 'platform',
|
||||
'playingfield', 'railing', 'railroad', 'river', 'road', 'rock', 'roof',
|
||||
'rug', 'salad', 'sand', 'sea', 'shelf', 'sky-other', 'skyscraper',
|
||||
'snow', 'solid-other', 'stairs', 'stone', 'straw', 'structural-other',
|
||||
'table', 'tent', 'textile-other', 'towel', 'tree', 'vegetable',
|
||||
'wall-brick', 'wall-concrete', 'wall-other', 'wall-panel',
|
||||
'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'waterdrops',
|
||||
'window-blind', 'window-other', 'wood')
|
||||
METAINFO = dict(
|
||||
classes=(
|
||||
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
||||
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
|
||||
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
|
||||
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
|
||||
'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
|
||||
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
|
||||
'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
|
||||
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
|
||||
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
|
||||
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
|
||||
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
|
||||
'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
|
||||
'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
|
||||
'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner',
|
||||
'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet',
|
||||
'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile',
|
||||
'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain',
|
||||
'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble',
|
||||
'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', 'flower',
|
||||
'fog', 'food-other', 'fruit', 'furniture-other', 'grass', 'gravel',
|
||||
'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', 'metal',
|
||||
'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net',
|
||||
'paper', 'pavement', 'pillow', 'plant-other', 'plastic',
|
||||
'platform', 'playingfield', 'railing', 'railroad', 'river', 'road',
|
||||
'rock', 'roof', 'rug', 'salad', 'sand', 'sea', 'shelf',
|
||||
'sky-other', 'skyscraper', 'snow', 'solid-other', 'stairs',
|
||||
'stone', 'straw', 'structural-other', 'table', 'tent',
|
||||
'textile-other', 'towel', 'tree', 'vegetable', 'wall-brick',
|
||||
'wall-concrete', 'wall-other', 'wall-panel', 'wall-stone',
|
||||
'wall-tile', 'wall-wood', 'water-other', 'waterdrops',
|
||||
'window-blind', 'window-other', 'wood'),
|
||||
palette=[[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192],
|
||||
[0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64],
|
||||
[0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224],
|
||||
[0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192],
|
||||
[0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192],
|
||||
[128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128],
|
||||
[64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160],
|
||||
[0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0],
|
||||
[0, 128, 0], [192, 128, 32], [128, 96, 128], [0, 0, 128],
|
||||
[64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160],
|
||||
[0, 96, 128], [128, 128, 128], [64, 0, 160], [128, 224, 128],
|
||||
[128, 128, 64], [192, 0, 32], [128, 96, 0], [128, 0, 192],
|
||||
[0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160],
|
||||
[64, 96, 0], [0, 128, 192], [0, 128, 160], [192, 224, 0],
|
||||
[0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192],
|
||||
[0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160],
|
||||
[64, 32, 128], [128, 192, 192], [0, 0, 160], [192, 160, 128],
|
||||
[128, 192, 0], [128, 0, 96], [192, 32, 0], [128, 64, 128],
|
||||
[64, 128, 96], [64, 160, 0], [0, 64, 0], [192, 128, 224],
|
||||
[64, 32, 0], [0, 192, 128], [64, 128, 224], [192, 160, 0],
|
||||
[0, 192, 0], [192, 128, 96], [192, 96, 128], [0, 64, 128],
|
||||
[64, 0, 96], [64, 224, 128], [128, 64, 0], [192, 0, 224],
|
||||
[64, 96, 128], [128, 192, 128], [64, 0, 224], [192, 224, 128],
|
||||
[128, 192, 64], [192, 0, 96], [192, 96, 0], [128, 64, 192],
|
||||
[0, 128, 96], [0, 224, 0], [64, 64, 64], [128, 128, 224],
|
||||
[0, 96, 0], [64, 192, 192], [0, 128, 224], [128, 224, 0],
|
||||
[64, 192, 64], [128, 128, 96], [128, 32, 128], [64, 0, 192],
|
||||
[0, 64, 96], [0, 160, 128], [192, 0, 64], [128, 64, 224],
|
||||
[0, 32, 128], [192, 128, 192], [0, 64, 224], [128, 160, 128],
|
||||
[192, 128, 0], [128, 64, 32], [128, 32, 64], [192, 0, 128],
|
||||
[64, 192, 32], [0, 160, 64], [64, 0, 0], [192, 192, 160],
|
||||
[0, 32, 64], [64, 128, 128], [64, 192, 160], [128, 160, 64],
|
||||
[64, 128, 0], [192, 192, 32], [128, 96, 192], [64, 0, 128],
|
||||
[64, 64, 32], [0, 224, 192], [192, 0, 0], [192, 64, 160],
|
||||
[0, 96, 192], [192, 128, 128], [64, 64, 160], [128, 224, 192],
|
||||
[192, 128, 64], [192, 64, 32], [128, 96, 64], [192, 0, 192],
|
||||
[0, 192, 32], [64, 224, 64], [64, 0, 64], [128, 192, 160],
|
||||
[64, 96, 64], [64, 128, 192], [0, 192, 160], [192, 224, 64],
|
||||
[64, 128, 64], [128, 192, 32], [192, 32, 192], [64, 64, 192],
|
||||
[0, 64, 32], [64, 160, 192], [192, 64, 64], [128, 64, 160],
|
||||
[64, 32, 192], [192, 192, 192], [0, 64, 160], [192, 160, 192],
|
||||
[192, 192, 0], [128, 64, 96], [192, 32, 64], [192, 64, 128],
|
||||
[64, 192, 96], [64, 160, 64], [64, 64, 0]])
|
||||
|
||||
PALETTE = [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192],
|
||||
[0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64],
|
||||
[0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224],
|
||||
[0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192],
|
||||
[0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192],
|
||||
[128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128],
|
||||
[64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160],
|
||||
[0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0],
|
||||
[0, 128, 0], [192, 128, 32], [128, 96, 128], [0, 0, 128],
|
||||
[64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160],
|
||||
[0, 96, 128], [128, 128, 128], [64, 0, 160], [128, 224, 128],
|
||||
[128, 128, 64], [192, 0, 32], [128, 96, 0], [128, 0, 192],
|
||||
[0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160],
|
||||
[64, 96, 0], [0, 128, 192], [0, 128, 160], [192, 224, 0],
|
||||
[0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192],
|
||||
[0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160],
|
||||
[64, 32, 128], [128, 192, 192], [0, 0, 160], [192, 160, 128],
|
||||
[128, 192, 0], [128, 0, 96], [192, 32, 0], [128, 64, 128],
|
||||
[64, 128, 96], [64, 160, 0], [0, 64, 0], [192, 128, 224],
|
||||
[64, 32, 0], [0, 192, 128], [64, 128, 224], [192, 160, 0],
|
||||
[0, 192, 0], [192, 128, 96], [192, 96, 128], [0, 64, 128],
|
||||
[64, 0, 96], [64, 224, 128], [128, 64, 0], [192, 0, 224],
|
||||
[64, 96, 128], [128, 192, 128], [64, 0, 224], [192, 224, 128],
|
||||
[128, 192, 64], [192, 0, 96], [192, 96, 0], [128, 64, 192],
|
||||
[0, 128, 96], [0, 224, 0], [64, 64, 64], [128, 128, 224],
|
||||
[0, 96, 0], [64, 192, 192], [0, 128, 224], [128, 224, 0],
|
||||
[64, 192, 64], [128, 128, 96], [128, 32, 128], [64, 0, 192],
|
||||
[0, 64, 96], [0, 160, 128], [192, 0, 64], [128, 64, 224],
|
||||
[0, 32, 128], [192, 128, 192], [0, 64, 224], [128, 160, 128],
|
||||
[192, 128, 0], [128, 64, 32], [128, 32, 64], [192, 0, 128],
|
||||
[64, 192, 32], [0, 160, 64], [64, 0, 0], [192, 192, 160],
|
||||
[0, 32, 64], [64, 128, 128], [64, 192, 160], [128, 160, 64],
|
||||
[64, 128, 0], [192, 192, 32], [128, 96, 192], [64, 0, 128],
|
||||
[64, 64, 32], [0, 224, 192], [192, 0, 0], [192, 64, 160],
|
||||
[0, 96, 192], [192, 128, 128], [64, 64, 160], [128, 224, 192],
|
||||
[192, 128, 64], [192, 64, 32], [128, 96, 64], [192, 0, 192],
|
||||
[0, 192, 32], [64, 224, 64], [64, 0, 64], [128, 192, 160],
|
||||
[64, 96, 64], [64, 128, 192], [0, 192, 160], [192, 224, 64],
|
||||
[64, 128, 64], [128, 192, 32], [192, 32, 192], [64, 64, 192],
|
||||
[0, 64, 32], [64, 160, 192], [192, 64, 64], [128, 64, 160],
|
||||
[64, 32, 192], [192, 192, 192], [0, 64, 160], [192, 160, 192],
|
||||
[192, 192, 0], [128, 64, 96], [192, 32, 64], [192, 64, 128],
|
||||
[64, 192, 96], [64, 160, 64], [64, 64, 0]]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(COCOStuffDataset, self).__init__(
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix='.jpg', seg_map_suffix='_labelTrainIds.png', **kwargs)
|
||||
|
|
|
@ -1,22 +1,17 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv.utils import print_log
|
||||
from prettytable import PrettyTable
|
||||
from torch.utils.data import Dataset
|
||||
from mmengine.dataset import BaseDataset, Compose
|
||||
|
||||
from mmseg.core import eval_metrics, intersect_and_union, pre_eval_to_metrics
|
||||
from mmseg.registry import DATASETS
|
||||
from mmseg.utils import get_root_logger
|
||||
from .pipelines import Compose, LoadAnnotations
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class CustomDataset(Dataset):
|
||||
class CustomDataset(BaseDataset):
|
||||
"""Custom dataset for semantic segmentation. An example of file structure
|
||||
is as followed.
|
||||
|
||||
|
@ -46,442 +41,220 @@ class CustomDataset(Dataset):
|
|||
|
||||
|
||||
Args:
|
||||
pipeline (list[dict]): Processing pipeline
|
||||
img_dir (str): Path to image directory
|
||||
ann_file (str): Annotation file path. Defaults to ''.
|
||||
metainfo (dict, optional): Meta information for dataset, such as
|
||||
specify classes to load. Defaults to None.
|
||||
data_root (str, optional): The root directory for ``data_prefix`` and
|
||||
``ann_file``. Defaults to None.
|
||||
data_prefix (dict, optional): Prefix for training data. Defaults to
|
||||
dict(img_path=None, seg_path=None).
|
||||
img_suffix (str): Suffix of images. Default: '.jpg'
|
||||
ann_dir (str, optional): Path to annotation directory. Default: None
|
||||
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
|
||||
split (str, optional): Split txt file. If split is specified, only
|
||||
file with suffix in the splits will be loaded. Otherwise, all
|
||||
images in img_dir/ann_dir will be loaded. Default: None
|
||||
data_root (str, optional): Data root for img_dir/ann_dir. Default:
|
||||
None.
|
||||
test_mode (bool): If test_mode=True, gt wouldn't be loaded.
|
||||
filter_cfg (dict, optional): Config for filter data. Defaults to None.
|
||||
indices (int or Sequence[int], optional): Support using first few
|
||||
data in annotation file to facilitate training/testing on a smaller
|
||||
dataset. Defaults to None which means using all ``data_infos``.
|
||||
serialize_data (bool, optional): Whether to hold memory using
|
||||
serialized objects, when enabled, data loader workers can use
|
||||
shared RAM from master process instead of making a copy. Defaults
|
||||
to True.
|
||||
pipeline (list, optional): Processing pipeline. Defaults to [].
|
||||
test_mode (bool, optional): ``test_mode=True`` means in test phase.
|
||||
Defaults to False.
|
||||
lazy_init (bool, optional): Whether to load annotation during
|
||||
instantiation. In some cases, such as visualization, only the meta
|
||||
information of the dataset is needed, which is not necessary to
|
||||
load annotation file. ``Basedataset`` can skip load annotations to
|
||||
save time by set ``lazy_init=False``. Defaults to False.
|
||||
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
|
||||
None img. The maximum extra number of cycles to get a valid
|
||||
image. Defaults to 1000.
|
||||
ignore_index (int): The label index to be ignored. Default: 255
|
||||
reduce_zero_label (bool): Whether to mark label zero as ignored.
|
||||
Default: False
|
||||
classes (str | Sequence[str], optional): Specify classes to load.
|
||||
If is None, ``cls.CLASSES`` will be used. Default: None.
|
||||
palette (Sequence[Sequence[int]]] | np.ndarray | None):
|
||||
The palette of segmentation map. If None is given, and
|
||||
self.PALETTE is None, random palette will be generated.
|
||||
Default: None
|
||||
gt_seg_map_loader_cfg (dict, optional): build LoadAnnotations to
|
||||
load gt for evaluation, load from disk by default. Default: None.
|
||||
file_client_args (dict): Arguments to instantiate a FileClient.
|
||||
See :class:`mmcv.fileio.FileClient` for details.
|
||||
Defaults to ``dict(backend='disk')``.
|
||||
"""
|
||||
METAINFO: dict = dict()
|
||||
|
||||
CLASSES = None
|
||||
def __init__(
|
||||
self,
|
||||
ann_file: str = '',
|
||||
img_suffix='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
metainfo: Optional[dict] = None,
|
||||
data_root: Optional[str] = None,
|
||||
data_prefix: dict = dict(img_path=None, seg_map_path=None),
|
||||
filter_cfg: Optional[dict] = None,
|
||||
indices: Optional[Union[int, Sequence[int]]] = None,
|
||||
serialize_data: bool = True,
|
||||
pipeline: List[Union[dict, Callable]] = [],
|
||||
test_mode: bool = False,
|
||||
lazy_init: bool = False,
|
||||
max_refetch: int = 1000,
|
||||
ignore_index: int = 255,
|
||||
reduce_zero_label: bool = True,
|
||||
file_client_args: dict = dict(backend='disk')
|
||||
) -> None:
|
||||
|
||||
PALETTE = None
|
||||
|
||||
def __init__(self,
|
||||
pipeline,
|
||||
img_dir,
|
||||
img_suffix='.jpg',
|
||||
ann_dir=None,
|
||||
seg_map_suffix='.png',
|
||||
split=None,
|
||||
data_root=None,
|
||||
test_mode=False,
|
||||
ignore_index=255,
|
||||
reduce_zero_label=False,
|
||||
classes=None,
|
||||
palette=None,
|
||||
gt_seg_map_loader_cfg=None,
|
||||
file_client_args=dict(backend='disk')):
|
||||
self.pipeline = Compose(pipeline)
|
||||
self.img_dir = img_dir
|
||||
self.img_suffix = img_suffix
|
||||
self.ann_dir = ann_dir
|
||||
self.seg_map_suffix = seg_map_suffix
|
||||
self.split = split
|
||||
self.data_root = data_root
|
||||
self.test_mode = test_mode
|
||||
self.ignore_index = ignore_index
|
||||
self.reduce_zero_label = reduce_zero_label
|
||||
self.label_map = None
|
||||
self.CLASSES, self.PALETTE = self.get_classes_and_palette(
|
||||
classes, palette)
|
||||
self.gt_seg_map_loader = LoadAnnotations(
|
||||
) if gt_seg_map_loader_cfg is None else LoadAnnotations(
|
||||
**gt_seg_map_loader_cfg)
|
||||
|
||||
self.file_client_args = file_client_args
|
||||
self.file_client = mmcv.FileClient.infer_client(self.file_client_args)
|
||||
|
||||
self.data_root = data_root
|
||||
self.data_prefix = copy.copy(data_prefix)
|
||||
self.ann_file = ann_file
|
||||
self.filter_cfg = copy.deepcopy(filter_cfg)
|
||||
self._indices = indices
|
||||
self.serialize_data = serialize_data
|
||||
self.test_mode = test_mode
|
||||
self.max_refetch = max_refetch
|
||||
self.data_list: List[dict] = []
|
||||
self.data_bytes: np.ndarray
|
||||
|
||||
# Set meta information.
|
||||
self._metainfo = self._load_metainfo(copy.deepcopy(metainfo))
|
||||
|
||||
# Get label map for custom classes
|
||||
new_classes = self._metainfo.get('classes', None)
|
||||
self.label_map = self.get_label_map(new_classes)
|
||||
|
||||
# Update palette based on label map or generate palette
|
||||
# if it is not defined
|
||||
updated_palette = self._update_palette()
|
||||
self._metainfo.update(dict(palette=updated_palette))
|
||||
if test_mode:
|
||||
assert self.CLASSES is not None, \
|
||||
'`cls.CLASSES` or `classes` should be specified when testing'
|
||||
assert self._metainfo.get('classes') is not None, \
|
||||
'dataset metainfo `classes` should be specified when testing'
|
||||
|
||||
# join paths if data_root is specified
|
||||
# Join paths.
|
||||
if self.data_root is not None:
|
||||
if not osp.isabs(self.img_dir):
|
||||
self.img_dir = osp.join(self.data_root, self.img_dir)
|
||||
if not (self.ann_dir is None or osp.isabs(self.ann_dir)):
|
||||
self.ann_dir = osp.join(self.data_root, self.ann_dir)
|
||||
if not (self.split is None or osp.isabs(self.split)):
|
||||
self.split = osp.join(self.data_root, self.split)
|
||||
self._join_prefix()
|
||||
|
||||
# load annotations
|
||||
self.img_infos = self.load_annotations(self.img_dir, self.img_suffix,
|
||||
self.ann_dir,
|
||||
self.seg_map_suffix, self.split)
|
||||
# Build pipeline.
|
||||
self.pipeline = Compose(pipeline)
|
||||
# Full initialize the dataset.
|
||||
if not lazy_init:
|
||||
self.full_init()
|
||||
|
||||
def __len__(self):
|
||||
"""Total number of samples of data."""
|
||||
return len(self.img_infos)
|
||||
@classmethod
|
||||
def get_label_map(cls,
|
||||
new_classes: Optional[Sequence] = None
|
||||
) -> Union[Dict, None]:
|
||||
"""Require label mapping.
|
||||
|
||||
def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix,
|
||||
split):
|
||||
"""Load annotation from directory.
|
||||
The ``label_map`` is a dictionary, its keys are the old label ids and
|
||||
its values are the new label ids, and is used for changing pixel
|
||||
labels in load_annotations. If and only if old classes in cls.METAINFO
|
||||
is not equal to new classes in self._metainfo and nether of them is not
|
||||
None, `label_map` is not None.
|
||||
|
||||
Args:
|
||||
img_dir (str): Path to image directory
|
||||
img_suffix (str): Suffix of images.
|
||||
ann_dir (str|None): Path to annotation directory.
|
||||
seg_map_suffix (str|None): Suffix of segmentation maps.
|
||||
split (str|None): Split txt file. If split is specified, only file
|
||||
with suffix in the splits will be loaded. Otherwise, all images
|
||||
in img_dir/ann_dir will be loaded. Default: None
|
||||
new_classes (list, tuple, optional): The new classes name from
|
||||
metainfo. Default to None.
|
||||
|
||||
|
||||
Returns:
|
||||
list[dict]: All image info of dataset.
|
||||
dict, optional: The mapping from old classes in cls.METAINFO to
|
||||
new classes in self._metainfo
|
||||
"""
|
||||
old_classes = cls.METAINFO.get('classes', None)
|
||||
if (new_classes is not None and old_classes is not None
|
||||
and list(new_classes) != list(old_classes)):
|
||||
|
||||
img_infos = []
|
||||
if split is not None:
|
||||
lines = mmcv.list_from_file(
|
||||
split, file_client_args=self.file_client_args)
|
||||
for line in lines:
|
||||
img_name = line.strip()
|
||||
img_info = dict(filename=img_name + img_suffix)
|
||||
if ann_dir is not None:
|
||||
seg_map = img_name + seg_map_suffix
|
||||
img_info['ann'] = dict(seg_map=seg_map)
|
||||
img_infos.append(img_info)
|
||||
else:
|
||||
for img in self.file_client.list_dir_or_file(
|
||||
dir_path=img_dir,
|
||||
list_dir=False,
|
||||
suffix=img_suffix,
|
||||
recursive=True):
|
||||
img_info = dict(filename=img)
|
||||
if ann_dir is not None:
|
||||
seg_map = img.replace(img_suffix, seg_map_suffix)
|
||||
img_info['ann'] = dict(seg_map=seg_map)
|
||||
img_infos.append(img_info)
|
||||
img_infos = sorted(img_infos, key=lambda x: x['filename'])
|
||||
|
||||
print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger())
|
||||
return img_infos
|
||||
|
||||
def get_ann_info(self, idx):
|
||||
"""Get annotation by index.
|
||||
|
||||
Args:
|
||||
idx (int): Index of data.
|
||||
|
||||
Returns:
|
||||
dict: Annotation info of specified index.
|
||||
"""
|
||||
|
||||
return self.img_infos[idx]['ann']
|
||||
|
||||
def pre_pipeline(self, results):
|
||||
"""Prepare results dict for pipeline."""
|
||||
results['seg_fields'] = []
|
||||
results['img_prefix'] = self.img_dir
|
||||
results['seg_prefix'] = self.ann_dir
|
||||
if self.custom_classes:
|
||||
results['label_map'] = self.label_map
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Get training/test data after pipeline.
|
||||
|
||||
Args:
|
||||
idx (int): Index of data.
|
||||
|
||||
Returns:
|
||||
dict: Training/test data (with annotation if `test_mode` is set
|
||||
False).
|
||||
"""
|
||||
|
||||
if self.test_mode:
|
||||
return self.prepare_test_img(idx)
|
||||
else:
|
||||
return self.prepare_train_img(idx)
|
||||
|
||||
def prepare_train_img(self, idx):
|
||||
"""Get training data and annotations after pipeline.
|
||||
|
||||
Args:
|
||||
idx (int): Index of data.
|
||||
|
||||
Returns:
|
||||
dict: Training data and annotation after pipeline with new keys
|
||||
introduced by pipeline.
|
||||
"""
|
||||
|
||||
img_info = self.img_infos[idx]
|
||||
ann_info = self.get_ann_info(idx)
|
||||
results = dict(img_info=img_info, ann_info=ann_info)
|
||||
self.pre_pipeline(results)
|
||||
return self.pipeline(results)
|
||||
|
||||
def prepare_test_img(self, idx):
|
||||
"""Get testing data after pipeline.
|
||||
|
||||
Args:
|
||||
idx (int): Index of data.
|
||||
|
||||
Returns:
|
||||
dict: Testing data after pipeline with new keys introduced by
|
||||
pipeline.
|
||||
"""
|
||||
|
||||
img_info = self.img_infos[idx]
|
||||
results = dict(img_info=img_info)
|
||||
self.pre_pipeline(results)
|
||||
return self.pipeline(results)
|
||||
|
||||
def format_results(self, results, imgfile_prefix, indices=None, **kwargs):
|
||||
"""Place holder to format result to dataset specific output."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_gt_seg_map_by_idx(self, index):
|
||||
"""Get one ground truth segmentation map for evaluation."""
|
||||
ann_info = self.get_ann_info(index)
|
||||
results = dict(ann_info=ann_info)
|
||||
self.pre_pipeline(results)
|
||||
self.gt_seg_map_loader(results)
|
||||
return results['gt_semantic_seg']
|
||||
|
||||
def get_gt_seg_maps(self, efficient_test=None):
|
||||
"""Get ground truth segmentation maps for evaluation."""
|
||||
if efficient_test is not None:
|
||||
warnings.warn(
|
||||
'DeprecationWarning: ``efficient_test`` has been deprecated '
|
||||
'since MMSeg v0.16, the ``get_gt_seg_maps()`` is CPU memory '
|
||||
'friendly by default. ')
|
||||
|
||||
for idx in range(len(self)):
|
||||
ann_info = self.get_ann_info(idx)
|
||||
results = dict(ann_info=ann_info)
|
||||
self.pre_pipeline(results)
|
||||
self.gt_seg_map_loader(results)
|
||||
yield results['gt_semantic_seg']
|
||||
|
||||
def pre_eval(self, preds, indices):
|
||||
"""Collect eval result from each iteration.
|
||||
|
||||
Args:
|
||||
preds (list[torch.Tensor] | torch.Tensor): the segmentation logit
|
||||
after argmax, shape (N, H, W).
|
||||
indices (list[int] | int): the prediction related ground truth
|
||||
indices.
|
||||
|
||||
Returns:
|
||||
list[torch.Tensor]: (area_intersect, area_union, area_prediction,
|
||||
area_ground_truth).
|
||||
"""
|
||||
# In order to compat with batch inference
|
||||
if not isinstance(indices, list):
|
||||
indices = [indices]
|
||||
if not isinstance(preds, list):
|
||||
preds = [preds]
|
||||
|
||||
pre_eval_results = []
|
||||
|
||||
for pred, index in zip(preds, indices):
|
||||
seg_map = self.get_gt_seg_map_by_idx(index)
|
||||
pre_eval_results.append(
|
||||
intersect_and_union(
|
||||
pred,
|
||||
seg_map,
|
||||
len(self.CLASSES),
|
||||
self.ignore_index,
|
||||
# as the labels has been converted when dataset initialized
|
||||
# in `get_palette_for_custom_classes ` this `label_map`
|
||||
# should be `dict()`, see
|
||||
# https://github.com/open-mmlab/mmsegmentation/issues/1415
|
||||
# for more ditails
|
||||
label_map=dict(),
|
||||
reduce_zero_label=self.reduce_zero_label))
|
||||
|
||||
return pre_eval_results
|
||||
|
||||
def get_classes_and_palette(self, classes=None, palette=None):
|
||||
"""Get class names of current dataset.
|
||||
|
||||
Args:
|
||||
classes (Sequence[str] | str | None): If classes is None, use
|
||||
default CLASSES defined by builtin dataset. If classes is a
|
||||
string, take it as a file name. The file contains the name of
|
||||
classes where each line contains one class name. If classes is
|
||||
a tuple or list, override the CLASSES defined by the dataset.
|
||||
palette (Sequence[Sequence[int]]] | np.ndarray | None):
|
||||
The palette of segmentation map. If None is given, random
|
||||
palette will be generated. Default: None
|
||||
"""
|
||||
if classes is None:
|
||||
self.custom_classes = False
|
||||
return self.CLASSES, self.PALETTE
|
||||
|
||||
self.custom_classes = True
|
||||
if isinstance(classes, str):
|
||||
# take it as a file path
|
||||
class_names = mmcv.list_from_file(classes)
|
||||
elif isinstance(classes, (tuple, list)):
|
||||
class_names = classes
|
||||
else:
|
||||
raise ValueError(f'Unsupported type {type(classes)} of classes.')
|
||||
|
||||
if self.CLASSES:
|
||||
if not set(class_names).issubset(self.CLASSES):
|
||||
raise ValueError('classes is not a subset of CLASSES.')
|
||||
|
||||
# dictionary, its keys are the old label ids and its values
|
||||
# are the new label ids.
|
||||
# used for changing pixel labels in load_annotations.
|
||||
self.label_map = {}
|
||||
for i, c in enumerate(self.CLASSES):
|
||||
if c not in class_names:
|
||||
self.label_map[i] = -1
|
||||
label_map = {}
|
||||
if not set(new_classes).issubset(cls.METAINFO['classes']):
|
||||
raise ValueError(
|
||||
f'new classes {new_classes} is not a '
|
||||
f'subset of classes {old_classes} in METAINFO.')
|
||||
for i, c in enumerate(old_classes):
|
||||
if c not in new_classes:
|
||||
label_map[i] = -1
|
||||
else:
|
||||
self.label_map[i] = class_names.index(c)
|
||||
label_map[i] = new_classes.index(c)
|
||||
return label_map
|
||||
else:
|
||||
return None
|
||||
|
||||
palette = self.get_palette_for_custom_classes(class_names, palette)
|
||||
def _update_palette(self) -> list:
|
||||
"""Update palette after loading metainfo.
|
||||
|
||||
return class_names, palette
|
||||
If length of palette is equal to classes, just return the palette.
|
||||
If palette is not defined, it will randomly generate a palette.
|
||||
If classes is updated by customer, it will return the subset of
|
||||
palette.
|
||||
|
||||
def get_palette_for_custom_classes(self, class_names, palette=None):
|
||||
Returns:
|
||||
Sequence: Palette for current dataset.
|
||||
"""
|
||||
palette = self._metainfo.get('palette', [])
|
||||
classes = self._metainfo.get('classes', [])
|
||||
# palette does match classes
|
||||
if len(palette) == len(classes):
|
||||
return palette
|
||||
|
||||
if self.label_map is not None:
|
||||
if len(palette) == 0:
|
||||
# Get random state before set seed, and restore
|
||||
# random state later.
|
||||
# It will prevent loss of randomness, as the palette
|
||||
# may be different in each iteration if not specified.
|
||||
# See: https://github.com/open-mmlab/mmdetection/issues/5844
|
||||
state = np.random.get_state()
|
||||
np.random.seed(42)
|
||||
# random palette
|
||||
new_palette = np.random.randint(
|
||||
0, 255, size=(len(classes), 3)).tolist()
|
||||
np.random.set_state(state)
|
||||
elif len(palette) >= len(classes) and self.label_map is not None:
|
||||
new_palette = []
|
||||
# return subset of palette
|
||||
palette = []
|
||||
for old_id, new_id in sorted(
|
||||
self.label_map.items(), key=lambda x: x[1]):
|
||||
if new_id != -1:
|
||||
palette.append(self.PALETTE[old_id])
|
||||
palette = type(self.PALETTE)(palette)
|
||||
new_palette.append(palette[old_id])
|
||||
new_palette = type(palette)(new_palette)
|
||||
else:
|
||||
raise ValueError('palette does not match classes '
|
||||
f'as metainfo is {self._metainfo}.')
|
||||
return new_palette
|
||||
|
||||
elif palette is None:
|
||||
if self.PALETTE is None:
|
||||
# Get random state before set seed, and restore
|
||||
# random state later.
|
||||
# It will prevent loss of randomness, as the palette
|
||||
# may be different in each iteration if not specified.
|
||||
# See: https://github.com/open-mmlab/mmdetection/issues/5844
|
||||
state = np.random.get_state()
|
||||
np.random.seed(42)
|
||||
# random palette
|
||||
palette = np.random.randint(0, 255, size=(len(class_names), 3))
|
||||
np.random.set_state(state)
|
||||
else:
|
||||
palette = self.PALETTE
|
||||
|
||||
return palette
|
||||
|
||||
def evaluate(self,
|
||||
results,
|
||||
metric='mIoU',
|
||||
logger=None,
|
||||
gt_seg_maps=None,
|
||||
**kwargs):
|
||||
"""Evaluate the dataset.
|
||||
|
||||
Args:
|
||||
results (list[tuple[torch.Tensor]] | list[str]): per image pre_eval
|
||||
results or predict segmentation map for computing evaluation
|
||||
metric.
|
||||
metric (str | list[str]): Metrics to be evaluated. 'mIoU',
|
||||
'mDice' and 'mFscore' are supported.
|
||||
logger (logging.Logger | None | str): Logger used for printing
|
||||
related information during evaluation. Default: None.
|
||||
gt_seg_maps (generator[ndarray]): Custom gt seg maps as input,
|
||||
used in ConcatDataset
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotation from directory or annotation file.
|
||||
|
||||
Returns:
|
||||
dict[str, float]: Default metrics.
|
||||
list[dict]: All data info of dataset.
|
||||
"""
|
||||
if isinstance(metric, str):
|
||||
metric = [metric]
|
||||
allowed_metrics = ['mIoU', 'mDice', 'mFscore']
|
||||
if not set(metric).issubset(set(allowed_metrics)):
|
||||
raise KeyError('metric {} is not supported'.format(metric))
|
||||
|
||||
eval_results = {}
|
||||
# test a list of files
|
||||
if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of(
|
||||
results, str):
|
||||
if gt_seg_maps is None:
|
||||
gt_seg_maps = self.get_gt_seg_maps()
|
||||
num_classes = len(self.CLASSES)
|
||||
ret_metrics = eval_metrics(
|
||||
results,
|
||||
gt_seg_maps,
|
||||
num_classes,
|
||||
self.ignore_index,
|
||||
metric,
|
||||
label_map=dict(),
|
||||
reduce_zero_label=self.reduce_zero_label)
|
||||
# test a list of pre_eval_results
|
||||
data_list = []
|
||||
ann_dir = self.data_prefix.get('seg_map_path', None)
|
||||
if osp.isfile(self.ann_file):
|
||||
lines = mmcv.list_from_file(
|
||||
self.ann_file, file_client_args=self.file_client_args)
|
||||
for line in lines:
|
||||
img_name = line.strip()
|
||||
data_info = dict(img_path=img_name + self.img_suffix)
|
||||
if ann_dir is not None:
|
||||
seg_map = img_name + self.seg_map_suffix
|
||||
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['seg_field'] = []
|
||||
data_list.append(data_info)
|
||||
else:
|
||||
ret_metrics = pre_eval_to_metrics(results, metric)
|
||||
|
||||
# Because dataset.CLASSES is required for per-eval.
|
||||
if self.CLASSES is None:
|
||||
class_names = tuple(range(num_classes))
|
||||
else:
|
||||
class_names = self.CLASSES
|
||||
|
||||
# summary table
|
||||
ret_metrics_summary = OrderedDict({
|
||||
ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
|
||||
for ret_metric, ret_metric_value in ret_metrics.items()
|
||||
})
|
||||
|
||||
# each class table
|
||||
ret_metrics.pop('aAcc', None)
|
||||
ret_metrics_class = OrderedDict({
|
||||
ret_metric: np.round(ret_metric_value * 100, 2)
|
||||
for ret_metric, ret_metric_value in ret_metrics.items()
|
||||
})
|
||||
ret_metrics_class.update({'Class': class_names})
|
||||
ret_metrics_class.move_to_end('Class', last=False)
|
||||
|
||||
# for logger
|
||||
class_table_data = PrettyTable()
|
||||
for key, val in ret_metrics_class.items():
|
||||
class_table_data.add_column(key, val)
|
||||
|
||||
summary_table_data = PrettyTable()
|
||||
for key, val in ret_metrics_summary.items():
|
||||
if key == 'aAcc':
|
||||
summary_table_data.add_column(key, [val])
|
||||
else:
|
||||
summary_table_data.add_column('m' + key, [val])
|
||||
|
||||
print_log('per class results:', logger)
|
||||
print_log('\n' + class_table_data.get_string(), logger=logger)
|
||||
print_log('Summary:', logger)
|
||||
print_log('\n' + summary_table_data.get_string(), logger=logger)
|
||||
|
||||
# each metric dict
|
||||
for key, value in ret_metrics_summary.items():
|
||||
if key == 'aAcc':
|
||||
eval_results[key] = value / 100.0
|
||||
else:
|
||||
eval_results['m' + key] = value / 100.0
|
||||
|
||||
ret_metrics_class.pop('Class', None)
|
||||
for key, value in ret_metrics_class.items():
|
||||
eval_results.update({
|
||||
key + '.' + str(name): value[idx] / 100.0
|
||||
for idx, name in enumerate(class_names)
|
||||
})
|
||||
|
||||
return eval_results
|
||||
img_dir = self.data_prefix['img_path']
|
||||
for img in self.file_client.list_dir_or_file(
|
||||
dir_path=img_dir,
|
||||
list_dir=False,
|
||||
suffix=self.img_suffix,
|
||||
recursive=True):
|
||||
data_info = dict(img_path=osp.join(img_dir, img))
|
||||
if ann_dir is not None:
|
||||
seg_map = img.replace(self.img_suffix, self.seg_map_suffix)
|
||||
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['seg_field'] = []
|
||||
data_list.append(data_info)
|
||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||
return data_list
|
||||
|
|
|
@ -7,7 +7,7 @@ from .cityscapes import CityscapesDataset
|
|||
class DarkZurichDataset(CityscapesDataset):
|
||||
"""DarkZurichDataset dataset."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix='_rgb_anon.png',
|
||||
seg_map_suffix='_gt_labelTrainIds.png',
|
||||
|
|
|
@ -13,15 +13,14 @@ class DRIVEDataset(CustomDataset):
|
|||
``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
|
||||
'_manual1.png'.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('background', 'vessel'),
|
||||
palette=[[120, 120, 120], [6, 230, 230]])
|
||||
|
||||
CLASSES = ('background', 'vessel')
|
||||
|
||||
PALETTE = [[120, 120, 120], [6, 230, 230]]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(DRIVEDataset, self).__init__(
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix='.png',
|
||||
seg_map_suffix='_manual1.png',
|
||||
reduce_zero_label=False,
|
||||
**kwargs)
|
||||
assert self.file_client.exists(self.img_dir)
|
||||
assert self.file_client.exists(self.data_prefix['img_path'])
|
||||
|
|
|
@ -13,15 +13,14 @@ class HRFDataset(CustomDataset):
|
|||
``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
|
||||
'.png'.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('background', 'vessel'),
|
||||
palette=[[120, 120, 120], [6, 230, 230]])
|
||||
|
||||
CLASSES = ('background', 'vessel')
|
||||
|
||||
PALETTE = [[120, 120, 120], [6, 230, 230]]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(HRFDataset, self).__init__(
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix='.png',
|
||||
seg_map_suffix='.png',
|
||||
reduce_zero_label=False,
|
||||
**kwargs)
|
||||
assert self.file_client.exists(self.img_dir)
|
||||
assert self.file_client.exists(self.data_prefix['img_path'])
|
||||
|
|
|
@ -1,10 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import mmcv
|
||||
from mmcv.utils import print_log
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from ..utils import get_root_logger
|
||||
from .custom import CustomDataset
|
||||
|
||||
|
||||
|
@ -17,66 +12,21 @@ class iSAIDDataset(CustomDataset):
|
|||
'_manual1.png'.
|
||||
"""
|
||||
|
||||
CLASSES = ('background', 'ship', 'store_tank', 'baseball_diamond',
|
||||
'tennis_court', 'basketball_court', 'Ground_Track_Field',
|
||||
'Bridge', 'Large_Vehicle', 'Small_Vehicle', 'Helicopter',
|
||||
'Swimming_pool', 'Roundabout', 'Soccer_ball_field', 'plane',
|
||||
'Harbor')
|
||||
METAINFO = dict(
|
||||
classes=('background', 'ship', 'store_tank', 'baseball_diamond',
|
||||
'tennis_court', 'basketball_court', 'Ground_Track_Field',
|
||||
'Bridge', 'Large_Vehicle', 'Small_Vehicle', 'Helicopter',
|
||||
'Swimming_pool', 'Roundabout', 'Soccer_ball_field', 'plane',
|
||||
'Harbor'),
|
||||
palette=[[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127],
|
||||
[0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, 127],
|
||||
[0, 0, 127], [0, 0, 191], [0, 0, 255], [0, 191, 127],
|
||||
[0, 127, 191], [0, 127, 255], [0, 100, 155]])
|
||||
|
||||
PALETTE = [[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127],
|
||||
[0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, 127],
|
||||
[0, 0, 127], [0, 0, 191], [0, 0, 255], [0, 191, 127],
|
||||
[0, 127, 191], [0, 127, 255], [0, 100, 155]]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(iSAIDDataset, self).__init__(
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix='.png',
|
||||
seg_map_suffix='.png',
|
||||
seg_map_suffix='_instance_color_RGB.png',
|
||||
ignore_index=255,
|
||||
**kwargs)
|
||||
assert self.file_client.exists(self.img_dir)
|
||||
|
||||
def load_annotations(self,
|
||||
img_dir,
|
||||
img_suffix,
|
||||
ann_dir,
|
||||
seg_map_suffix=None,
|
||||
split=None):
|
||||
"""Load annotation from directory.
|
||||
|
||||
Args:
|
||||
img_dir (str): Path to image directory
|
||||
img_suffix (str): Suffix of images.
|
||||
ann_dir (str|None): Path to annotation directory.
|
||||
seg_map_suffix (str|None): Suffix of segmentation maps.
|
||||
split (str|None): Split txt file. If split is specified, only file
|
||||
with suffix in the splits will be loaded. Otherwise, all images
|
||||
in img_dir/ann_dir will be loaded. Default: None
|
||||
|
||||
Returns:
|
||||
list[dict]: All image info of dataset.
|
||||
"""
|
||||
|
||||
img_infos = []
|
||||
if split is not None:
|
||||
with open(split) as f:
|
||||
for line in f:
|
||||
name = line.strip()
|
||||
img_info = dict(filename=name + img_suffix)
|
||||
if ann_dir is not None:
|
||||
ann_name = name + '_instance_color_RGB'
|
||||
seg_map = ann_name + seg_map_suffix
|
||||
img_info['ann'] = dict(seg_map=seg_map)
|
||||
img_infos.append(img_info)
|
||||
else:
|
||||
for img in mmcv.scandir(img_dir, img_suffix, recursive=True):
|
||||
img_info = dict(filename=img)
|
||||
if ann_dir is not None:
|
||||
seg_img = img
|
||||
seg_map = seg_img.replace(
|
||||
img_suffix, '_instance_color_RGB' + seg_map_suffix)
|
||||
img_info['ann'] = dict(seg_map=seg_map)
|
||||
img_infos.append(img_info)
|
||||
|
||||
print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger())
|
||||
return img_infos
|
||||
assert self.file_client.exists(self.data_prefix['img_path'])
|
||||
|
|
|
@ -11,14 +11,14 @@ class ISPRSDataset(CustomDataset):
|
|||
``reduce_zero_label`` should be set to True. The ``img_suffix`` and
|
||||
``seg_map_suffix`` are both fixed to '.png'.
|
||||
"""
|
||||
CLASSES = ('impervious_surface', 'building', 'low_vegetation', 'tree',
|
||||
'car', 'clutter')
|
||||
METAINFO = dict(
|
||||
classes=('impervious_surface', 'building', 'low_vegetation', 'tree',
|
||||
'car', 'clutter'),
|
||||
palette=[[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0],
|
||||
[255, 255, 0], [255, 0, 0]])
|
||||
|
||||
PALETTE = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0],
|
||||
[255, 255, 0], [255, 0, 0]]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(ISPRSDataset, self).__init__(
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix='.png',
|
||||
seg_map_suffix='.png',
|
||||
reduce_zero_label=True,
|
||||
|
|
|
@ -1,10 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import CustomDataset
|
||||
|
||||
|
@ -17,76 +11,15 @@ class LoveDADataset(CustomDataset):
|
|||
``reduce_zero_label`` should be set to True. The ``img_suffix`` and
|
||||
``seg_map_suffix`` are both fixed to '.png'.
|
||||
"""
|
||||
CLASSES = ('background', 'building', 'road', 'water', 'barren', 'forest',
|
||||
'agricultural')
|
||||
METAINFO = dict(
|
||||
classes=('background', 'building', 'road', 'water', 'barren', 'forest',
|
||||
'agricultural'),
|
||||
palette=[[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255],
|
||||
[159, 129, 183], [0, 255, 0], [255, 195, 128]])
|
||||
|
||||
PALETTE = [[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255],
|
||||
[159, 129, 183], [0, 255, 0], [255, 195, 128]]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(LoveDADataset, self).__init__(
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix='.png',
|
||||
seg_map_suffix='.png',
|
||||
reduce_zero_label=True,
|
||||
**kwargs)
|
||||
|
||||
def results2img(self, results, imgfile_prefix, indices=None):
|
||||
"""Write the segmentation results to images.
|
||||
|
||||
Args:
|
||||
results (list[ndarray]): Testing results of the
|
||||
dataset.
|
||||
imgfile_prefix (str): The filename prefix of the png files.
|
||||
If the prefix is "somepath/xxx",
|
||||
the png files will be named "somepath/xxx.png".
|
||||
indices (list[int], optional): Indices of input results, if not
|
||||
set, all the indices of the dataset will be used.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
list[str: str]: result txt files which contains corresponding
|
||||
semantic segmentation images.
|
||||
"""
|
||||
|
||||
mmcv.mkdir_or_exist(imgfile_prefix)
|
||||
result_files = []
|
||||
for result, idx in zip(results, indices):
|
||||
|
||||
filename = self.img_infos[idx]['filename']
|
||||
basename = osp.splitext(osp.basename(filename))[0]
|
||||
|
||||
png_filename = osp.join(imgfile_prefix, f'{basename}.png')
|
||||
|
||||
# The index range of official requirement is from 0 to 6.
|
||||
output = Image.fromarray(result.astype(np.uint8))
|
||||
output.save(png_filename)
|
||||
result_files.append(png_filename)
|
||||
|
||||
return result_files
|
||||
|
||||
def format_results(self, results, imgfile_prefix, indices=None):
|
||||
"""Format the results into dir (standard format for LoveDA evaluation).
|
||||
|
||||
Args:
|
||||
results (list): Testing results of the dataset.
|
||||
imgfile_prefix (str): The prefix of images files. It
|
||||
includes the file path and the prefix of filename, e.g.,
|
||||
"a/b/prefix".
|
||||
indices (list[int], optional): Indices of input results,
|
||||
if not set, all the indices of the dataset will be used.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
tuple: (result_files, tmp_dir), result_files is a list containing
|
||||
the image paths, tmp_dir is the temporal directory created
|
||||
for saving json/png files when img_prefix is not specified.
|
||||
"""
|
||||
if indices is None:
|
||||
indices = list(range(len(self)))
|
||||
|
||||
assert isinstance(results, list), 'results must be a list.'
|
||||
assert isinstance(indices, list), 'indices must be a list.'
|
||||
|
||||
result_files = self.results2img(results, imgfile_prefix, indices)
|
||||
|
||||
return result_files
|
||||
|
|
|
@ -7,7 +7,7 @@ from .cityscapes import CityscapesDataset
|
|||
class NightDrivingDataset(CityscapesDataset):
|
||||
"""NightDrivingDataset dataset."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix='_leftImg8bit.png',
|
||||
seg_map_suffix='_gtCoarse_labelTrainIds.png',
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import CustomDataset
|
||||
|
@ -14,44 +15,45 @@ class PascalContextDataset(CustomDataset):
|
|||
fixed to '.png'.
|
||||
|
||||
Args:
|
||||
split (str): Split txt file for PascalContext.
|
||||
ann_file (str): Annotation file path.
|
||||
"""
|
||||
|
||||
CLASSES = ('background', 'aeroplane', 'bag', 'bed', 'bedclothes', 'bench',
|
||||
'bicycle', 'bird', 'boat', 'book', 'bottle', 'building', 'bus',
|
||||
'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth',
|
||||
'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence',
|
||||
'floor', 'flower', 'food', 'grass', 'ground', 'horse',
|
||||
'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person',
|
||||
'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep',
|
||||
'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table',
|
||||
'track', 'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water',
|
||||
'window', 'wood')
|
||||
METAINFO = dict(
|
||||
classes=('background', 'aeroplane', 'bag', 'bed', 'bedclothes',
|
||||
'bench', 'bicycle', 'bird', 'boat', 'book', 'bottle',
|
||||
'building', 'bus', 'cabinet', 'car', 'cat', 'ceiling',
|
||||
'chair', 'cloth', 'computer', 'cow', 'cup', 'curtain', 'dog',
|
||||
'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground',
|
||||
'horse', 'keyboard', 'light', 'motorbike', 'mountain',
|
||||
'mouse', 'person', 'plate', 'platform', 'pottedplant', 'road',
|
||||
'rock', 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow',
|
||||
'sofa', 'table', 'track', 'train', 'tree', 'truck',
|
||||
'tvmonitor', 'wall', 'water', 'window', 'wood'),
|
||||
palette=[[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
|
||||
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
|
||||
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
|
||||
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
|
||||
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
|
||||
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
|
||||
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
|
||||
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
|
||||
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
|
||||
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
|
||||
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
|
||||
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
|
||||
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
|
||||
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
|
||||
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]])
|
||||
|
||||
PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
|
||||
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
|
||||
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
|
||||
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
|
||||
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
|
||||
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
|
||||
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
|
||||
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
|
||||
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
|
||||
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
|
||||
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
|
||||
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
|
||||
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
|
||||
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
|
||||
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]]
|
||||
|
||||
def __init__(self, split, **kwargs):
|
||||
super(PascalContextDataset, self).__init__(
|
||||
def __init__(self, ann_file: str, **kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
split=split,
|
||||
ann_file=ann_file,
|
||||
reduce_zero_label=False,
|
||||
**kwargs)
|
||||
assert self.file_client.exists(self.img_dir) and self.split is not None
|
||||
assert self.file_client.exists(
|
||||
self.data_prefix['img_path']) and osp.isfile(self.ann_file)
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
@ -64,40 +66,41 @@ class PascalContextDataset59(CustomDataset):
|
|||
fixed to '.png'.
|
||||
|
||||
Args:
|
||||
split (str): Split txt file for PascalContext.
|
||||
ann_file (str): Annotation file path.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle',
|
||||
'bird', 'boat', 'book', 'bottle', 'building', 'bus',
|
||||
'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth',
|
||||
'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence',
|
||||
'floor', 'flower', 'food', 'grass', 'ground', 'horse',
|
||||
'keyboard', 'light', 'motorbike', 'mountain', 'mouse',
|
||||
'person', 'plate', 'platform', 'pottedplant', 'road', 'rock',
|
||||
'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa',
|
||||
'table', 'track', 'train', 'tree', 'truck', 'tvmonitor',
|
||||
'wall', 'water', 'window', 'wood'),
|
||||
palette=[[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3],
|
||||
[120, 120, 80], [140, 140, 140], [204, 5, 255],
|
||||
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
|
||||
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
|
||||
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
|
||||
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
|
||||
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
|
||||
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
|
||||
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
|
||||
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
|
||||
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
|
||||
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
|
||||
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
|
||||
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
|
||||
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]])
|
||||
|
||||
CLASSES = ('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle',
|
||||
'bird', 'boat', 'book', 'bottle', 'building', 'bus', 'cabinet',
|
||||
'car', 'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow',
|
||||
'cup', 'curtain', 'dog', 'door', 'fence', 'floor', 'flower',
|
||||
'food', 'grass', 'ground', 'horse', 'keyboard', 'light',
|
||||
'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform',
|
||||
'pottedplant', 'road', 'rock', 'sheep', 'shelves', 'sidewalk',
|
||||
'sign', 'sky', 'snow', 'sofa', 'table', 'track', 'train',
|
||||
'tree', 'truck', 'tvmonitor', 'wall', 'water', 'window', 'wood')
|
||||
|
||||
PALETTE = [[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3],
|
||||
[120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230],
|
||||
[4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61],
|
||||
[120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140],
|
||||
[204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200],
|
||||
[61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71],
|
||||
[255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92],
|
||||
[112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6],
|
||||
[10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8],
|
||||
[102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8],
|
||||
[0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255],
|
||||
[235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140],
|
||||
[250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0],
|
||||
[255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0],
|
||||
[0, 235, 255], [0, 173, 255], [31, 0, 255]]
|
||||
|
||||
def __init__(self, split, **kwargs):
|
||||
super(PascalContextDataset59, self).__init__(
|
||||
def __init__(self, ann_file: str, **kwargs):
|
||||
super().__init__(
|
||||
img_suffix='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
split=split,
|
||||
ann_file=ann_file,
|
||||
reduce_zero_label=True,
|
||||
**kwargs)
|
||||
assert self.file_client.exists(self.img_dir) and self.split is not None
|
||||
assert self.file_client.exists(
|
||||
self.data_prefix['img_path']) and osp.isfile(self.ann_file)
|
||||
|
|
|
@ -11,14 +11,14 @@ class PotsdamDataset(CustomDataset):
|
|||
``reduce_zero_label`` should be set to True. The ``img_suffix`` and
|
||||
``seg_map_suffix`` are both fixed to '.png'.
|
||||
"""
|
||||
CLASSES = ('impervious_surface', 'building', 'low_vegetation', 'tree',
|
||||
'car', 'clutter')
|
||||
METAINFO = dict(
|
||||
classes=('impervious_surface', 'building', 'low_vegetation', 'tree',
|
||||
'car', 'clutter'),
|
||||
palette=[[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0],
|
||||
[255, 255, 0], [255, 0, 0]])
|
||||
|
||||
PALETTE = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0],
|
||||
[255, 255, 0], [255, 0, 0]]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(PotsdamDataset, self).__init__(
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix='.png',
|
||||
seg_map_suffix='.png',
|
||||
reduce_zero_label=True,
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import CustomDataset
|
||||
|
||||
|
@ -14,15 +12,14 @@ class STAREDataset(CustomDataset):
|
|||
``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
|
||||
'.ah.png'.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('background', 'vessel'),
|
||||
palette=[[120, 120, 120], [6, 230, 230]])
|
||||
|
||||
CLASSES = ('background', 'vessel')
|
||||
|
||||
PALETTE = [[120, 120, 120], [6, 230, 230]]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(STAREDataset, self).__init__(
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix='.png',
|
||||
seg_map_suffix='.ah.png',
|
||||
reduce_zero_label=False,
|
||||
**kwargs)
|
||||
assert osp.exists(self.img_dir)
|
||||
assert self.file_client.exists(self.data_prefix['img_path'])
|
||||
|
|
|
@ -12,19 +12,23 @@ class PascalVOCDataset(CustomDataset):
|
|||
Args:
|
||||
split (str): Split txt file for Pascal VOC.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('background', 'aeroplane', 'bicycle', 'bird', 'boat',
|
||||
'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',
|
||||
'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep',
|
||||
'sofa', 'train', 'tvmonitor'),
|
||||
palette=[[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
|
||||
[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
|
||||
[64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
|
||||
[64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
|
||||
[0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
|
||||
[0, 64, 128]])
|
||||
|
||||
CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
|
||||
'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
|
||||
'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa',
|
||||
'train', 'tvmonitor')
|
||||
|
||||
PALETTE = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
|
||||
[128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
|
||||
[192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
|
||||
[192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
|
||||
[128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
|
||||
|
||||
def __init__(self, split, **kwargs):
|
||||
super(PascalVOCDataset, self).__init__(
|
||||
img_suffix='.jpg', seg_map_suffix='.png', split=split, **kwargs)
|
||||
assert osp.exists(self.img_dir) and self.split is not None
|
||||
def __init__(self, ann_file, **kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
ann_file=ann_file,
|
||||
**kwargs)
|
||||
assert self.file_client.exists(
|
||||
self.data_prefix['img_path']) and osp.isfile(self.ann_file)
|
||||
|
|
|
@ -0,0 +1,353 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from mmseg.core.evaluation import get_classes, get_palette
|
||||
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
|
||||
COCOStuffDataset, CustomDataset, ISPRSDataset,
|
||||
LoveDADataset, PascalVOCDataset, PotsdamDataset,
|
||||
iSAIDDataset)
|
||||
|
||||
|
||||
def test_classes():
|
||||
assert list(
|
||||
CityscapesDataset.METAINFO['classes']) == get_classes('cityscapes')
|
||||
assert list(PascalVOCDataset.METAINFO['classes']) == get_classes(
|
||||
'voc') == get_classes('pascal_voc')
|
||||
assert list(ADE20KDataset.METAINFO['classes']) == get_classes(
|
||||
'ade') == get_classes('ade20k')
|
||||
assert list(
|
||||
COCOStuffDataset.METAINFO['classes']) == get_classes('cocostuff')
|
||||
assert list(LoveDADataset.METAINFO['classes']) == get_classes('loveda')
|
||||
assert list(PotsdamDataset.METAINFO['classes']) == get_classes('potsdam')
|
||||
assert list(ISPRSDataset.METAINFO['classes']) == get_classes('vaihingen')
|
||||
assert list(iSAIDDataset.METAINFO['classes']) == get_classes('isaid')
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
get_classes('unsupported')
|
||||
|
||||
|
||||
def test_classes_file_path():
|
||||
tmp_file = tempfile.NamedTemporaryFile()
|
||||
classes_path = f'{tmp_file.name}.txt'
|
||||
train_pipeline = []
|
||||
kwargs = dict(
|
||||
pipeline=train_pipeline,
|
||||
data_prefix=dict(img_path='./', seg_map_path='./'),
|
||||
metainfo=dict(classes=classes_path))
|
||||
|
||||
# classes.txt with full categories
|
||||
categories = get_classes('cityscapes')
|
||||
with open(classes_path, 'w') as f:
|
||||
f.write('\n'.join(categories))
|
||||
dataset = CityscapesDataset(**kwargs)
|
||||
assert list(dataset.metainfo['classes']) == categories
|
||||
assert dataset.label_map is None
|
||||
|
||||
# classes.txt with sub categories
|
||||
categories = ['road', 'sidewalk', 'building']
|
||||
with open(classes_path, 'w') as f:
|
||||
f.write('\n'.join(categories))
|
||||
dataset = CityscapesDataset(**kwargs)
|
||||
assert list(dataset.metainfo['classes']) == categories
|
||||
assert dataset.label_map is not None
|
||||
|
||||
# classes.txt with unknown categories
|
||||
categories = ['road', 'sidewalk', 'unknown']
|
||||
with open(classes_path, 'w') as f:
|
||||
f.write('\n'.join(categories))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
CityscapesDataset(**kwargs)
|
||||
|
||||
tmp_file.close()
|
||||
os.remove(classes_path)
|
||||
assert not osp.exists(classes_path)
|
||||
|
||||
|
||||
def test_palette():
|
||||
assert CityscapesDataset.METAINFO['palette'] == get_palette('cityscapes')
|
||||
assert PascalVOCDataset.METAINFO['palette'] == get_palette(
|
||||
'voc') == get_palette('pascal_voc')
|
||||
assert ADE20KDataset.METAINFO['palette'] == get_palette(
|
||||
'ade') == get_palette('ade20k')
|
||||
assert LoveDADataset.METAINFO['palette'] == get_palette('loveda')
|
||||
assert PotsdamDataset.METAINFO['palette'] == get_palette('potsdam')
|
||||
assert COCOStuffDataset.METAINFO['palette'] == get_palette('cocostuff')
|
||||
assert iSAIDDataset.METAINFO['palette'] == get_palette('isaid')
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
get_palette('unsupported')
|
||||
|
||||
|
||||
def test_custom_dataset():
|
||||
|
||||
# with 'img_path' and 'seg_map_path' in data_prefix
|
||||
train_dataset = CustomDataset(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
|
||||
data_prefix=dict(
|
||||
img_path='imgs/',
|
||||
seg_map_path='gts/',
|
||||
),
|
||||
img_suffix='img.jpg',
|
||||
seg_map_suffix='gt.png')
|
||||
assert len(train_dataset) == 5
|
||||
|
||||
# with 'img_path' and 'seg_map_path' in data_prefix and ann_file
|
||||
train_dataset = CustomDataset(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
|
||||
data_prefix=dict(
|
||||
img_path='imgs/',
|
||||
seg_map_path='gts/',
|
||||
),
|
||||
img_suffix='img.jpg',
|
||||
seg_map_suffix='gt.png',
|
||||
ann_file='splits/train.txt')
|
||||
assert len(train_dataset) == 4
|
||||
|
||||
# no data_root
|
||||
train_dataset = CustomDataset(
|
||||
data_prefix=dict(
|
||||
img_path=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
|
||||
seg_map_path=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_dataset/gts')),
|
||||
img_suffix='img.jpg',
|
||||
seg_map_suffix='gt.png')
|
||||
assert len(train_dataset) == 5
|
||||
|
||||
# with data_root but 'img_path' and 'seg_map_path' in data_prefix are
|
||||
# abs path
|
||||
train_dataset = CustomDataset(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
|
||||
data_prefix=dict(
|
||||
img_path=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
|
||||
seg_map_path=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_dataset/gts')),
|
||||
img_suffix='img.jpg',
|
||||
seg_map_suffix='gt.png')
|
||||
assert len(train_dataset) == 5
|
||||
|
||||
# test_mode=True
|
||||
test_dataset = CustomDataset(
|
||||
data_prefix=dict(
|
||||
img_path=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_dataset/imgs')),
|
||||
img_suffix='img.jpg',
|
||||
test_mode=True,
|
||||
metainfo=dict(classes=('pseudo_class', )))
|
||||
assert len(test_dataset) == 5
|
||||
|
||||
# training data get
|
||||
train_data = train_dataset[0]
|
||||
assert isinstance(train_data, dict)
|
||||
assert 'img_path' in train_data and osp.isfile(train_data['img_path'])
|
||||
assert 'seg_map_path' in train_data and osp.isfile(
|
||||
train_data['seg_map_path'])
|
||||
|
||||
# test data get
|
||||
test_data = test_dataset[0]
|
||||
assert isinstance(test_data, dict)
|
||||
assert 'img_path' in train_data and osp.isfile(train_data['img_path'])
|
||||
assert 'seg_map_path' in train_data and osp.isfile(
|
||||
train_data['seg_map_path'])
|
||||
|
||||
|
||||
def test_ade():
|
||||
test_dataset = ADE20KDataset(
|
||||
pipeline=[],
|
||||
data_prefix=dict(
|
||||
img_path=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_dataset/imgs')))
|
||||
assert len(test_dataset) == 5
|
||||
|
||||
|
||||
def test_cityscapes():
|
||||
test_dataset = CityscapesDataset(
|
||||
pipeline=[],
|
||||
data_prefix=dict(
|
||||
img_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_cityscapes_dataset/leftImg8bit'),
|
||||
seg_map_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_cityscapes_dataset/gtFine')))
|
||||
assert len(test_dataset) == 1
|
||||
|
||||
|
||||
def test_loveda():
|
||||
test_dataset = LoveDADataset(
|
||||
pipeline=[],
|
||||
data_prefix=dict(
|
||||
img_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_loveda_dataset/img_dir'),
|
||||
seg_map_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_loveda_dataset/ann_dir')))
|
||||
assert len(test_dataset) == 3
|
||||
|
||||
|
||||
def test_potsdam():
|
||||
test_dataset = PotsdamDataset(
|
||||
pipeline=[],
|
||||
data_prefix=dict(
|
||||
img_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_potsdam_dataset/img_dir'),
|
||||
seg_map_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_potsdam_dataset/ann_dir')))
|
||||
assert len(test_dataset) == 1
|
||||
|
||||
|
||||
def test_vaihingen():
|
||||
test_dataset = ISPRSDataset(
|
||||
pipeline=[],
|
||||
data_prefix=dict(
|
||||
img_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_vaihingen_dataset/img_dir'),
|
||||
seg_map_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_vaihingen_dataset/ann_dir')))
|
||||
assert len(test_dataset) == 1
|
||||
|
||||
|
||||
def test_isaid():
|
||||
test_dataset = iSAIDDataset(
|
||||
pipeline=[],
|
||||
data_prefix=dict(
|
||||
img_path=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_isaid_dataset/img_dir'),
|
||||
seg_map_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_isaid_dataset/ann_dir')))
|
||||
assert len(test_dataset) == 2
|
||||
test_dataset = iSAIDDataset(
|
||||
data_prefix=dict(
|
||||
img_path=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_isaid_dataset/img_dir'),
|
||||
seg_map_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_isaid_dataset/ann_dir')),
|
||||
ann_file=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_isaid_dataset/splits/train.txt'))
|
||||
assert len(test_dataset) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dataset, classes', [
|
||||
('ADE20KDataset', ('wall', 'building')),
|
||||
('CityscapesDataset', ('road', 'sidewalk')),
|
||||
('CustomDataset', ('bus', 'car')),
|
||||
('PascalVOCDataset', ('aeroplane', 'bicycle')),
|
||||
])
|
||||
def test_custom_classes_override_default(dataset, classes):
|
||||
|
||||
dataset_class = DATASETS.get(dataset)
|
||||
if isinstance(dataset_class, PascalVOCDataset):
|
||||
tmp_file = tempfile.NamedTemporaryFile()
|
||||
ann_file = f'{tmp_file.name}.txt'
|
||||
else:
|
||||
ann_file = MagicMock()
|
||||
|
||||
original_classes = dataset_class.METAINFO.get('classes', None)
|
||||
|
||||
# Test setting classes as a tuple
|
||||
custom_dataset = dataset_class(
|
||||
data_prefix=dict(img_path=MagicMock()),
|
||||
ann_file=ann_file,
|
||||
metainfo=dict(classes=classes),
|
||||
test_mode=True,
|
||||
lazy_init=True)
|
||||
|
||||
assert custom_dataset.metainfo['classes'] != original_classes
|
||||
assert custom_dataset.metainfo['classes'] == classes
|
||||
if not isinstance(custom_dataset, CustomDataset):
|
||||
assert isinstance(custom_dataset.label_map, dict)
|
||||
|
||||
# Test setting classes as a list
|
||||
custom_dataset = dataset_class(
|
||||
data_prefix=dict(img_path=MagicMock()),
|
||||
ann_file=ann_file,
|
||||
metainfo=dict(classes=list(classes)),
|
||||
test_mode=True,
|
||||
lazy_init=True)
|
||||
|
||||
assert custom_dataset.metainfo['classes'] != original_classes
|
||||
assert custom_dataset.metainfo['classes'] == list(classes)
|
||||
if not isinstance(custom_dataset, CustomDataset):
|
||||
assert isinstance(custom_dataset.label_map, dict)
|
||||
|
||||
# Test overriding not a subset
|
||||
custom_dataset = dataset_class(
|
||||
ann_file=ann_file,
|
||||
data_prefix=dict(img_path=MagicMock()),
|
||||
metainfo=dict(classes=[classes[0]]),
|
||||
test_mode=True,
|
||||
lazy_init=True)
|
||||
|
||||
assert custom_dataset.metainfo['classes'] != original_classes
|
||||
assert custom_dataset.metainfo['classes'] == [classes[0]]
|
||||
if not isinstance(custom_dataset, CustomDataset):
|
||||
assert isinstance(custom_dataset.label_map, dict)
|
||||
|
||||
# Test default behavior
|
||||
if dataset_class is CustomDataset:
|
||||
with pytest.raises(AssertionError):
|
||||
custom_dataset = dataset_class(
|
||||
ann_file=ann_file,
|
||||
data_prefix=dict(img_path=MagicMock()),
|
||||
metainfo=None,
|
||||
test_mode=True,
|
||||
lazy_init=True)
|
||||
else:
|
||||
custom_dataset = dataset_class(
|
||||
data_prefix=dict(img_path=MagicMock()),
|
||||
ann_file=ann_file,
|
||||
metainfo=None,
|
||||
test_mode=True,
|
||||
lazy_init=True)
|
||||
|
||||
assert custom_dataset.METAINFO['classes'] == original_classes
|
||||
assert custom_dataset.label_map is None
|
||||
|
||||
|
||||
def test_custom_dataset_random_palette_is_generated():
|
||||
dataset = CustomDataset(
|
||||
pipeline=[],
|
||||
data_prefix=dict(img_path=MagicMock()),
|
||||
ann_file=MagicMock(),
|
||||
metainfo=dict(classes=('bus', 'car')),
|
||||
lazy_init=True,
|
||||
test_mode=True)
|
||||
assert len(dataset.metainfo['palette']) == 2
|
||||
for class_color in dataset.metainfo['palette']:
|
||||
assert len(class_color) == 3
|
||||
assert all(x >= 0 and x <= 255 for x in class_color)
|
||||
|
||||
|
||||
def test_custom_dataset_custom_palette():
|
||||
dataset = CustomDataset(
|
||||
data_prefix=dict(img_path=MagicMock()),
|
||||
ann_file=MagicMock(),
|
||||
metainfo=dict(
|
||||
classes=('bus', 'car'), palette=[[100, 100, 100], [200, 200,
|
||||
200]]),
|
||||
lazy_init=True,
|
||||
test_mode=True)
|
||||
assert tuple(dataset.metainfo['palette']) == tuple([[100, 100, 100],
|
||||
[200, 200, 200]])
|
||||
# test custom class and palette don't match
|
||||
with pytest.raises(ValueError):
|
||||
dataset = CustomDataset(
|
||||
data_prefix=dict(img_path=MagicMock()),
|
||||
ann_file=MagicMock(),
|
||||
metainfo=dict(classes=('bus', 'car'), palette=[[200, 200, 200]]),
|
||||
lazy_init=True)
|
Loading…
Reference in New Issue