mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers. ## Motivation fix #2593 ## Modification 1. Only when gt seg map is 2D, extend its shape to 3D PixelData 2. If seg map is not 2D, we raised warning for users. --------- Co-authored-by: xiexinch <xiexinch@outlook.com>
58 lines
2.1 KiB
Python
58 lines
2.1 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy
|
|
import os.path as osp
|
|
import unittest
|
|
|
|
import numpy as np
|
|
from mmengine.structures import BaseDataElement
|
|
|
|
from mmseg.datasets.transforms import PackSegInputs
|
|
from mmseg.structures import SegDataSample
|
|
|
|
|
|
class TestPackSegInputs(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
"""Setup the model and optimizer which are used in every test method.
|
|
|
|
TestCase calls functions in this order: setUp() -> testMethod() ->
|
|
tearDown() -> cleanUp()
|
|
"""
|
|
data_prefix = osp.join(osp.dirname(__file__), '../../data')
|
|
img_path = osp.join(data_prefix, 'color.jpg')
|
|
rng = np.random.RandomState(0)
|
|
self.results = {
|
|
'img_path': img_path,
|
|
'ori_shape': (300, 400),
|
|
'pad_shape': (600, 800),
|
|
'img_shape': (600, 800),
|
|
'scale_factor': 2.0,
|
|
'flip': False,
|
|
'flip_direction': 'horizontal',
|
|
'img_norm_cfg': None,
|
|
'img': rng.rand(300, 400),
|
|
'gt_seg_map': rng.rand(300, 400),
|
|
}
|
|
self.meta_keys = ('img_path', 'ori_shape', 'img_shape', 'pad_shape',
|
|
'scale_factor', 'flip', 'flip_direction')
|
|
|
|
def test_transform(self):
|
|
transform = PackSegInputs(meta_keys=self.meta_keys)
|
|
results = transform(copy.deepcopy(self.results))
|
|
self.assertIn('data_samples', results)
|
|
self.assertIsInstance(results['data_samples'], SegDataSample)
|
|
self.assertIsInstance(results['data_samples'].gt_sem_seg,
|
|
BaseDataElement)
|
|
self.assertEqual(results['data_samples'].ori_shape,
|
|
results['data_samples'].gt_sem_seg.shape)
|
|
results = copy.deepcopy(self.results)
|
|
results['gt_seg_map'] = np.random.rand(3, 300, 400)
|
|
results = transform(results)
|
|
self.assertEqual(results['data_samples'].ori_shape,
|
|
results['data_samples'].gt_sem_seg.shape)
|
|
|
|
def test_repr(self):
|
|
transform = PackSegInputs(meta_keys=self.meta_keys)
|
|
self.assertEqual(
|
|
repr(transform), f'PackSegInputs(meta_keys={self.meta_keys})')
|