mirror of https://github.com/JDAI-CV/fast-reid.git
22 lines
497 B
Python
22 lines
497 B
Python
|
# encoding: utf-8
|
||
|
"""
|
||
|
@author: liaoxingyu
|
||
|
@contact: sherlockliao01@gmail.com
|
||
|
"""
|
||
|
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
from torch.nn.parameter import Parameter
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
__all__ = ['GeM',]
|
||
|
|
||
|
|
||
|
class GeM(nn.Module):
|
||
|
def __init__(self, p=3, eps=1e-6):
|
||
|
super().__init__()
|
||
|
self.p = Parameter(torch.ones(1)*p)
|
||
|
self.eps = eps
|
||
|
|
||
|
def forward(self, x):
|
||
|
return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1./self.p)
|