[Enhance] Enable to toggle whether Gem Pooling is trainable or not. (#1246)
* Enable to toggle whether Gem Pooling is trainable or not. * Add test case whether Gem Pooling is trainable or not. * Enable to toggle whether Gem Pooling is trainable or not by requires_grad --------- Co-authored-by: Yusuke Fujimoto <yusuke.fujimoto@rist.co.jp>pull/1386/head
parent
a3f2effb17
commit
4ce7be17c9
|
@ -22,20 +22,20 @@ class GeneralizedMeanPooling(nn.Module):
|
|||
has a batch dimension of size 1, which can lead to unexpected errors.
|
||||
|
||||
Args:
|
||||
p (float): Parameter value.
|
||||
Default: 3.
|
||||
eps (float): epsilon.
|
||||
Default: 1e-6
|
||||
clamp (bool): Use clamp before pooling.
|
||||
Default: True
|
||||
p (float): Parameter value. Defaults to 3.
|
||||
eps (float): epsilon. Defaults to 1e-6.
|
||||
clamp (bool): Use clamp before pooling. Defaults to True
|
||||
p_trainable (bool): Toggle whether Parameter p is trainable or not.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, p=3., eps=1e-6, clamp=True):
|
||||
def __init__(self, p=3., eps=1e-6, clamp=True, p_trainable=True):
|
||||
assert p >= 1, "'p' must be a value greater than 1"
|
||||
super(GeneralizedMeanPooling, self).__init__()
|
||||
self.p = Parameter(torch.ones(1) * p)
|
||||
self.p = Parameter(torch.ones(1) * p, requires_grad=p_trainable)
|
||||
self.eps = eps
|
||||
self.clamp = clamp
|
||||
self.p_trainable = p_trainable
|
||||
|
||||
def forward(self, inputs):
|
||||
if isinstance(inputs, tuple):
|
||||
|
|
|
@ -44,6 +44,10 @@ def test_gem_neck():
|
|||
|
||||
# test gem_neck
|
||||
neck = GeneralizedMeanPooling()
|
||||
|
||||
# default p is trainable
|
||||
assert neck.p.requires_grad
|
||||
|
||||
# batch_size, num_features, feature_size(2)
|
||||
fake_input = torch.rand(1, 16, 24, 24)
|
||||
|
||||
|
@ -61,6 +65,19 @@ def test_gem_neck():
|
|||
assert output[0].shape == (1, 8)
|
||||
assert output[1].shape == (1, 16)
|
||||
|
||||
# test gem_neck with p_trainable=False
|
||||
neck = GeneralizedMeanPooling(p_trainable=False)
|
||||
|
||||
# p is not trainable
|
||||
assert not neck.p.requires_grad
|
||||
|
||||
# batch_size, num_features, feature_size(2)
|
||||
fake_input = torch.rand(1, 16, 24, 24)
|
||||
|
||||
output = neck(fake_input)
|
||||
# batch_size, num_features
|
||||
assert output.shape == (1, 16)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# p must be a value greater then 1
|
||||
GeneralizedMeanPooling(p=0.5)
|
||||
|
|
Loading…
Reference in New Issue