mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
parent
dd47cef801
commit
3cc7ae2167
@ -74,12 +74,12 @@ class PackSegInputs(BaseTransform):
|
|||||||
|
|
||||||
data_sample = SegDataSample()
|
data_sample = SegDataSample()
|
||||||
if 'gt_seg_map' in results:
|
if 'gt_seg_map' in results:
|
||||||
if results['gt_seg_map'].shape == 2:
|
if len(results['gt_seg_map'].shape) == 2:
|
||||||
data = to_tensor(results['gt_seg_map'][None,
|
data = to_tensor(results['gt_seg_map'][None,
|
||||||
...].astype(np.int64))
|
...].astype(np.int64))
|
||||||
else:
|
else:
|
||||||
warnings.warn('Please pay attention your ground truth '
|
warnings.warn('Please pay attention your ground truth '
|
||||||
'segmentation map, usually the segentation '
|
'segmentation map, usually the segmentation '
|
||||||
'map is 2D, but got '
|
'map is 2D, but got '
|
||||||
f'{results["gt_seg_map"].shape}')
|
f'{results["gt_seg_map"].shape}')
|
||||||
data = to_tensor(results['gt_seg_map'].astype(np.int64))
|
data = to_tensor(results['gt_seg_map'].astype(np.int64))
|
||||||
|
@ -4,6 +4,7 @@ import os.path as osp
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
from mmengine.structures import BaseDataElement
|
from mmengine.structures import BaseDataElement
|
||||||
|
|
||||||
from mmseg.datasets.transforms import PackSegInputs
|
from mmseg.datasets.transforms import PackSegInputs
|
||||||
@ -46,7 +47,10 @@ class TestPackSegInputs(unittest.TestCase):
|
|||||||
self.assertEqual(results['data_samples'].ori_shape,
|
self.assertEqual(results['data_samples'].ori_shape,
|
||||||
results['data_samples'].gt_sem_seg.shape)
|
results['data_samples'].gt_sem_seg.shape)
|
||||||
results = copy.deepcopy(self.results)
|
results = copy.deepcopy(self.results)
|
||||||
|
# test dataset shape is not 2D
|
||||||
results['gt_seg_map'] = np.random.rand(3, 300, 400)
|
results['gt_seg_map'] = np.random.rand(3, 300, 400)
|
||||||
|
msg = 'the segmentation map is 2D'
|
||||||
|
with pytest.warns(UserWarning, match=msg):
|
||||||
results = transform(results)
|
results = transform(results)
|
||||||
self.assertEqual(results['data_samples'].ori_shape,
|
self.assertEqual(results['data_samples'].ori_shape,
|
||||||
results['data_samples'].gt_sem_seg.shape)
|
results['data_samples'].gt_sem_seg.shape)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user