diff --git a/mmseg/metrics/iou_metric.py b/mmseg/metrics/iou_metric.py index a7c0524e6..f0dd45a4c 100644 --- a/mmseg/metrics/iou_metric.py +++ b/mmseg/metrics/iou_metric.py @@ -1,8 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from collections import OrderedDict -from typing import Dict, List, Optional, Sequence, Union +from typing import Dict, List, Optional, Sequence -import mmcv import numpy as np import torch from mmengine.evaluator import BaseMetric @@ -59,15 +58,12 @@ class IoUMetric(BaseMetric): predictions (Sequence[dict]): A batch of outputs from the model. """ num_classes = len(self.dataset_meta['classes']) - label_map = self.dataset_meta['label_map'] - reduce_zero_label = self.dataset_meta['reduce_zero_label'] for data, pred in zip(data_batch, predictions): - label = data['data_sample']['gt_sem_seg']['data'][0].cpu().numpy() - pred_label = pred['pred_sem_seg']['data'][0].cpu().numpy() + label = data['data_sample']['gt_sem_seg']['data'][0].cpu() + pred_label = pred['pred_sem_seg']['data'][0].cpu() self.results.append( self.intersect_and_union(pred_label, label, num_classes, - self.ignore_index, label_map, - reduce_zero_label)) + self.ignore_index)) def compute_metrics(self, results: list) -> Dict[str, float]: """Compute the metrics from processed results. @@ -129,25 +125,17 @@ class IoUMetric(BaseMetric): return metrics @staticmethod - def intersect_and_union(pred_label: Union[np.ndarray, str], - label: Union[np.ndarray, str], - num_classes: int, - ignore_index: int, - label_map: dict = dict(), - reduce_zero_label: bool = False): + def intersect_and_union(pred_label: torch.tensor, label: torch.tensor, + num_classes: int, ignore_index: int): """Calculate intersection and Union. Args: - pred_label (ndarray | str): Prediction segmentation map + pred_label (torch.tensor): Prediction segmentation map or predict result filename. The shape is (H, W). - label (ndarray | str): Ground truth segmentation map + label (torch.tensor): Ground truth segmentation map or label filename. The shape is (H, W). num_classes (int): Number of categories. ignore_index (int): Index that will be ignored in evaluation. - label_map (dict): Mapping old labels to new labels. The parameter - will work only when label is str. Default: dict(). - reduce_zero_label (bool): Whether ignore zero label. The parameter - will work only when label is str. Default: False. Returns: torch.Tensor: The intersection of prediction and ground truth @@ -158,25 +146,6 @@ class IoUMetric(BaseMetric): torch.Tensor: The ground truth histogram on all classes. """ - if isinstance(pred_label, str): - pred_label = torch.from_numpy(np.load(pred_label)) - else: - pred_label = torch.from_numpy((pred_label)) - - if isinstance(label, str): - label = torch.from_numpy( - mmcv.imread(label, flag='unchanged', backend='pillow')) - else: - label = torch.from_numpy(label) - - if label_map is not None: - for old_id, new_id in label_map.items(): - label[label == old_id] = new_id - if reduce_zero_label: - label[label == 0] = 255 - label = label - 1 - label[label == 254] = 255 - mask = (label != ignore_index) pred_label = pred_label[mask] label = label[mask]