mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
parent
1290bdd181
commit
ec43b671ab
@ -6,16 +6,28 @@ from .registry import ACTIVATION_LAYERS
|
|||||||
@ACTIVATION_LAYERS.register_module()
|
@ACTIVATION_LAYERS.register_module()
|
||||||
class HSigmoid(nn.Module):
|
class HSigmoid(nn.Module):
|
||||||
"""Hard Sigmoid Module. Apply the hard sigmoid function:
|
"""Hard Sigmoid Module. Apply the hard sigmoid function:
|
||||||
Hsigmoid(x) = min(max((x + 1) / 2, 0), 1)
|
Hsigmoid(x) = min(max((x + bias) / divisor, min_value), max_value)
|
||||||
|
Default: Hsigmoid(x) = min(max((x + 1) / 2, 0), 1)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bias (float): Bias of the input feature map. Default: 1.0.
|
||||||
|
divisor (float): Divisor of the input feature map. Default: 2.0.
|
||||||
|
min_value (float): Lower bound value. Default: 0.0.
|
||||||
|
max_value (float): Upper bound value. Default: 1.0.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: The output tensor.
|
Tensor: The output tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, bias=1.0, divisor=2.0, min_value=0.0, max_value=1.0):
|
||||||
super(HSigmoid, self).__init__()
|
super(HSigmoid, self).__init__()
|
||||||
|
self.bias = bias
|
||||||
|
self.divisor = divisor
|
||||||
|
assert self.divisor != 0
|
||||||
|
self.min_value = min_value
|
||||||
|
self.max_value = max_value
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = (x + 1) / 2
|
x = (x + self.bias) / self.divisor
|
||||||
|
|
||||||
return x.clamp_(0, 1)
|
return x.clamp_(self.min_value, self.max_value)
|
||||||
|
|||||||
@ -1,9 +1,15 @@
|
|||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mmcv.cnn.bricks import HSigmoid
|
from mmcv.cnn.bricks import HSigmoid
|
||||||
|
|
||||||
|
|
||||||
def test_hsigmoid():
|
def test_hsigmoid():
|
||||||
|
# test assertion divisor can not be zero
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
HSigmoid(divisor=0)
|
||||||
|
|
||||||
|
# test with default parameters
|
||||||
act = HSigmoid()
|
act = HSigmoid()
|
||||||
input_shape = torch.Size([1, 3, 64, 64])
|
input_shape = torch.Size([1, 3, 64, 64])
|
||||||
input = torch.randn(input_shape)
|
input = torch.randn(input_shape)
|
||||||
@ -15,3 +21,16 @@ def test_hsigmoid():
|
|||||||
assert output.shape == expected_output.shape
|
assert output.shape == expected_output.shape
|
||||||
# test output value
|
# test output value
|
||||||
assert torch.equal(output, expected_output)
|
assert torch.equal(output, expected_output)
|
||||||
|
|
||||||
|
# test with designated parameters
|
||||||
|
act = HSigmoid(3, 6, 0, 1)
|
||||||
|
input_shape = torch.Size([1, 3, 64, 64])
|
||||||
|
input = torch.randn(input_shape)
|
||||||
|
output = act(input)
|
||||||
|
expected_output = torch.min(
|
||||||
|
torch.max((input + 3) / 6, torch.zeros(input_shape)),
|
||||||
|
torch.ones(input_shape))
|
||||||
|
# test output shape
|
||||||
|
assert output.shape == expected_output.shape
|
||||||
|
# test output value
|
||||||
|
assert torch.equal(output, expected_output)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user