mirror of https://github.com/open-mmlab/mmocr.git
56 lines
2.1 KiB
Python
56 lines
2.1 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy
|
|
import unittest
|
|
|
|
import numpy as np
|
|
|
|
from mmocr.datasets.pipelines import PyramidRescale
|
|
|
|
|
|
class TestPyramidRescale(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
self.data_info = dict(img=np.random.random((128, 100, 3)))
|
|
|
|
def test_init(self):
|
|
# factor is int
|
|
transform = PyramidRescale(factor=4, randomize_factor=False)
|
|
self.assertEqual(transform.factor, 4)
|
|
# factor is float
|
|
with self.assertRaisesRegex(TypeError,
|
|
'`factor` should be an integer'):
|
|
PyramidRescale(factor=4.0)
|
|
# invalid base_shape
|
|
with self.assertRaisesRegex(TypeError,
|
|
'`base_shape` should be a list or tuple'):
|
|
PyramidRescale(base_shape=128)
|
|
with self.assertRaisesRegex(
|
|
ValueError, '`base_shape` should contain two integers'):
|
|
PyramidRescale(base_shape=(128, ))
|
|
with self.assertRaisesRegex(
|
|
ValueError, '`base_shape` should contain two integers'):
|
|
PyramidRescale(base_shape=(128.0, 2.0))
|
|
# invalid randomize_factor
|
|
with self.assertRaisesRegex(TypeError,
|
|
'`randomize_factor` should be a bool'):
|
|
PyramidRescale(randomize_factor=None)
|
|
|
|
def test_transform(self):
|
|
# test if the rescale keeps the original size
|
|
transform = PyramidRescale()
|
|
results = transform(copy.deepcopy(self.data_info))
|
|
self.assertEqual(results['img'].shape, (128, 100, 3))
|
|
# test factor = 0
|
|
transform = PyramidRescale(factor=0, randomize_factor=False)
|
|
results = transform(copy.deepcopy(self.data_info))
|
|
self.assertTrue(np.all(results['img'] == self.data_info['img']))
|
|
|
|
def test_repr(self):
|
|
transform = PyramidRescale(
|
|
factor=4, base_shape=(128, 512), randomize_factor=False)
|
|
print(repr(transform))
|
|
self.assertEqual(
|
|
repr(transform),
|
|
('PyramidRescale(factor = 4, randomize_factor = False, '
|
|
'base_w = 128, base_h = 512)'))
|