[Fix] Support format_result and fix prefix param in cityscape metric, and rename CitysMetric to CityscapesMetric (#2660)
as titlepull/2715/head
parent
6c3599bd9d
commit
a8aafdd902
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .metrics import CitysMetric, IoUMetric
|
||||
from .metrics import CityscapesMetric, IoUMetric
|
||||
|
||||
__all__ = ['IoUMetric', 'CitysMetric']
|
||||
__all__ = ['IoUMetric', 'CityscapesMetric']
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .citys_metric import CitysMetric
|
||||
from .citys_metric import CityscapesMetric
|
||||
from .iou_metric import IoUMetric
|
||||
|
||||
__all__ = ['IoUMetric', 'CitysMetric']
|
||||
__all__ = ['IoUMetric', 'CityscapesMetric']
|
||||
|
|
|
@ -1,30 +1,41 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, Optional, Sequence
|
||||
|
||||
try:
|
||||
|
||||
import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa
|
||||
import cityscapesscripts.helpers.labels as CSLabels
|
||||
except ImportError:
|
||||
CSLabels = None
|
||||
CSEval = None
|
||||
|
||||
import numpy as np
|
||||
from mmengine.dist import is_main_process, master_only
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmengine.logging import MMLogger, print_log
|
||||
from mmengine.utils import mkdir_or_exist, scandir
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
from PIL import Image
|
||||
|
||||
from mmseg.registry import METRICS
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class CitysMetric(BaseMetric):
|
||||
class CityscapesMetric(BaseMetric):
|
||||
"""Cityscapes evaluation metric.
|
||||
|
||||
Args:
|
||||
output_dir (str): The directory for output prediction
|
||||
ignore_index (int): Index that will be ignored in evaluation.
|
||||
Default: 255.
|
||||
citys_metrics (list[str] | str): Metrics to be evaluated,
|
||||
Default: ['cityscapes'].
|
||||
to_label_id (bool): whether convert output to label_id for
|
||||
submission. Default: True.
|
||||
suffix (str): The filename prefix of the png files.
|
||||
If the prefix is "somepath/xxx", the png files will be
|
||||
named "somepath/xxx.png". Default: '.format_cityscapes'.
|
||||
format_only (bool): Only format result for results commit without
|
||||
perform evaluation. It is useful when you want to format the result
|
||||
to a specific format and submit it to the test server.
|
||||
Defaults to False.
|
||||
keep_results (bool): Whether to keep the results. When ``format_only``
|
||||
is True, ``keep_results`` must be True. Defaults to False.
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
|
@ -35,19 +46,34 @@ class CitysMetric(BaseMetric):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
output_dir: str,
|
||||
ignore_index: int = 255,
|
||||
citys_metrics: List[str] = ['cityscapes'],
|
||||
to_label_id: bool = True,
|
||||
suffix: str = '.format_cityscapes',
|
||||
format_only: bool = False,
|
||||
keep_results: bool = False,
|
||||
collect_device: str = 'cpu',
|
||||
prefix: Optional[str] = None) -> None:
|
||||
super().__init__(collect_device=collect_device, prefix=prefix)
|
||||
|
||||
if CSEval is None:
|
||||
raise ImportError('Please run "pip install cityscapesscripts" to '
|
||||
'install cityscapesscripts first.')
|
||||
self.output_dir = output_dir
|
||||
self.ignore_index = ignore_index
|
||||
self.metrics = citys_metrics
|
||||
assert self.metrics[0] == 'cityscapes'
|
||||
self.to_label_id = to_label_id
|
||||
self.suffix = suffix
|
||||
|
||||
self.format_only = format_only
|
||||
if format_only:
|
||||
assert keep_results, (
|
||||
'When format_only is True, the results must be keep, please '
|
||||
f'set keep_results as True, but got {keep_results}')
|
||||
self.keep_results = keep_results
|
||||
self.prefix = prefix
|
||||
if is_main_process():
|
||||
mkdir_or_exist(self.output_dir)
|
||||
|
||||
@master_only
|
||||
def __del__(self) -> None:
|
||||
"""Clean up."""
|
||||
if not self.keep_results:
|
||||
shutil.rmtree(self.output_dir)
|
||||
|
||||
def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
|
||||
"""Process one batch of data and data_samples.
|
||||
|
@ -59,26 +85,23 @@ class CitysMetric(BaseMetric):
|
|||
data_batch (dict): A batch of data from the dataloader.
|
||||
data_samples (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
mkdir_or_exist(self.suffix)
|
||||
mkdir_or_exist(self.output_dir)
|
||||
|
||||
for data_sample in data_samples:
|
||||
pred_label = data_sample['pred_sem_seg']['data'][0].cpu().numpy()
|
||||
# results2img
|
||||
if self.to_label_id:
|
||||
pred_label = self._convert_to_label_id(pred_label)
|
||||
# when evaluating with official cityscapesscripts,
|
||||
# labelIds should be used
|
||||
pred_label = self._convert_to_label_id(pred_label)
|
||||
basename = osp.splitext(osp.basename(data_sample['img_path']))[0]
|
||||
png_filename = osp.join(self.suffix, f'{basename}.png')
|
||||
png_filename = osp.abspath(
|
||||
osp.join(self.output_dir, f'{basename}.png'))
|
||||
output = Image.fromarray(pred_label.astype(np.uint8)).convert('P')
|
||||
import cityscapesscripts.helpers.labels as CSLabels
|
||||
palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8)
|
||||
for label_id, label in CSLabels.id2label.items():
|
||||
palette[label_id] = label.color
|
||||
output.putpalette(palette)
|
||||
output.save(png_filename)
|
||||
|
||||
ann_dir = osp.join(data_samples[0]['seg_map_path'].split('val')[0],
|
||||
'val')
|
||||
self.results.append(ann_dir)
|
||||
# when evaluating with official cityscapesscripts,
|
||||
# **_gtFine_labelIds.png is used
|
||||
gt_filename = data_sample['seg_map_path'].replace(
|
||||
'labelTrainIds.png', 'labelIds.png')
|
||||
self.results.append((png_filename, gt_filename))
|
||||
|
||||
def compute_metrics(self, results: list) -> Dict[str, float]:
|
||||
"""Compute the metrics from processed results.
|
||||
|
@ -90,38 +113,28 @@ class CitysMetric(BaseMetric):
|
|||
dict[str: float]: Cityscapes evaluation results.
|
||||
"""
|
||||
logger: MMLogger = MMLogger.get_current_instance()
|
||||
try:
|
||||
import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa
|
||||
except ImportError:
|
||||
raise ImportError('Please run "pip install cityscapesscripts" to '
|
||||
'install cityscapesscripts first.')
|
||||
msg = 'Evaluating in Cityscapes style'
|
||||
if self.format_only:
|
||||
logger.info(f'results are saved to {osp.dirname(self.output_dir)}')
|
||||
return OrderedDict()
|
||||
|
||||
msg = 'Evaluating in Cityscapes style'
|
||||
if logger is None:
|
||||
msg = '\n' + msg
|
||||
print_log(msg, logger=logger)
|
||||
|
||||
result_dir = self.suffix
|
||||
|
||||
eval_results = dict()
|
||||
print_log(f'Evaluating results under {result_dir} ...', logger=logger)
|
||||
print_log(
|
||||
f'Evaluating results under {self.output_dir} ...', logger=logger)
|
||||
|
||||
CSEval.args.evalInstLevelScore = True
|
||||
CSEval.args.predictionPath = osp.abspath(result_dir)
|
||||
CSEval.args.predictionPath = osp.abspath(self.output_dir)
|
||||
CSEval.args.evalPixelAccuracy = True
|
||||
CSEval.args.JSONOutput = False
|
||||
|
||||
seg_map_list = []
|
||||
pred_list = []
|
||||
ann_dir = results[0]
|
||||
# when evaluating with official cityscapesscripts,
|
||||
# **_gtFine_labelIds.png is used
|
||||
for seg_map in scandir(ann_dir, 'gtFine_labelIds.png', recursive=True):
|
||||
seg_map_list.append(osp.join(ann_dir, seg_map))
|
||||
pred_list.append(CSEval.getPrediction(CSEval.args, seg_map))
|
||||
pred_list, gt_list = zip(*results)
|
||||
metric = dict()
|
||||
eval_results.update(
|
||||
CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args))
|
||||
CSEval.evaluateImgLists(pred_list, gt_list, CSEval.args))
|
||||
metric['averageScoreCategories'] = eval_results[
|
||||
'averageScoreCategories']
|
||||
metric['averageScoreInstCategories'] = eval_results[
|
||||
|
@ -133,7 +146,6 @@ class CitysMetric(BaseMetric):
|
|||
"""Convert trainId to id for cityscapes."""
|
||||
if isinstance(result, str):
|
||||
result = np.load(result)
|
||||
import cityscapesscripts.helpers.labels as CSLabels
|
||||
result_copy = result.copy()
|
||||
for trainId, label in CSLabels.trainId2label.items():
|
||||
result_copy[result == trainId] = label.id
|
||||
|
|
|
@ -1,15 +1,17 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from mmengine.structures import BaseDataElement, PixelData
|
||||
|
||||
from mmseg.evaluation import CitysMetric
|
||||
from mmseg.evaluation import CityscapesMetric
|
||||
from mmseg.structures import SegDataSample
|
||||
|
||||
|
||||
class TestCitysMetric(TestCase):
|
||||
class TestCityscapesMetric(TestCase):
|
||||
|
||||
def _demo_mm_inputs(self,
|
||||
batch_size=1,
|
||||
|
@ -42,9 +44,8 @@ class TestCitysMetric(TestCase):
|
|||
gt_sem_seg_data = dict(data=gt_semantic_seg)
|
||||
data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
|
||||
mm_inputs['data_sample'] = data_sample.to_dict()
|
||||
mm_inputs['data_sample']['seg_map_path'] = \
|
||||
'tests/data/pseudo_cityscapes_dataset/gtFine/val/\
|
||||
frankfurt/frankfurt_000000_000294_gtFine_labelTrainIds.png'
|
||||
mm_inputs['data_sample'][
|
||||
'seg_map_path'] = 'tests/data/pseudo_cityscapes_dataset/gtFine/val/frankfurt/frankfurt_000000_000294_gtFine_labelTrainIds.png' # noqa
|
||||
|
||||
mm_inputs['seg_map_path'] = mm_inputs['data_sample'][
|
||||
'seg_map_path']
|
||||
|
@ -86,9 +87,8 @@ class TestCitysMetric(TestCase):
|
|||
for pred in batch_datasampes:
|
||||
if isinstance(pred, BaseDataElement):
|
||||
test_data = pred.to_dict()
|
||||
test_data['img_path'] = \
|
||||
'tests/data/pseudo_cityscapes_dataset/leftImg8bit/val/\
|
||||
frankfurt/frankfurt_000000_000294_leftImg8bit.png'
|
||||
test_data[
|
||||
'img_path'] = 'tests/data/pseudo_cityscapes_dataset/leftImg8bit/val/frankfurt/frankfurt_000000_000294_leftImg8bit.png' # noqa
|
||||
|
||||
_predictions.append(test_data)
|
||||
else:
|
||||
|
@ -104,15 +104,23 @@ class TestCitysMetric(TestCase):
|
|||
dict(**data, **result)
|
||||
for data, result in zip(data_batch, predictions)
|
||||
]
|
||||
iou_metric = CitysMetric(citys_metrics=['cityscapes'])
|
||||
iou_metric.process(data_batch, data_samples)
|
||||
res = iou_metric.evaluate(6)
|
||||
self.assertIsInstance(res, dict)
|
||||
# test to_label_id = True
|
||||
iou_metric = CitysMetric(
|
||||
citys_metrics=['cityscapes'], to_label_id=True)
|
||||
iou_metric.process(data_batch, data_samples)
|
||||
res = iou_metric.evaluate(6)
|
||||
# test keep_results should be True when format_only is True
|
||||
with pytest.raises(AssertionError):
|
||||
CityscapesMetric(
|
||||
output_dir='tmp', format_only=True, keep_results=False)
|
||||
|
||||
# test evaluate with cityscape metric
|
||||
metric = CityscapesMetric(output_dir='tmp')
|
||||
metric.process(data_batch, data_samples)
|
||||
res = metric.evaluate(2)
|
||||
self.assertIsInstance(res, dict)
|
||||
|
||||
# test format_only
|
||||
metric = CityscapesMetric(
|
||||
output_dir='tmp', format_only=True, keep_results=True)
|
||||
metric.process(data_batch, data_samples)
|
||||
metric.evaluate(2)
|
||||
assert osp.exists('tmp')
|
||||
assert osp.isfile('tmp/frankfurt_000000_000294_leftImg8bit.png')
|
||||
import shutil
|
||||
shutil.rmtree('.format_cityscapes')
|
||||
shutil.rmtree('tmp')
|
||||
|
|
Loading…
Reference in New Issue