mmsegmentation/tests/test_evaluation/test_metrics/test_iou_metric.py

77 lines
2.6 KiB
Python
Raw Normal View History

2022-06-02 22:15:28 +08:00
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import numpy as np
import torch
from mmengine.structures import PixelData
2022-06-02 22:15:28 +08:00
from mmseg.evaluation import IoUMetric
2022-08-03 15:43:23 +08:00
from mmseg.structures import SegDataSample
2022-06-02 22:15:28 +08:00
class TestIoUMetric(TestCase):
def _demo_mm_inputs(self,
batch_size=2,
image_shapes=(3, 64, 64),
num_classes=5):
"""Create a superset of inputs needed to run test or train batches.
Args:
batch_size (int): batch size. Default to 2.
image_shapes (List[tuple], Optional): image shape.
Default to (3, 64, 64)
num_classes (int): number of different classes.
Default to 5.
"""
if isinstance(image_shapes, list):
assert len(image_shapes) == batch_size
else:
image_shapes = [image_shapes] * batch_size
data_samples = []
2022-06-02 22:15:28 +08:00
for idx in range(batch_size):
image_shape = image_shapes[idx]
_, h, w = image_shape
data_sample = SegDataSample()
gt_semantic_seg = np.random.randint(
0, num_classes, (1, h, w), dtype=np.uint8)
gt_semantic_seg = torch.LongTensor(gt_semantic_seg)
gt_sem_seg_data = dict(data=gt_semantic_seg)
data_sample.gt_sem_seg = PixelData(**gt_sem_seg_data)
data_samples.append(data_sample.to_dict())
return data_samples
2022-06-02 22:15:28 +08:00
def _demo_mm_model_output(self,
data_samples,
2022-06-02 22:15:28 +08:00
batch_size=2,
image_shapes=(3, 64, 64),
num_classes=5):
_, h, w = image_shapes
for data_sample in data_samples:
data_sample['seg_logits'] = dict(
data=torch.randn(num_classes, h, w))
data_sample['pred_sem_seg'] = dict(
data=torch.randint(0, num_classes, (1, h, w)))
return data_samples
2022-06-02 22:15:28 +08:00
def test_evaluate(self):
"""Test using the metric in the same way as Evalutor."""
data_samples = self._demo_mm_inputs()
data_samples = self._demo_mm_model_output(data_samples)
2022-06-02 22:15:28 +08:00
iou_metric = IoUMetric(iou_metrics=['mIoU'])
2022-06-02 22:15:28 +08:00
iou_metric.dataset_meta = dict(
classes=['wall', 'building', 'sky', 'floor', 'tree'],
label_map=dict(),
reduce_zero_label=False)
iou_metric.process([0] * len(data_samples), data_samples)
2022-06-02 22:15:28 +08:00
res = iou_metric.evaluate(6)
self.assertIsInstance(res, dict)