diff --git a/mmseg/datasets/transforms/formatting.py b/mmseg/datasets/transforms/formatting.py index 57cda9b10..46171d64f 100644 --- a/mmseg/datasets/transforms/formatting.py +++ b/mmseg/datasets/transforms/formatting.py @@ -74,12 +74,12 @@ class PackSegInputs(BaseTransform): data_sample = SegDataSample() if 'gt_seg_map' in results: - if results['gt_seg_map'].shape == 2: + if len(results['gt_seg_map'].shape) == 2: data = to_tensor(results['gt_seg_map'][None, ...].astype(np.int64)) else: warnings.warn('Please pay attention your ground truth ' - 'segmentation map, usually the segentation ' + 'segmentation map, usually the segmentation ' 'map is 2D, but got ' f'{results["gt_seg_map"].shape}') data = to_tensor(results['gt_seg_map'].astype(np.int64)) diff --git a/tests/test_datasets/test_formatting.py b/tests/test_datasets/test_formatting.py index 51fd90d04..d0e5820ec 100644 --- a/tests/test_datasets/test_formatting.py +++ b/tests/test_datasets/test_formatting.py @@ -4,6 +4,7 @@ import os.path as osp import unittest import numpy as np +import pytest from mmengine.structures import BaseDataElement from mmseg.datasets.transforms import PackSegInputs @@ -46,8 +47,11 @@ class TestPackSegInputs(unittest.TestCase): self.assertEqual(results['data_samples'].ori_shape, results['data_samples'].gt_sem_seg.shape) results = copy.deepcopy(self.results) + # test dataset shape is not 2D results['gt_seg_map'] = np.random.rand(3, 300, 400) - results = transform(results) + msg = 'the segmentation map is 2D' + with pytest.warns(UserWarning, match=msg): + results = transform(results) self.assertEqual(results['data_samples'].ori_shape, results['data_samples'].gt_sem_seg.shape)