[Fix] Format shape check (#2753)

as title
This commit is contained in:
Miao Zheng 2023-03-15 17:49:59 +08:00 committed by GitHub
parent dd47cef801
commit 3cc7ae2167
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 3 deletions

View File

@ -74,12 +74,12 @@ class PackSegInputs(BaseTransform):
data_sample = SegDataSample()
if 'gt_seg_map' in results:
if results['gt_seg_map'].shape == 2:
if len(results['gt_seg_map'].shape) == 2:
data = to_tensor(results['gt_seg_map'][None,
...].astype(np.int64))
else:
warnings.warn('Please pay attention your ground truth '
'segmentation map, usually the segentation '
'segmentation map, usually the segmentation '
'map is 2D, but got '
f'{results["gt_seg_map"].shape}')
data = to_tensor(results['gt_seg_map'].astype(np.int64))

View File

@ -4,6 +4,7 @@ import os.path as osp
import unittest
import numpy as np
import pytest
from mmengine.structures import BaseDataElement
from mmseg.datasets.transforms import PackSegInputs
@ -46,8 +47,11 @@ class TestPackSegInputs(unittest.TestCase):
self.assertEqual(results['data_samples'].ori_shape,
results['data_samples'].gt_sem_seg.shape)
results = copy.deepcopy(self.results)
# test dataset shape is not 2D
results['gt_seg_map'] = np.random.rand(3, 300, 400)
results = transform(results)
msg = 'the segmentation map is 2D'
with pytest.warns(UserWarning, match=msg):
results = transform(results)
self.assertEqual(results['data_samples'].ori_shape,
results['data_samples'].gt_sem_seg.shape)