# Copyright (c) OpenMMLab. All rights reserved. import copy import os.path as osp import unittest import numpy as np from mmengine.data import BaseDataElement from mmseg.data import SegDataSample from mmseg.datasets.transforms import PackSegInputs class TestPackSegInputs(unittest.TestCase): def setUp(self): """Setup the model and optimizer which are used in every test method. TestCase calls functions in this order: setUp() -> testMethod() -> tearDown() -> cleanUp() """ data_prefix = osp.join(osp.dirname(__file__), '../../data') img_path = osp.join(data_prefix, 'color.jpg') rng = np.random.RandomState(0) self.results = { 'img_path': img_path, 'ori_shape': (300, 400), 'pad_shape': (600, 800), 'img_shape': (600, 800), 'scale_factor': 2.0, 'flip': False, 'flip_direction': 'horizontal', 'img_norm_cfg': None, 'img': rng.rand(300, 400), 'gt_seg_map': rng.rand(300, 400), } self.meta_keys = ('img_path', 'ori_shape', 'img_shape', 'pad_shape', 'scale_factor', 'flip', 'flip_direction') def test_transform(self): transform = PackSegInputs(meta_keys=self.meta_keys) results = transform(copy.deepcopy(self.results)) self.assertIn('data_sample', results) self.assertIsInstance(results['data_sample'], SegDataSample) self.assertIsInstance(results['data_sample'].gt_sem_seg, BaseDataElement) self.assertEqual(results['data_sample'].ori_shape, results['data_sample'].gt_sem_seg.shape) def test_repr(self): transform = PackSegInputs(meta_keys=self.meta_keys) self.assertEqual( repr(transform), f'PackSegInputs(meta_keys={self.meta_keys})')