[Feature] Support gem pooling (#677)
* add gem pooling * add example config * fix params * add assert * add param clamp * add test assert * add clamp * fix conflictpull/692/head
parent
fcd57913ae
commit
43024cda73
|
@ -0,0 +1,17 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=34,
|
||||
num_stages=4,
|
||||
out_indices=(3, ),
|
||||
style='pytorch'),
|
||||
neck=dict(type='GeneralizedMeanPooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=512,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
))
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .gap import GlobalAveragePooling
|
||||
from .gem import GeneralizedMeanPooling
|
||||
from .hr_fuse import HRFuseScales
|
||||
|
||||
__all__ = ['GlobalAveragePooling', 'HRFuseScales']
|
||||
__all__ = ['GlobalAveragePooling', 'GeneralizedMeanPooling', 'HRFuseScales']
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from ..builder import NECKS
|
||||
|
||||
|
||||
def gem(x: Tensor, p: Parameter, eps: float = 1e-6, clamp=True) -> Tensor:
|
||||
if clamp:
|
||||
x = x.clamp(min=eps)
|
||||
return F.avg_pool2d(x.pow(p), (x.size(-2), x.size(-1))).pow(1. / p)
|
||||
|
||||
|
||||
@NECKS.register_module()
|
||||
class GeneralizedMeanPooling(nn.Module):
|
||||
"""Generalized Mean Pooling neck.
|
||||
|
||||
Note that we use `view` to remove extra channel after pooling. We do not
|
||||
use `squeeze` as it will also remove the batch dimension when the tensor
|
||||
has a batch dimension of size 1, which can lead to unexpected errors.
|
||||
|
||||
Args:
|
||||
p (float): Parameter value.
|
||||
Default: 3.
|
||||
eps (float): epsilon.
|
||||
Default: 1e-6
|
||||
clamp (bool): Use clamp before pooling.
|
||||
Default: True
|
||||
"""
|
||||
|
||||
def __init__(self, p=3., eps=1e-6, clamp=True):
|
||||
assert p >= 1, "'p' must be a value greater then 1"
|
||||
super(GeneralizedMeanPooling, self).__init__()
|
||||
self.p = Parameter(torch.ones(1) * p)
|
||||
self.eps = eps
|
||||
self.clamp = clamp
|
||||
|
||||
def forward(self, inputs):
|
||||
if isinstance(inputs, tuple):
|
||||
outs = tuple([
|
||||
gem(x, p=self.p, eps=self.eps, clamp=self.clamp)
|
||||
for x in inputs
|
||||
])
|
||||
outs = tuple(
|
||||
[out.view(x.size(0), -1) for out, x in zip(outs, inputs)])
|
||||
elif isinstance(inputs, torch.Tensor):
|
||||
outs = gem(inputs, p=self.p, eps=self.eps, clamp=self.clamp)
|
||||
outs = outs.view(inputs.size(0), -1)
|
||||
else:
|
||||
raise TypeError('neck inputs should be tuple or torch.tensor')
|
||||
return outs
|
|
@ -2,7 +2,8 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcls.models.necks import GlobalAveragePooling, HRFuseScales
|
||||
from mmcls.models.necks import (GeneralizedMeanPooling, GlobalAveragePooling,
|
||||
HRFuseScales)
|
||||
|
||||
|
||||
def test_gap_neck():
|
||||
|
@ -39,6 +40,32 @@ def test_gap_neck():
|
|||
GlobalAveragePooling(dim='other')
|
||||
|
||||
|
||||
def test_gem_neck():
|
||||
|
||||
# test gem_neck
|
||||
neck = GeneralizedMeanPooling()
|
||||
# batch_size, num_features, feature_size(2)
|
||||
fake_input = torch.rand(1, 16, 24, 24)
|
||||
|
||||
output = neck(fake_input)
|
||||
# batch_size, num_features
|
||||
assert output.shape == (1, 16)
|
||||
|
||||
# test tuple input gem_neck
|
||||
neck = GeneralizedMeanPooling()
|
||||
# batch_size, num_features, feature_size(2)
|
||||
fake_input = (torch.rand(1, 8, 24, 24), torch.rand(1, 16, 24, 24))
|
||||
|
||||
output = neck(fake_input)
|
||||
# batch_size, num_features
|
||||
assert output[0].shape == (1, 8)
|
||||
assert output[1].shape == (1, 16)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# p must be a value greater then 1
|
||||
GeneralizedMeanPooling(p=0.5)
|
||||
|
||||
|
||||
def test_hr_fuse_scales():
|
||||
|
||||
in_channels = (18, 32, 64, 128)
|
||||
|
|
Loading…
Reference in New Issue