mirror of https://github.com/open-mmlab/mmcv.git
37 lines
1.1 KiB
Python
37 lines
1.1 KiB
Python
import pytest
|
|
import torch
|
|
|
|
from mmcv.cnn.bricks import HSigmoid
|
|
|
|
|
|
def test_hsigmoid():
|
|
# test assertion divisor can not be zero
|
|
with pytest.raises(AssertionError):
|
|
HSigmoid(divisor=0)
|
|
|
|
# test with default parameters
|
|
act = HSigmoid()
|
|
input_shape = torch.Size([1, 3, 64, 64])
|
|
input = torch.randn(input_shape)
|
|
output = act(input)
|
|
expected_output = torch.min(
|
|
torch.max((input + 1) / 2, torch.zeros(input_shape)),
|
|
torch.ones(input_shape))
|
|
# test output shape
|
|
assert output.shape == expected_output.shape
|
|
# test output value
|
|
assert torch.equal(output, expected_output)
|
|
|
|
# test with designated parameters
|
|
act = HSigmoid(3, 6, 0, 1)
|
|
input_shape = torch.Size([1, 3, 64, 64])
|
|
input = torch.randn(input_shape)
|
|
output = act(input)
|
|
expected_output = torch.min(
|
|
torch.max((input + 3) / 6, torch.zeros(input_shape)),
|
|
torch.ones(input_shape))
|
|
# test output shape
|
|
assert output.shape == expected_output.shape
|
|
# test output value
|
|
assert torch.equal(output, expected_output)
|