mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
50 lines
1.7 KiB
Python
50 lines
1.7 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from unittest import TestCase
|
|
|
|
import numpy as np
|
|
|
|
import mmcls.datasets # noqa: F401,F403
|
|
from mmcls.registry import TRANSFORMS
|
|
|
|
|
|
class TestResizeEdge(TestCase):
|
|
|
|
def test_transform(self):
|
|
results = dict(img=np.random.randint(0, 256, (128, 256, 3), np.uint8))
|
|
|
|
# test resize short edge by default.
|
|
cfg = dict(type='ResizeEdge', scale=224)
|
|
transform = TRANSFORMS.build(cfg)
|
|
results = transform(results)
|
|
self.assertTupleEqual(results['img'].shape, (224, 448, 3))
|
|
|
|
# test resize long edge.
|
|
cfg = dict(type='ResizeEdge', scale=224, edge='long')
|
|
transform = TRANSFORMS.build(cfg)
|
|
results = transform(results)
|
|
self.assertTupleEqual(results['img'].shape, (112, 224, 3))
|
|
|
|
# test resize width.
|
|
cfg = dict(type='ResizeEdge', scale=224, edge='width')
|
|
transform = TRANSFORMS.build(cfg)
|
|
results = transform(results)
|
|
self.assertTupleEqual(results['img'].shape, (112, 224, 3))
|
|
|
|
# test resize height.
|
|
cfg = dict(type='ResizeEdge', scale=224, edge='height')
|
|
transform = TRANSFORMS.build(cfg)
|
|
results = transform(results)
|
|
self.assertTupleEqual(results['img'].shape, (224, 448, 3))
|
|
|
|
# test invalid edge
|
|
with self.assertRaisesRegex(AssertionError, 'Invalid edge "hi"'):
|
|
cfg = dict(type='ResizeEdge', scale=224, edge='hi')
|
|
TRANSFORMS.build(cfg)
|
|
|
|
def test_repr(self):
|
|
cfg = dict(type='ResizeEdge', scale=224, edge='height')
|
|
transform = TRANSFORMS.build(cfg)
|
|
self.assertEqual(
|
|
repr(transform), 'ResizeEdge(scale=224, edge=height, backend=cv2, '
|
|
'interpolation=bilinear)')
|