[Fix] Fix iou metric

This commit is contained in:
linfangjian.vendor 2022-06-15 08:18:26 +00:00 committed by zhengmiao
parent ec52a27299
commit e445836bd4

View File

@ -1,8 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, List, Optional, Sequence, Union from typing import Dict, List, Optional, Sequence
import mmcv
import numpy as np import numpy as np
import torch import torch
from mmengine.evaluator import BaseMetric from mmengine.evaluator import BaseMetric
@ -59,15 +58,12 @@ class IoUMetric(BaseMetric):
predictions (Sequence[dict]): A batch of outputs from the model. predictions (Sequence[dict]): A batch of outputs from the model.
""" """
num_classes = len(self.dataset_meta['classes']) num_classes = len(self.dataset_meta['classes'])
label_map = self.dataset_meta['label_map']
reduce_zero_label = self.dataset_meta['reduce_zero_label']
for data, pred in zip(data_batch, predictions): for data, pred in zip(data_batch, predictions):
label = data['data_sample']['gt_sem_seg']['data'][0].cpu().numpy() label = data['data_sample']['gt_sem_seg']['data'][0].cpu()
pred_label = pred['pred_sem_seg']['data'][0].cpu().numpy() pred_label = pred['pred_sem_seg']['data'][0].cpu()
self.results.append( self.results.append(
self.intersect_and_union(pred_label, label, num_classes, self.intersect_and_union(pred_label, label, num_classes,
self.ignore_index, label_map, self.ignore_index))
reduce_zero_label))
def compute_metrics(self, results: list) -> Dict[str, float]: def compute_metrics(self, results: list) -> Dict[str, float]:
"""Compute the metrics from processed results. """Compute the metrics from processed results.
@ -129,25 +125,17 @@ class IoUMetric(BaseMetric):
return metrics return metrics
@staticmethod @staticmethod
def intersect_and_union(pred_label: Union[np.ndarray, str], def intersect_and_union(pred_label: torch.tensor, label: torch.tensor,
label: Union[np.ndarray, str], num_classes: int, ignore_index: int):
num_classes: int,
ignore_index: int,
label_map: dict = dict(),
reduce_zero_label: bool = False):
"""Calculate intersection and Union. """Calculate intersection and Union.
Args: Args:
pred_label (ndarray | str): Prediction segmentation map pred_label (torch.tensor): Prediction segmentation map
or predict result filename. The shape is (H, W). or predict result filename. The shape is (H, W).
label (ndarray | str): Ground truth segmentation map label (torch.tensor): Ground truth segmentation map
or label filename. The shape is (H, W). or label filename. The shape is (H, W).
num_classes (int): Number of categories. num_classes (int): Number of categories.
ignore_index (int): Index that will be ignored in evaluation. ignore_index (int): Index that will be ignored in evaluation.
label_map (dict): Mapping old labels to new labels. The parameter
will work only when label is str. Default: dict().
reduce_zero_label (bool): Whether ignore zero label. The parameter
will work only when label is str. Default: False.
Returns: Returns:
torch.Tensor: The intersection of prediction and ground truth torch.Tensor: The intersection of prediction and ground truth
@ -158,25 +146,6 @@ class IoUMetric(BaseMetric):
torch.Tensor: The ground truth histogram on all classes. torch.Tensor: The ground truth histogram on all classes.
""" """
if isinstance(pred_label, str):
pred_label = torch.from_numpy(np.load(pred_label))
else:
pred_label = torch.from_numpy((pred_label))
if isinstance(label, str):
label = torch.from_numpy(
mmcv.imread(label, flag='unchanged', backend='pillow'))
else:
label = torch.from_numpy(label)
if label_map is not None:
for old_id, new_id in label_map.items():
label[label == old_id] = new_id
if reduce_zero_label:
label[label == 0] = 255
label = label - 1
label[label == 254] = 255
mask = (label != ignore_index) mask = (label != ignore_index)
pred_label = pred_label[mask] pred_label = pred_label[mask]
label = label[mask] label = label[mask]