[Fix] Support format_result and fix prefix param in cityscape metric, and rename CitysMetric to CityscapesMetric (#2660)

as title
pull/2715/head
Miao Zheng 2023-03-07 17:57:37 +08:00 committed by GitHub
parent 6c3599bd9d
commit a8aafdd902
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 94 additions and 74 deletions

View File

@ -1,4 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .metrics import CitysMetric, IoUMetric
from .metrics import CityscapesMetric, IoUMetric
__all__ = ['IoUMetric', 'CitysMetric']
__all__ = ['IoUMetric', 'CityscapesMetric']

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .citys_metric import CitysMetric
from .citys_metric import CityscapesMetric
from .iou_metric import IoUMetric
__all__ = ['IoUMetric', 'CitysMetric']
__all__ = ['IoUMetric', 'CityscapesMetric']

View File

@ -1,30 +1,41 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Dict, List, Optional, Sequence
import shutil
from collections import OrderedDict
from typing import Dict, Optional, Sequence
try:
import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa
import cityscapesscripts.helpers.labels as CSLabels
except ImportError:
CSLabels = None
CSEval = None
import numpy as np
from mmengine.dist import is_main_process, master_only
from mmengine.evaluator import BaseMetric
from mmengine.logging import MMLogger, print_log
from mmengine.utils import mkdir_or_exist, scandir
from mmengine.utils import mkdir_or_exist
from PIL import Image
from mmseg.registry import METRICS
@METRICS.register_module()
class CitysMetric(BaseMetric):
class CityscapesMetric(BaseMetric):
"""Cityscapes evaluation metric.
Args:
output_dir (str): The directory for output prediction
ignore_index (int): Index that will be ignored in evaluation.
Default: 255.
citys_metrics (list[str] | str): Metrics to be evaluated,
Default: ['cityscapes'].
to_label_id (bool): whether convert output to label_id for
submission. Default: True.
suffix (str): The filename prefix of the png files.
If the prefix is "somepath/xxx", the png files will be
named "somepath/xxx.png". Default: '.format_cityscapes'.
format_only (bool): Only format result for results commit without
perform evaluation. It is useful when you want to format the result
to a specific format and submit it to the test server.
Defaults to False.
keep_results (bool): Whether to keep the results. When ``format_only``
is True, ``keep_results`` must be True. Defaults to False.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
@ -35,19 +46,34 @@ class CitysMetric(BaseMetric):
"""
def __init__(self,
output_dir: str,
ignore_index: int = 255,
citys_metrics: List[str] = ['cityscapes'],
to_label_id: bool = True,
suffix: str = '.format_cityscapes',
format_only: bool = False,
keep_results: bool = False,
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
super().__init__(collect_device=collect_device, prefix=prefix)
if CSEval is None:
raise ImportError('Please run "pip install cityscapesscripts" to '
'install cityscapesscripts first.')
self.output_dir = output_dir
self.ignore_index = ignore_index
self.metrics = citys_metrics
assert self.metrics[0] == 'cityscapes'
self.to_label_id = to_label_id
self.suffix = suffix
self.format_only = format_only
if format_only:
assert keep_results, (
'When format_only is True, the results must be keep, please '
f'set keep_results as True, but got {keep_results}')
self.keep_results = keep_results
self.prefix = prefix
if is_main_process():
mkdir_or_exist(self.output_dir)
@master_only
def __del__(self) -> None:
"""Clean up."""
if not self.keep_results:
shutil.rmtree(self.output_dir)
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
"""Process one batch of data and data_samples.
@ -59,26 +85,23 @@ class CitysMetric(BaseMetric):
data_batch (dict): A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
mkdir_or_exist(self.suffix)
mkdir_or_exist(self.output_dir)
for data_sample in data_samples:
pred_label = data_sample['pred_sem_seg']['data'][0].cpu().numpy()
# results2img
if self.to_label_id:
pred_label = self._convert_to_label_id(pred_label)
# when evaluating with official cityscapesscripts,
# labelIds should be used
pred_label = self._convert_to_label_id(pred_label)
basename = osp.splitext(osp.basename(data_sample['img_path']))[0]
png_filename = osp.join(self.suffix, f'{basename}.png')
png_filename = osp.abspath(
osp.join(self.output_dir, f'{basename}.png'))
output = Image.fromarray(pred_label.astype(np.uint8)).convert('P')
import cityscapesscripts.helpers.labels as CSLabels
palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8)
for label_id, label in CSLabels.id2label.items():
palette[label_id] = label.color
output.putpalette(palette)
output.save(png_filename)
ann_dir = osp.join(data_samples[0]['seg_map_path'].split('val')[0],
'val')
self.results.append(ann_dir)
# when evaluating with official cityscapesscripts,
# **_gtFine_labelIds.png is used
gt_filename = data_sample['seg_map_path'].replace(
'labelTrainIds.png', 'labelIds.png')
self.results.append((png_filename, gt_filename))
def compute_metrics(self, results: list) -> Dict[str, float]:
"""Compute the metrics from processed results.
@ -90,38 +113,28 @@ class CitysMetric(BaseMetric):
dict[str: float]: Cityscapes evaluation results.
"""
logger: MMLogger = MMLogger.get_current_instance()
try:
import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa
except ImportError:
raise ImportError('Please run "pip install cityscapesscripts" to '
'install cityscapesscripts first.')
msg = 'Evaluating in Cityscapes style'
if self.format_only:
logger.info(f'results are saved to {osp.dirname(self.output_dir)}')
return OrderedDict()
msg = 'Evaluating in Cityscapes style'
if logger is None:
msg = '\n' + msg
print_log(msg, logger=logger)
result_dir = self.suffix
eval_results = dict()
print_log(f'Evaluating results under {result_dir} ...', logger=logger)
print_log(
f'Evaluating results under {self.output_dir} ...', logger=logger)
CSEval.args.evalInstLevelScore = True
CSEval.args.predictionPath = osp.abspath(result_dir)
CSEval.args.predictionPath = osp.abspath(self.output_dir)
CSEval.args.evalPixelAccuracy = True
CSEval.args.JSONOutput = False
seg_map_list = []
pred_list = []
ann_dir = results[0]
# when evaluating with official cityscapesscripts,
# **_gtFine_labelIds.png is used
for seg_map in scandir(ann_dir, 'gtFine_labelIds.png', recursive=True):
seg_map_list.append(osp.join(ann_dir, seg_map))
pred_list.append(CSEval.getPrediction(CSEval.args, seg_map))
pred_list, gt_list = zip(*results)
metric = dict()
eval_results.update(
CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args))
CSEval.evaluateImgLists(pred_list, gt_list, CSEval.args))
metric['averageScoreCategories'] = eval_results[
'averageScoreCategories']
metric['averageScoreInstCategories'] = eval_results[
@ -133,7 +146,6 @@ class CitysMetric(BaseMetric):
"""Convert trainId to id for cityscapes."""
if isinstance(result, str):
result = np.load(result)
import cityscapesscripts.helpers.labels as CSLabels
result_copy = result.copy()
for trainId, label in CSLabels.trainId2label.items():
result_copy[result == trainId] = label.id

View File

@ -1,15 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from unittest import TestCase
import numpy as np
import pytest
import torch
from mmengine.structures import BaseDataElement, PixelData
from mmseg.evaluation import CitysMetric
from mmseg.evaluation import CityscapesMetric
from mmseg.structures import SegDataSample
class TestCitysMetric(TestCase):
class TestCityscapesMetric(TestCase):
def _demo_mm_inputs(self,
batch_size=1,
@ -42,9 +44,8 @@ class TestCitysMetric(TestCase):
gt_sem_seg_data = dict(data=gt_semantic_seg)
data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
mm_inputs['data_sample'] = data_sample.to_dict()
mm_inputs['data_sample']['seg_map_path'] = \
'tests/data/pseudo_cityscapes_dataset/gtFine/val/\
frankfurt/frankfurt_000000_000294_gtFine_labelTrainIds.png'
mm_inputs['data_sample'][
'seg_map_path'] = 'tests/data/pseudo_cityscapes_dataset/gtFine/val/frankfurt/frankfurt_000000_000294_gtFine_labelTrainIds.png' # noqa
mm_inputs['seg_map_path'] = mm_inputs['data_sample'][
'seg_map_path']
@ -86,9 +87,8 @@ class TestCitysMetric(TestCase):
for pred in batch_datasampes:
if isinstance(pred, BaseDataElement):
test_data = pred.to_dict()
test_data['img_path'] = \
'tests/data/pseudo_cityscapes_dataset/leftImg8bit/val/\
frankfurt/frankfurt_000000_000294_leftImg8bit.png'
test_data[
'img_path'] = 'tests/data/pseudo_cityscapes_dataset/leftImg8bit/val/frankfurt/frankfurt_000000_000294_leftImg8bit.png' # noqa
_predictions.append(test_data)
else:
@ -104,15 +104,23 @@ class TestCitysMetric(TestCase):
dict(**data, **result)
for data, result in zip(data_batch, predictions)
]
iou_metric = CitysMetric(citys_metrics=['cityscapes'])
iou_metric.process(data_batch, data_samples)
res = iou_metric.evaluate(6)
self.assertIsInstance(res, dict)
# test to_label_id = True
iou_metric = CitysMetric(
citys_metrics=['cityscapes'], to_label_id=True)
iou_metric.process(data_batch, data_samples)
res = iou_metric.evaluate(6)
# test keep_results should be True when format_only is True
with pytest.raises(AssertionError):
CityscapesMetric(
output_dir='tmp', format_only=True, keep_results=False)
# test evaluate with cityscape metric
metric = CityscapesMetric(output_dir='tmp')
metric.process(data_batch, data_samples)
res = metric.evaluate(2)
self.assertIsInstance(res, dict)
# test format_only
metric = CityscapesMetric(
output_dir='tmp', format_only=True, keep_results=True)
metric.process(data_batch, data_samples)
metric.evaluate(2)
assert osp.exists('tmp')
assert osp.isfile('tmp/frankfurt_000000_000294_leftImg8bit.png')
import shutil
shutil.rmtree('.format_cityscapes')
shutil.rmtree('tmp')