50 lines
1.7 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import numpy as np
import mmcls.datasets # noqa: F401,F403
from mmcls.registry import TRANSFORMS
class TestResizeEdge(TestCase):
def test_transform(self):
results = dict(img=np.random.randint(0, 256, (128, 256, 3), np.uint8))
# test resize short edge by default.
cfg = dict(type='ResizeEdge', scale=224)
transform = TRANSFORMS.build(cfg)
results = transform(results)
self.assertTupleEqual(results['img'].shape, (224, 448, 3))
# test resize long edge.
cfg = dict(type='ResizeEdge', scale=224, edge='long')
transform = TRANSFORMS.build(cfg)
results = transform(results)
self.assertTupleEqual(results['img'].shape, (112, 224, 3))
# test resize width.
cfg = dict(type='ResizeEdge', scale=224, edge='width')
transform = TRANSFORMS.build(cfg)
results = transform(results)
self.assertTupleEqual(results['img'].shape, (112, 224, 3))
# test resize height.
cfg = dict(type='ResizeEdge', scale=224, edge='height')
transform = TRANSFORMS.build(cfg)
results = transform(results)
self.assertTupleEqual(results['img'].shape, (224, 448, 3))
# test invalid edge
with self.assertRaisesRegex(AssertionError, 'Invalid edge "hi"'):
cfg = dict(type='ResizeEdge', scale=224, edge='hi')
TRANSFORMS.build(cfg)
def test_repr(self):
cfg = dict(type='ResizeEdge', scale=224, edge='height')
transform = TRANSFORMS.build(cfg)
self.assertEqual(
repr(transform), 'ResizeEdge(scale=224, edge=height, backend=cv2, '
'interpolation=bilinear)')