[Enhance] Enable to toggle whether Gem Pooling is trainable or not. (#1246)
* Enable to toggle whether Gem Pooling is trainable or not. * Add test case whether Gem Pooling is trainable or not. * Enable to toggle whether Gem Pooling is trainable or not by requires_grad --------- Co-authored-by: Yusuke Fujimoto <yusuke.fujimoto@rist.co.jp>pull/1386/head
parent
a3f2effb17
commit
4ce7be17c9
|
@ -22,20 +22,20 @@ class GeneralizedMeanPooling(nn.Module):
|
||||||
has a batch dimension of size 1, which can lead to unexpected errors.
|
has a batch dimension of size 1, which can lead to unexpected errors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
p (float): Parameter value.
|
p (float): Parameter value. Defaults to 3.
|
||||||
Default: 3.
|
eps (float): epsilon. Defaults to 1e-6.
|
||||||
eps (float): epsilon.
|
clamp (bool): Use clamp before pooling. Defaults to True
|
||||||
Default: 1e-6
|
p_trainable (bool): Toggle whether Parameter p is trainable or not.
|
||||||
clamp (bool): Use clamp before pooling.
|
Defaults to True.
|
||||||
Default: True
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, p=3., eps=1e-6, clamp=True):
|
def __init__(self, p=3., eps=1e-6, clamp=True, p_trainable=True):
|
||||||
assert p >= 1, "'p' must be a value greater than 1"
|
assert p >= 1, "'p' must be a value greater than 1"
|
||||||
super(GeneralizedMeanPooling, self).__init__()
|
super(GeneralizedMeanPooling, self).__init__()
|
||||||
self.p = Parameter(torch.ones(1) * p)
|
self.p = Parameter(torch.ones(1) * p, requires_grad=p_trainable)
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.clamp = clamp
|
self.clamp = clamp
|
||||||
|
self.p_trainable = p_trainable
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
if isinstance(inputs, tuple):
|
if isinstance(inputs, tuple):
|
||||||
|
|
|
@ -44,6 +44,10 @@ def test_gem_neck():
|
||||||
|
|
||||||
# test gem_neck
|
# test gem_neck
|
||||||
neck = GeneralizedMeanPooling()
|
neck = GeneralizedMeanPooling()
|
||||||
|
|
||||||
|
# default p is trainable
|
||||||
|
assert neck.p.requires_grad
|
||||||
|
|
||||||
# batch_size, num_features, feature_size(2)
|
# batch_size, num_features, feature_size(2)
|
||||||
fake_input = torch.rand(1, 16, 24, 24)
|
fake_input = torch.rand(1, 16, 24, 24)
|
||||||
|
|
||||||
|
@ -61,6 +65,19 @@ def test_gem_neck():
|
||||||
assert output[0].shape == (1, 8)
|
assert output[0].shape == (1, 8)
|
||||||
assert output[1].shape == (1, 16)
|
assert output[1].shape == (1, 16)
|
||||||
|
|
||||||
|
# test gem_neck with p_trainable=False
|
||||||
|
neck = GeneralizedMeanPooling(p_trainable=False)
|
||||||
|
|
||||||
|
# p is not trainable
|
||||||
|
assert not neck.p.requires_grad
|
||||||
|
|
||||||
|
# batch_size, num_features, feature_size(2)
|
||||||
|
fake_input = torch.rand(1, 16, 24, 24)
|
||||||
|
|
||||||
|
output = neck(fake_input)
|
||||||
|
# batch_size, num_features
|
||||||
|
assert output.shape == (1, 16)
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
# p must be a value greater then 1
|
# p must be a value greater then 1
|
||||||
GeneralizedMeanPooling(p=0.5)
|
GeneralizedMeanPooling(p=0.5)
|
||||||
|
|
Loading…
Reference in New Issue