mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
parent
dd47cef801
commit
3cc7ae2167
@ -74,12 +74,12 @@ class PackSegInputs(BaseTransform):
|
||||
|
||||
data_sample = SegDataSample()
|
||||
if 'gt_seg_map' in results:
|
||||
if results['gt_seg_map'].shape == 2:
|
||||
if len(results['gt_seg_map'].shape) == 2:
|
||||
data = to_tensor(results['gt_seg_map'][None,
|
||||
...].astype(np.int64))
|
||||
else:
|
||||
warnings.warn('Please pay attention your ground truth '
|
||||
'segmentation map, usually the segentation '
|
||||
'segmentation map, usually the segmentation '
|
||||
'map is 2D, but got '
|
||||
f'{results["gt_seg_map"].shape}')
|
||||
data = to_tensor(results['gt_seg_map'].astype(np.int64))
|
||||
|
@ -4,6 +4,7 @@ import os.path as osp
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mmengine.structures import BaseDataElement
|
||||
|
||||
from mmseg.datasets.transforms import PackSegInputs
|
||||
@ -46,8 +47,11 @@ class TestPackSegInputs(unittest.TestCase):
|
||||
self.assertEqual(results['data_samples'].ori_shape,
|
||||
results['data_samples'].gt_sem_seg.shape)
|
||||
results = copy.deepcopy(self.results)
|
||||
# test dataset shape is not 2D
|
||||
results['gt_seg_map'] = np.random.rand(3, 300, 400)
|
||||
results = transform(results)
|
||||
msg = 'the segmentation map is 2D'
|
||||
with pytest.warns(UserWarning, match=msg):
|
||||
results = transform(results)
|
||||
self.assertEqual(results['data_samples'].ori_shape,
|
||||
results['data_samples'].gt_sem_seg.shape)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user