mirror of
https://github.com/PyRetri/PyRetri.git
synced 2025-06-03 14:49:50 +08:00
52 lines
1.6 KiB
Python
52 lines
1.6 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
import torch
|
|
|
|
from ..aggregators_base import AggregatorBase
|
|
from ...registry import AGGREGATORS
|
|
|
|
from typing import Dict
|
|
|
|
@AGGREGATORS.register
|
|
class GeM(AggregatorBase):
|
|
"""
|
|
Generalized-mean pooling.
|
|
c.f. https://pdfs.semanticscholar.org/a2ca/e0ed91d8a3298b3209fc7ea0a4248b914386.pdf
|
|
|
|
Hyper-Params
|
|
p (float): hyper-parameter for calculating generalized mean. If p = 1, GeM is equal to global average pooling, and
|
|
if p = +infinity, GeM is equal to global max pooling.
|
|
"""
|
|
default_hyper_params = {
|
|
"p": 3.0,
|
|
}
|
|
|
|
def __init__(self, hps: Dict or None = None):
|
|
"""
|
|
Args:
|
|
hps (dict): default hyper parameters in a dict (keys, values).
|
|
"""
|
|
self.first_show = True
|
|
super(GeM, self).__init__(hps)
|
|
|
|
def __call__(self, features: Dict[str, torch.tensor]) -> Dict[str, torch.tensor]:
|
|
p = self._hyper_params["p"]
|
|
|
|
ret = dict()
|
|
for key in features:
|
|
fea = features[key]
|
|
if fea.ndimension() == 4:
|
|
fea = fea ** p
|
|
h, w = fea.shape[2:]
|
|
fea = fea.sum(dim=(2, 3)) * 1.0 / w / h
|
|
fea = fea ** (1.0 / p)
|
|
ret[key + "_{}".format(self.__class__.__name__)] = fea
|
|
else:
|
|
# In case of fc feature.
|
|
assert fea.ndimension() == 2
|
|
if self.first_show:
|
|
print("[GeM Aggregator]: find 2-dimension feature map, skip aggregation")
|
|
self.first_show = False
|
|
ret[key] = fea
|
|
return ret
|