mmsegmentation/tests/test_datasets/test_formatting.py

53 lines
1.9 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
import unittest
import numpy as np
from mmengine.data import BaseDataElement
2022-07-15 23:47:29 +08:00
from mmseg.data import SegDataSample
from mmseg.datasets.transforms import PackSegInputs
class TestPackSegInputs(unittest.TestCase):
def setUp(self):
"""Setup the model and optimizer which are used in every test method.
TestCase calls functions in this order: setUp() -> testMethod() ->
tearDown() -> cleanUp()
"""
data_prefix = osp.join(osp.dirname(__file__), '../../data')
img_path = osp.join(data_prefix, 'color.jpg')
rng = np.random.RandomState(0)
self.results = {
'img_path': img_path,
'ori_shape': (300, 400),
'pad_shape': (600, 800),
'img_shape': (600, 800),
'scale_factor': 2.0,
'flip': False,
'flip_direction': 'horizontal',
'img_norm_cfg': None,
'img': rng.rand(300, 400),
'gt_seg_map': rng.rand(300, 400),
}
self.meta_keys = ('img_path', 'ori_shape', 'img_shape', 'pad_shape',
'scale_factor', 'flip', 'flip_direction')
def test_transform(self):
transform = PackSegInputs(meta_keys=self.meta_keys)
results = transform(copy.deepcopy(self.results))
self.assertIn('data_sample', results)
self.assertIsInstance(results['data_sample'], SegDataSample)
self.assertIsInstance(results['data_sample'].gt_sem_seg,
BaseDataElement)
self.assertEqual(results['data_sample'].ori_shape,
results['data_sample'].gt_sem_seg.shape)
def test_repr(self):
transform = PackSegInputs(meta_keys=self.meta_keys)
self.assertEqual(
repr(transform), f'PackSegInputs(meta_keys={self.meta_keys})')