mirror of https://github.com/open-mmlab/mmocr.git
67 lines
2.2 KiB
Python
67 lines
2.2 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy
|
|
import unittest.mock as mock
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from mmocr.datasets.pipelines import (OneOfWrapper, RandomWrapper,
|
|
TorchVisionWrapper)
|
|
from mmocr.datasets.pipelines.transforms import ColorJitter
|
|
|
|
|
|
def test_torchvision_wrapper():
|
|
x = {'img': np.ones((128, 100, 3), dtype=np.uint8)}
|
|
# object not found error
|
|
with pytest.raises(Exception):
|
|
TorchVisionWrapper(op='NonExist')
|
|
with pytest.raises(TypeError):
|
|
TorchVisionWrapper()
|
|
f = TorchVisionWrapper('Grayscale')
|
|
with pytest.raises(AssertionError):
|
|
f({})
|
|
results = f(x)
|
|
assert results['img'].shape == (128, 100)
|
|
assert results['img_shape'] == (128, 100)
|
|
|
|
|
|
@mock.patch('random.choice')
|
|
def test_oneof(rand_choice):
|
|
color_jitter = dict(type='TorchVisionWrapper', op='ColorJitter')
|
|
gray_scale = dict(type='TorchVisionWrapper', op='Grayscale')
|
|
x = {'img': np.random.randint(0, 256, size=(128, 100, 3), dtype=np.uint8)}
|
|
f = OneOfWrapper([color_jitter, gray_scale])
|
|
# Use color_jitter at the first call
|
|
rand_choice.side_effect = lambda x: x[0]
|
|
results = f(x)
|
|
assert results['img'].shape == (128, 100, 3)
|
|
# Use gray_scale at the second call
|
|
rand_choice.side_effect = lambda x: x[1]
|
|
results = f(x)
|
|
assert results['img'].shape == (128, 100)
|
|
|
|
# Passing object
|
|
f = OneOfWrapper([ColorJitter(), gray_scale])
|
|
# Use color_jitter at the first call
|
|
results = f(x)
|
|
assert results['img'].shape == (128, 100)
|
|
|
|
# Test invalid inputs
|
|
with pytest.raises(AssertionError):
|
|
f = OneOfWrapper(None)
|
|
with pytest.raises(AssertionError):
|
|
f = OneOfWrapper([])
|
|
with pytest.raises(AssertionError):
|
|
f = OneOfWrapper({})
|
|
|
|
|
|
@mock.patch('numpy.random.uniform')
|
|
def test_runwithprob(np_random_uniform):
|
|
np_random_uniform.side_effect = [0.1, 0.9]
|
|
f = RandomWrapper([dict(type='TorchVisionWrapper', op='Grayscale')], 0.5)
|
|
img = np.random.randint(0, 256, size=(128, 100, 3), dtype=np.uint8)
|
|
results = f({'img': copy.deepcopy(img)})
|
|
assert results['img'].shape == (128, 100)
|
|
results = f({'img': copy.deepcopy(img)})
|
|
assert results['img'].shape == (128, 100, 3)
|