diff --git a/mmcls/models/necks/gem.py b/mmcls/models/necks/gem.py index fd048469d..ce6edb36d 100644 --- a/mmcls/models/necks/gem.py +++ b/mmcls/models/necks/gem.py @@ -22,20 +22,20 @@ class GeneralizedMeanPooling(nn.Module): has a batch dimension of size 1, which can lead to unexpected errors. Args: - p (float): Parameter value. - Default: 3. - eps (float): epsilon. - Default: 1e-6 - clamp (bool): Use clamp before pooling. - Default: True + p (float): Parameter value. Defaults to 3. + eps (float): epsilon. Defaults to 1e-6. + clamp (bool): Use clamp before pooling. Defaults to True + p_trainable (bool): Toggle whether Parameter p is trainable or not. + Defaults to True. """ - def __init__(self, p=3., eps=1e-6, clamp=True): + def __init__(self, p=3., eps=1e-6, clamp=True, p_trainable=True): assert p >= 1, "'p' must be a value greater than 1" super(GeneralizedMeanPooling, self).__init__() - self.p = Parameter(torch.ones(1) * p) + self.p = Parameter(torch.ones(1) * p, requires_grad=p_trainable) self.eps = eps self.clamp = clamp + self.p_trainable = p_trainable def forward(self, inputs): if isinstance(inputs, tuple): diff --git a/tests/test_models/test_necks.py b/tests/test_models/test_necks.py index 6d8e1ea45..8fa2156e2 100644 --- a/tests/test_models/test_necks.py +++ b/tests/test_models/test_necks.py @@ -44,6 +44,10 @@ def test_gem_neck(): # test gem_neck neck = GeneralizedMeanPooling() + + # default p is trainable + assert neck.p.requires_grad + # batch_size, num_features, feature_size(2) fake_input = torch.rand(1, 16, 24, 24) @@ -61,6 +65,19 @@ def test_gem_neck(): assert output[0].shape == (1, 8) assert output[1].shape == (1, 16) + # test gem_neck with p_trainable=False + neck = GeneralizedMeanPooling(p_trainable=False) + + # p is not trainable + assert not neck.p.requires_grad + + # batch_size, num_features, feature_size(2) + fake_input = torch.rand(1, 16, 24, 24) + + output = neck(fake_input) + # batch_size, num_features + assert output.shape == (1, 16) + with pytest.raises(AssertionError): # p must be a value greater then 1 GeneralizedMeanPooling(p=0.5)