From 6ca42737b62d6708c7346fe9d10ad938cf13135e Mon Sep 17 00:00:00 2001 From: "linfangjian.vendor" Date: Tue, 14 Jun 2022 02:53:46 +0000 Subject: [PATCH] [Fix] Fix pred --- mmseg/metrics/iou_metric.py | 2 +- mmseg/models/segmentors/encoder_decoder.py | 3 +-- tests/test_metrics/test_iou_metric.py | 1 + 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mmseg/metrics/iou_metric.py b/mmseg/metrics/iou_metric.py index c7964f33f..a7c0524e6 100644 --- a/mmseg/metrics/iou_metric.py +++ b/mmseg/metrics/iou_metric.py @@ -63,7 +63,7 @@ class IoUMetric(BaseMetric): reduce_zero_label = self.dataset_meta['reduce_zero_label'] for data, pred in zip(data_batch, predictions): label = data['data_sample']['gt_sem_seg']['data'][0].cpu().numpy() - pred_label = pred['pred_sem_seg']['data'][0] + pred_label = pred['pred_sem_seg']['data'][0].cpu().numpy() self.results.append( self.intersect_and_union(pred_label, label, num_classes, self.ignore_index, label_map, diff --git a/mmseg/models/segmentors/encoder_decoder.py b/mmseg/models/segmentors/encoder_decoder.py index 1da7f0716..1e61c21e3 100644 --- a/mmseg/models/segmentors/encoder_decoder.py +++ b/mmseg/models/segmentors/encoder_decoder.py @@ -256,8 +256,7 @@ class EncoderDecoder(BaseSegmentor): results_dict = dict() seg_logit = self.inference(batch_inputs, batch_img_metas, rescale) results_dict['seg_logits'] = seg_logit - seg_pred = seg_logit.argmax(dim=1) - seg_pred = seg_pred.cpu().numpy() + seg_pred = seg_logit.argmax(dim=1, keepdim=True) results_dict['pred_sem_seg'] = seg_pred results_list = self.postprocess_result(results_dict) return results_list diff --git a/tests/test_metrics/test_iou_metric.py b/tests/test_metrics/test_iou_metric.py index 58d238bfd..809fabe9a 100644 --- a/tests/test_metrics/test_iou_metric.py +++ b/tests/test_metrics/test_iou_metric.py @@ -65,6 +65,7 @@ class TestIoUMetric(TestCase): results_dict['seg_logits'] = seg_logit seg_pred = np.random.randint( 0, num_classes, (batch_size, h, w), dtype=np.uint8) + seg_pred = torch.LongTensor(seg_pred) results_dict['pred_sem_seg'] = seg_pred batch_datasampes = [