mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Fix iou metric
This commit is contained in:
parent
ec52a27299
commit
e445836bd4
@ -1,8 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, Optional, Sequence, Union
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.evaluator import BaseMetric
|
||||
@ -59,15 +58,12 @@ class IoUMetric(BaseMetric):
|
||||
predictions (Sequence[dict]): A batch of outputs from the model.
|
||||
"""
|
||||
num_classes = len(self.dataset_meta['classes'])
|
||||
label_map = self.dataset_meta['label_map']
|
||||
reduce_zero_label = self.dataset_meta['reduce_zero_label']
|
||||
for data, pred in zip(data_batch, predictions):
|
||||
label = data['data_sample']['gt_sem_seg']['data'][0].cpu().numpy()
|
||||
pred_label = pred['pred_sem_seg']['data'][0].cpu().numpy()
|
||||
label = data['data_sample']['gt_sem_seg']['data'][0].cpu()
|
||||
pred_label = pred['pred_sem_seg']['data'][0].cpu()
|
||||
self.results.append(
|
||||
self.intersect_and_union(pred_label, label, num_classes,
|
||||
self.ignore_index, label_map,
|
||||
reduce_zero_label))
|
||||
self.ignore_index))
|
||||
|
||||
def compute_metrics(self, results: list) -> Dict[str, float]:
|
||||
"""Compute the metrics from processed results.
|
||||
@ -129,25 +125,17 @@ class IoUMetric(BaseMetric):
|
||||
return metrics
|
||||
|
||||
@staticmethod
|
||||
def intersect_and_union(pred_label: Union[np.ndarray, str],
|
||||
label: Union[np.ndarray, str],
|
||||
num_classes: int,
|
||||
ignore_index: int,
|
||||
label_map: dict = dict(),
|
||||
reduce_zero_label: bool = False):
|
||||
def intersect_and_union(pred_label: torch.tensor, label: torch.tensor,
|
||||
num_classes: int, ignore_index: int):
|
||||
"""Calculate intersection and Union.
|
||||
|
||||
Args:
|
||||
pred_label (ndarray | str): Prediction segmentation map
|
||||
pred_label (torch.tensor): Prediction segmentation map
|
||||
or predict result filename. The shape is (H, W).
|
||||
label (ndarray | str): Ground truth segmentation map
|
||||
label (torch.tensor): Ground truth segmentation map
|
||||
or label filename. The shape is (H, W).
|
||||
num_classes (int): Number of categories.
|
||||
ignore_index (int): Index that will be ignored in evaluation.
|
||||
label_map (dict): Mapping old labels to new labels. The parameter
|
||||
will work only when label is str. Default: dict().
|
||||
reduce_zero_label (bool): Whether ignore zero label. The parameter
|
||||
will work only when label is str. Default: False.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The intersection of prediction and ground truth
|
||||
@ -158,25 +146,6 @@ class IoUMetric(BaseMetric):
|
||||
torch.Tensor: The ground truth histogram on all classes.
|
||||
"""
|
||||
|
||||
if isinstance(pred_label, str):
|
||||
pred_label = torch.from_numpy(np.load(pred_label))
|
||||
else:
|
||||
pred_label = torch.from_numpy((pred_label))
|
||||
|
||||
if isinstance(label, str):
|
||||
label = torch.from_numpy(
|
||||
mmcv.imread(label, flag='unchanged', backend='pillow'))
|
||||
else:
|
||||
label = torch.from_numpy(label)
|
||||
|
||||
if label_map is not None:
|
||||
for old_id, new_id in label_map.items():
|
||||
label[label == old_id] = new_id
|
||||
if reduce_zero_label:
|
||||
label[label == 0] = 255
|
||||
label = label - 1
|
||||
label[label == 254] = 255
|
||||
|
||||
mask = (label != ignore_index)
|
||||
pred_label = pred_label[mask]
|
||||
label = label[mask]
|
||||
|
Loading…
x
Reference in New Issue
Block a user