# Copyright (c) OpenMMLab. All rights reserved. import copy import os.path as osp import unittest import numpy as np import pytest from mmengine.structures import BaseDataElement from mmseg.datasets.transforms import PackSegInputs from mmseg.structures import SegDataSample class TestPackSegInputs(unittest.TestCase): def setUp(self): """Setup the model and optimizer which are used in every test method. TestCase calls functions in this order: setUp() -> testMethod() -> tearDown() -> cleanUp() """ data_prefix = osp.join(osp.dirname(__file__), '../../data') img_path = osp.join(data_prefix, 'color.jpg') rng = np.random.RandomState(0) self.results = { 'img_path': img_path, 'ori_shape': (300, 400), 'pad_shape': (600, 800), 'img_shape': (600, 800), 'scale_factor': 2.0, 'flip': False, 'flip_direction': 'horizontal', 'img_norm_cfg': None, 'img': rng.rand(300, 400), 'gt_seg_map': rng.rand(300, 400), } self.meta_keys = ('img_path', 'ori_shape', 'img_shape', 'pad_shape', 'scale_factor', 'flip', 'flip_direction') def test_transform(self): transform = PackSegInputs(meta_keys=self.meta_keys) results = transform(copy.deepcopy(self.results)) self.assertIn('data_samples', results) self.assertIsInstance(results['data_samples'], SegDataSample) self.assertIsInstance(results['data_samples'].gt_sem_seg, BaseDataElement) self.assertEqual(results['data_samples'].ori_shape, results['data_samples'].gt_sem_seg.shape) results = copy.deepcopy(self.results) # test dataset shape is not 2D results['gt_seg_map'] = np.random.rand(3, 300, 400) msg = 'the segmentation map is 2D' with pytest.warns(UserWarning, match=msg): results = transform(results) self.assertEqual(results['data_samples'].ori_shape, results['data_samples'].gt_sem_seg.shape) def test_repr(self): transform = PackSegInputs(meta_keys=self.meta_keys) self.assertEqual( repr(transform), f'PackSegInputs(meta_keys={self.meta_keys})')