PyRetri/pyretri/extract/aggregator/aggregators_impl/spoc.py

55 lines
2.0 KiB
Python

# -*- coding: utf-8 -*-
import torch
import numpy as np
from ..aggregators_base import AggregatorBase
from ...registry import AGGREGATORS
from typing import Dict
@AGGREGATORS.register
class SPoC(AggregatorBase):
"""
SPoC with center prior.
c.f. https://arxiv.org/pdf/1510.07493.pdf
"""
default_hyper_params = dict()
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(SPoC, self).__init__(hps)
self.first_show = True
self.spatial_weight_cache = dict()
def __call__(self, features: Dict[str, torch.tensor]) -> Dict[str, torch.tensor]:
ret = dict()
for key in features:
fea = features[key]
if fea.ndimension() == 4:
h, w = fea.shape[2:]
if (h, w) in self.spatial_weight_cache:
spatial_weight = self.spatial_weight_cache[(h, w)]
else:
sigma = min(h, w) / 2.0 / 3.0
x = torch.Tensor(range(w))
y = torch.Tensor(range(h))[:, None]
spatial_weight = torch.exp(-((x - (w - 1) / 2.0) ** 2 + (y - (h - 1) / 2.0) ** 2) / 2.0 / (sigma ** 2))
if torch.cuda.is_available():
spatial_weight = spatial_weight.cuda()
spatial_weight = spatial_weight[None, None, :, :]
self.spatial_weight_cache[(h, w)] = spatial_weight
fea = (fea * spatial_weight).sum(dim=(2, 3))
ret[key + "_{}".format(self.__class__.__name__)] = fea
else:
# In case of fc feature.
assert fea.ndimension() == 2
if self.first_show:
print("[SPoC Aggregator]: find 2-dimension feature map, skip aggregation")
self.first_show = False
ret[key] = fea
return ret