fast-reid/fastreid/layers/pooling.py

22 lines
497 B
Python
Raw Normal View History

2020-02-10 07:38:56 +08:00
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
from torch import nn
from torch.nn.parameter import Parameter
import torch.nn.functional as F
__all__ = ['GeM',]
class GeM(nn.Module):
def __init__(self, p=3, eps=1e-6):
super().__init__()
self.p = Parameter(torch.ones(1)*p)
self.eps = eps
def forward(self, x):
return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1./self.p)