mirror of https://github.com/open-mmlab/mmcv.git
[Fix] Align HSigmoid with pytorch official implementation(#1622)
* [Fix] align hsigmoid with pytorch official * [Fix] add warnings for Hsigmoid * [Fix] fix format * [Fix] add unittest * [Fix] fix docstringpull/1445/merge
parent
b6e1ab7e83
commit
b8d78336a7
|
@ -1,4 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from .registry import ACTIVATION_LAYERS
|
||||
|
@ -8,11 +10,15 @@ from .registry import ACTIVATION_LAYERS
|
|||
class HSigmoid(nn.Module):
|
||||
"""Hard Sigmoid Module. Apply the hard sigmoid function:
|
||||
Hsigmoid(x) = min(max((x + bias) / divisor, min_value), max_value)
|
||||
Default: Hsigmoid(x) = min(max((x + 1) / 2, 0), 1)
|
||||
Default: Hsigmoid(x) = min(max((x + 3) / 6, 0), 1)
|
||||
|
||||
Note:
|
||||
In MMCV v1.4.4, we modified the default value of args to align with
|
||||
PyTorch official.
|
||||
|
||||
Args:
|
||||
bias (float): Bias of the input feature map. Default: 1.0.
|
||||
divisor (float): Divisor of the input feature map. Default: 2.0.
|
||||
bias (float): Bias of the input feature map. Default: 3.0.
|
||||
divisor (float): Divisor of the input feature map. Default: 6.0.
|
||||
min_value (float): Lower bound value. Default: 0.0.
|
||||
max_value (float): Upper bound value. Default: 1.0.
|
||||
|
||||
|
@ -20,8 +26,14 @@ class HSigmoid(nn.Module):
|
|||
Tensor: The output tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, bias=1.0, divisor=2.0, min_value=0.0, max_value=1.0):
|
||||
def __init__(self, bias=3.0, divisor=6.0, min_value=0.0, max_value=1.0):
|
||||
super(HSigmoid, self).__init__()
|
||||
warnings.warn(
|
||||
'In MMCV v1.4.4, we modified the default value of args to align '
|
||||
'with PyTorch official. Previous Implementation: '
|
||||
'Hsigmoid(x) = min(max((x + 1) / 2, 0), 1). '
|
||||
'Current Implementation: '
|
||||
'Hsigmoid(x) = min(max((x + 3) / 6, 0), 1).')
|
||||
self.bias = bias
|
||||
self.divisor = divisor
|
||||
assert self.divisor != 0
|
||||
|
|
|
@ -6,4 +6,4 @@ onnxruntime>=1.8.0
|
|||
pytest
|
||||
PyTurboJPEG
|
||||
scipy
|
||||
tiffile
|
||||
tifffile
|
||||
|
|
|
@ -15,7 +15,7 @@ def test_hsigmoid():
|
|||
input = torch.randn(input_shape)
|
||||
output = act(input)
|
||||
expected_output = torch.min(
|
||||
torch.max((input + 1) / 2, torch.zeros(input_shape)),
|
||||
torch.max((input + 3) / 6, torch.zeros(input_shape)),
|
||||
torch.ones(input_shape))
|
||||
# test output shape
|
||||
assert output.shape == expected_output.shape
|
||||
|
@ -23,12 +23,12 @@ def test_hsigmoid():
|
|||
assert torch.equal(output, expected_output)
|
||||
|
||||
# test with designated parameters
|
||||
act = HSigmoid(3, 6, 0, 1)
|
||||
act = HSigmoid(1, 2, 0, 1)
|
||||
input_shape = torch.Size([1, 3, 64, 64])
|
||||
input = torch.randn(input_shape)
|
||||
output = act(input)
|
||||
expected_output = torch.min(
|
||||
torch.max((input + 3) / 6, torch.zeros(input_shape)),
|
||||
torch.max((input + 1) / 2, torch.zeros(input_shape)),
|
||||
torch.ones(input_shape))
|
||||
# test output shape
|
||||
assert output.shape == expected_output.shape
|
||||
|
|
Loading…
Reference in New Issue