[Fix] Fix iou metric

This commit is contained in:
linfangjian.vendor 2022-06-15 08:18:26 +00:00 committed by zhengmiao
parent ec52a27299
commit e445836bd4

View File

@ -1,8 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
from typing import Dict, List, Optional, Sequence, Union
from typing import Dict, List, Optional, Sequence
import mmcv
import numpy as np
import torch
from mmengine.evaluator import BaseMetric
@ -59,15 +58,12 @@ class IoUMetric(BaseMetric):
predictions (Sequence[dict]): A batch of outputs from the model.
"""
num_classes = len(self.dataset_meta['classes'])
label_map = self.dataset_meta['label_map']
reduce_zero_label = self.dataset_meta['reduce_zero_label']
for data, pred in zip(data_batch, predictions):
label = data['data_sample']['gt_sem_seg']['data'][0].cpu().numpy()
pred_label = pred['pred_sem_seg']['data'][0].cpu().numpy()
label = data['data_sample']['gt_sem_seg']['data'][0].cpu()
pred_label = pred['pred_sem_seg']['data'][0].cpu()
self.results.append(
self.intersect_and_union(pred_label, label, num_classes,
self.ignore_index, label_map,
reduce_zero_label))
self.ignore_index))
def compute_metrics(self, results: list) -> Dict[str, float]:
"""Compute the metrics from processed results.
@ -129,25 +125,17 @@ class IoUMetric(BaseMetric):
return metrics
@staticmethod
def intersect_and_union(pred_label: Union[np.ndarray, str],
label: Union[np.ndarray, str],
num_classes: int,
ignore_index: int,
label_map: dict = dict(),
reduce_zero_label: bool = False):
def intersect_and_union(pred_label: torch.tensor, label: torch.tensor,
num_classes: int, ignore_index: int):
"""Calculate intersection and Union.
Args:
pred_label (ndarray | str): Prediction segmentation map
pred_label (torch.tensor): Prediction segmentation map
or predict result filename. The shape is (H, W).
label (ndarray | str): Ground truth segmentation map
label (torch.tensor): Ground truth segmentation map
or label filename. The shape is (H, W).
num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation.
label_map (dict): Mapping old labels to new labels. The parameter
will work only when label is str. Default: dict().
reduce_zero_label (bool): Whether ignore zero label. The parameter
will work only when label is str. Default: False.
Returns:
torch.Tensor: The intersection of prediction and ground truth
@ -158,25 +146,6 @@ class IoUMetric(BaseMetric):
torch.Tensor: The ground truth histogram on all classes.
"""
if isinstance(pred_label, str):
pred_label = torch.from_numpy(np.load(pred_label))
else:
pred_label = torch.from_numpy((pred_label))
if isinstance(label, str):
label = torch.from_numpy(
mmcv.imread(label, flag='unchanged', backend='pillow'))
else:
label = torch.from_numpy(label)
if label_map is not None:
for old_id, new_id in label_map.items():
label[label == old_id] = new_id
if reduce_zero_label:
label[label == 0] = 255
label = label - 1
label[label == 254] = 255
mask = (label != ignore_index)
pred_label = pred_label[mask]
label = label[mask]