[Enhance] Enable to toggle whether Gem Pooling is trainable or not. (#1246)

* Enable to toggle whether Gem Pooling is trainable or not.

* Add test case whether Gem Pooling is trainable or not.

* Enable to toggle whether Gem Pooling is trainable or not by requires_grad

---------

Co-authored-by: Yusuke Fujimoto <yusuke.fujimoto@rist.co.jp>
pull/1386/head
fam_taro 2023-02-09 12:27:05 +09:00 committed by GitHub
parent a3f2effb17
commit 4ce7be17c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 8 deletions

View File

@ -22,20 +22,20 @@ class GeneralizedMeanPooling(nn.Module):
has a batch dimension of size 1, which can lead to unexpected errors.
Args:
p (float): Parameter value.
Default: 3.
eps (float): epsilon.
Default: 1e-6
clamp (bool): Use clamp before pooling.
Default: True
p (float): Parameter value. Defaults to 3.
eps (float): epsilon. Defaults to 1e-6.
clamp (bool): Use clamp before pooling. Defaults to True
p_trainable (bool): Toggle whether Parameter p is trainable or not.
Defaults to True.
"""
def __init__(self, p=3., eps=1e-6, clamp=True):
def __init__(self, p=3., eps=1e-6, clamp=True, p_trainable=True):
assert p >= 1, "'p' must be a value greater than 1"
super(GeneralizedMeanPooling, self).__init__()
self.p = Parameter(torch.ones(1) * p)
self.p = Parameter(torch.ones(1) * p, requires_grad=p_trainable)
self.eps = eps
self.clamp = clamp
self.p_trainable = p_trainable
def forward(self, inputs):
if isinstance(inputs, tuple):

View File

@ -44,6 +44,10 @@ def test_gem_neck():
# test gem_neck
neck = GeneralizedMeanPooling()
# default p is trainable
assert neck.p.requires_grad
# batch_size, num_features, feature_size(2)
fake_input = torch.rand(1, 16, 24, 24)
@ -61,6 +65,19 @@ def test_gem_neck():
assert output[0].shape == (1, 8)
assert output[1].shape == (1, 16)
# test gem_neck with p_trainable=False
neck = GeneralizedMeanPooling(p_trainable=False)
# p is not trainable
assert not neck.p.requires_grad
# batch_size, num_features, feature_size(2)
fake_input = torch.rand(1, 16, 24, 24)
output = neck(fake_input)
# batch_size, num_features
assert output.shape == (1, 16)
with pytest.raises(AssertionError):
# p must be a value greater then 1
GeneralizedMeanPooling(p=0.5)