diff --git a/mmseg/metrics/iou_metric.py b/mmseg/metrics/iou_metric.py index c7964f33f..a7c0524e6 100644 --- a/mmseg/metrics/iou_metric.py +++ b/mmseg/metrics/iou_metric.py @@ -63,7 +63,7 @@ class IoUMetric(BaseMetric): reduce_zero_label = self.dataset_meta['reduce_zero_label'] for data, pred in zip(data_batch, predictions): label = data['data_sample']['gt_sem_seg']['data'][0].cpu().numpy() - pred_label = pred['pred_sem_seg']['data'][0] + pred_label = pred['pred_sem_seg']['data'][0].cpu().numpy() self.results.append( self.intersect_and_union(pred_label, label, num_classes, self.ignore_index, label_map, diff --git a/mmseg/models/segmentors/encoder_decoder.py b/mmseg/models/segmentors/encoder_decoder.py index 1da7f0716..1e61c21e3 100644 --- a/mmseg/models/segmentors/encoder_decoder.py +++ b/mmseg/models/segmentors/encoder_decoder.py @@ -256,8 +256,7 @@ class EncoderDecoder(BaseSegmentor): results_dict = dict() seg_logit = self.inference(batch_inputs, batch_img_metas, rescale) results_dict['seg_logits'] = seg_logit - seg_pred = seg_logit.argmax(dim=1) - seg_pred = seg_pred.cpu().numpy() + seg_pred = seg_logit.argmax(dim=1, keepdim=True) results_dict['pred_sem_seg'] = seg_pred results_list = self.postprocess_result(results_dict) return results_list diff --git a/tests/test_metrics/test_iou_metric.py b/tests/test_metrics/test_iou_metric.py index 58d238bfd..809fabe9a 100644 --- a/tests/test_metrics/test_iou_metric.py +++ b/tests/test_metrics/test_iou_metric.py @@ -65,6 +65,7 @@ class TestIoUMetric(TestCase): results_dict['seg_logits'] = seg_logit seg_pred = np.random.randint( 0, num_classes, (batch_size, h, w), dtype=np.uint8) + seg_pred = torch.LongTensor(seg_pred) results_dict['pred_sem_seg'] = seg_pred batch_datasampes = [