mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Fix iou metric
This commit is contained in:
parent
ec52a27299
commit
e445836bd4
@ -1,8 +1,7 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Dict, List, Optional, Sequence, Union
|
from typing import Dict, List, Optional, Sequence
|
||||||
|
|
||||||
import mmcv
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from mmengine.evaluator import BaseMetric
|
from mmengine.evaluator import BaseMetric
|
||||||
@ -59,15 +58,12 @@ class IoUMetric(BaseMetric):
|
|||||||
predictions (Sequence[dict]): A batch of outputs from the model.
|
predictions (Sequence[dict]): A batch of outputs from the model.
|
||||||
"""
|
"""
|
||||||
num_classes = len(self.dataset_meta['classes'])
|
num_classes = len(self.dataset_meta['classes'])
|
||||||
label_map = self.dataset_meta['label_map']
|
|
||||||
reduce_zero_label = self.dataset_meta['reduce_zero_label']
|
|
||||||
for data, pred in zip(data_batch, predictions):
|
for data, pred in zip(data_batch, predictions):
|
||||||
label = data['data_sample']['gt_sem_seg']['data'][0].cpu().numpy()
|
label = data['data_sample']['gt_sem_seg']['data'][0].cpu()
|
||||||
pred_label = pred['pred_sem_seg']['data'][0].cpu().numpy()
|
pred_label = pred['pred_sem_seg']['data'][0].cpu()
|
||||||
self.results.append(
|
self.results.append(
|
||||||
self.intersect_and_union(pred_label, label, num_classes,
|
self.intersect_and_union(pred_label, label, num_classes,
|
||||||
self.ignore_index, label_map,
|
self.ignore_index))
|
||||||
reduce_zero_label))
|
|
||||||
|
|
||||||
def compute_metrics(self, results: list) -> Dict[str, float]:
|
def compute_metrics(self, results: list) -> Dict[str, float]:
|
||||||
"""Compute the metrics from processed results.
|
"""Compute the metrics from processed results.
|
||||||
@ -129,25 +125,17 @@ class IoUMetric(BaseMetric):
|
|||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def intersect_and_union(pred_label: Union[np.ndarray, str],
|
def intersect_and_union(pred_label: torch.tensor, label: torch.tensor,
|
||||||
label: Union[np.ndarray, str],
|
num_classes: int, ignore_index: int):
|
||||||
num_classes: int,
|
|
||||||
ignore_index: int,
|
|
||||||
label_map: dict = dict(),
|
|
||||||
reduce_zero_label: bool = False):
|
|
||||||
"""Calculate intersection and Union.
|
"""Calculate intersection and Union.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pred_label (ndarray | str): Prediction segmentation map
|
pred_label (torch.tensor): Prediction segmentation map
|
||||||
or predict result filename. The shape is (H, W).
|
or predict result filename. The shape is (H, W).
|
||||||
label (ndarray | str): Ground truth segmentation map
|
label (torch.tensor): Ground truth segmentation map
|
||||||
or label filename. The shape is (H, W).
|
or label filename. The shape is (H, W).
|
||||||
num_classes (int): Number of categories.
|
num_classes (int): Number of categories.
|
||||||
ignore_index (int): Index that will be ignored in evaluation.
|
ignore_index (int): Index that will be ignored in evaluation.
|
||||||
label_map (dict): Mapping old labels to new labels. The parameter
|
|
||||||
will work only when label is str. Default: dict().
|
|
||||||
reduce_zero_label (bool): Whether ignore zero label. The parameter
|
|
||||||
will work only when label is str. Default: False.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: The intersection of prediction and ground truth
|
torch.Tensor: The intersection of prediction and ground truth
|
||||||
@ -158,25 +146,6 @@ class IoUMetric(BaseMetric):
|
|||||||
torch.Tensor: The ground truth histogram on all classes.
|
torch.Tensor: The ground truth histogram on all classes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(pred_label, str):
|
|
||||||
pred_label = torch.from_numpy(np.load(pred_label))
|
|
||||||
else:
|
|
||||||
pred_label = torch.from_numpy((pred_label))
|
|
||||||
|
|
||||||
if isinstance(label, str):
|
|
||||||
label = torch.from_numpy(
|
|
||||||
mmcv.imread(label, flag='unchanged', backend='pillow'))
|
|
||||||
else:
|
|
||||||
label = torch.from_numpy(label)
|
|
||||||
|
|
||||||
if label_map is not None:
|
|
||||||
for old_id, new_id in label_map.items():
|
|
||||||
label[label == old_id] = new_id
|
|
||||||
if reduce_zero_label:
|
|
||||||
label[label == 0] = 255
|
|
||||||
label = label - 1
|
|
||||||
label[label == 254] = 255
|
|
||||||
|
|
||||||
mask = (label != ignore_index)
|
mask = (label != ignore_index)
|
||||||
pred_label = pred_label[mask]
|
pred_label = pred_label[mask]
|
||||||
label = label[mask]
|
label = label[mask]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user