diff --git a/configs/_base_/models/resnet34_gem.py b/configs/_base_/models/resnet34_gem.py new file mode 100644 index 00000000..5c0e0d3e --- /dev/null +++ b/configs/_base_/models/resnet34_gem.py @@ -0,0 +1,17 @@ +# model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='ResNet', + depth=34, + num_stages=4, + out_indices=(3, ), + style='pytorch'), + neck=dict(type='GeneralizedMeanPooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=512, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5), + )) diff --git a/mmcls/models/necks/__init__.py b/mmcls/models/necks/__init__.py index 6f3ae47c..aa5411f0 100644 --- a/mmcls/models/necks/__init__.py +++ b/mmcls/models/necks/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .gap import GlobalAveragePooling +from .gem import GeneralizedMeanPooling from .hr_fuse import HRFuseScales -__all__ = ['GlobalAveragePooling', 'HRFuseScales'] +__all__ = ['GlobalAveragePooling', 'GeneralizedMeanPooling', 'HRFuseScales'] diff --git a/mmcls/models/necks/gem.py b/mmcls/models/necks/gem.py new file mode 100644 index 00000000..f499357c --- /dev/null +++ b/mmcls/models/necks/gem.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import Tensor, nn +from torch.nn import functional as F +from torch.nn.parameter import Parameter + +from ..builder import NECKS + + +def gem(x: Tensor, p: Parameter, eps: float = 1e-6, clamp=True) -> Tensor: + if clamp: + x = x.clamp(min=eps) + return F.avg_pool2d(x.pow(p), (x.size(-2), x.size(-1))).pow(1. / p) + + +@NECKS.register_module() +class GeneralizedMeanPooling(nn.Module): + """Generalized Mean Pooling neck. + + Note that we use `view` to remove extra channel after pooling. We do not + use `squeeze` as it will also remove the batch dimension when the tensor + has a batch dimension of size 1, which can lead to unexpected errors. + + Args: + p (float): Parameter value. + Default: 3. + eps (float): epsilon. + Default: 1e-6 + clamp (bool): Use clamp before pooling. + Default: True + """ + + def __init__(self, p=3., eps=1e-6, clamp=True): + assert p >= 1, "'p' must be a value greater then 1" + super(GeneralizedMeanPooling, self).__init__() + self.p = Parameter(torch.ones(1) * p) + self.eps = eps + self.clamp = clamp + + def forward(self, inputs): + if isinstance(inputs, tuple): + outs = tuple([ + gem(x, p=self.p, eps=self.eps, clamp=self.clamp) + for x in inputs + ]) + outs = tuple( + [out.view(x.size(0), -1) for out, x in zip(outs, inputs)]) + elif isinstance(inputs, torch.Tensor): + outs = gem(inputs, p=self.p, eps=self.eps, clamp=self.clamp) + outs = outs.view(inputs.size(0), -1) + else: + raise TypeError('neck inputs should be tuple or torch.tensor') + return outs diff --git a/tests/test_models/test_neck.py b/tests/test_models/test_neck.py index 08e2e421..b554e3da 100644 --- a/tests/test_models/test_neck.py +++ b/tests/test_models/test_neck.py @@ -2,7 +2,8 @@ import pytest import torch -from mmcls.models.necks import GlobalAveragePooling, HRFuseScales +from mmcls.models.necks import (GeneralizedMeanPooling, GlobalAveragePooling, + HRFuseScales) def test_gap_neck(): @@ -39,6 +40,32 @@ def test_gap_neck(): GlobalAveragePooling(dim='other') +def test_gem_neck(): + + # test gem_neck + neck = GeneralizedMeanPooling() + # batch_size, num_features, feature_size(2) + fake_input = torch.rand(1, 16, 24, 24) + + output = neck(fake_input) + # batch_size, num_features + assert output.shape == (1, 16) + + # test tuple input gem_neck + neck = GeneralizedMeanPooling() + # batch_size, num_features, feature_size(2) + fake_input = (torch.rand(1, 8, 24, 24), torch.rand(1, 16, 24, 24)) + + output = neck(fake_input) + # batch_size, num_features + assert output[0].shape == (1, 8) + assert output[1].shape == (1, 16) + + with pytest.raises(AssertionError): + # p must be a value greater then 1 + GeneralizedMeanPooling(p=0.5) + + def test_hr_fuse_scales(): in_channels = (18, 32, 64, 128)