[Fix] Format shape check (#2753)

as title
This commit is contained in:
Miao Zheng 2023-03-15 17:49:59 +08:00 committed by GitHub
parent dd47cef801
commit 3cc7ae2167
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 3 deletions

View File

@ -74,12 +74,12 @@ class PackSegInputs(BaseTransform):
data_sample = SegDataSample() data_sample = SegDataSample()
if 'gt_seg_map' in results: if 'gt_seg_map' in results:
if results['gt_seg_map'].shape == 2: if len(results['gt_seg_map'].shape) == 2:
data = to_tensor(results['gt_seg_map'][None, data = to_tensor(results['gt_seg_map'][None,
...].astype(np.int64)) ...].astype(np.int64))
else: else:
warnings.warn('Please pay attention your ground truth ' warnings.warn('Please pay attention your ground truth '
'segmentation map, usually the segentation ' 'segmentation map, usually the segmentation '
'map is 2D, but got ' 'map is 2D, but got '
f'{results["gt_seg_map"].shape}') f'{results["gt_seg_map"].shape}')
data = to_tensor(results['gt_seg_map'].astype(np.int64)) data = to_tensor(results['gt_seg_map'].astype(np.int64))

View File

@ -4,6 +4,7 @@ import os.path as osp
import unittest import unittest
import numpy as np import numpy as np
import pytest
from mmengine.structures import BaseDataElement from mmengine.structures import BaseDataElement
from mmseg.datasets.transforms import PackSegInputs from mmseg.datasets.transforms import PackSegInputs
@ -46,7 +47,10 @@ class TestPackSegInputs(unittest.TestCase):
self.assertEqual(results['data_samples'].ori_shape, self.assertEqual(results['data_samples'].ori_shape,
results['data_samples'].gt_sem_seg.shape) results['data_samples'].gt_sem_seg.shape)
results = copy.deepcopy(self.results) results = copy.deepcopy(self.results)
# test dataset shape is not 2D
results['gt_seg_map'] = np.random.rand(3, 300, 400) results['gt_seg_map'] = np.random.rand(3, 300, 400)
msg = 'the segmentation map is 2D'
with pytest.warns(UserWarning, match=msg):
results = transform(results) results = transform(results)
self.assertEqual(results['data_samples'].ori_shape, self.assertEqual(results['data_samples'].ori_shape,
results['data_samples'].gt_sem_seg.shape) results['data_samples'].gt_sem_seg.shape)