mirror of https://github.com/PyRetri/PyRetri.git
82 lines
3.1 KiB
Python
82 lines
3.1 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
import torch
|
|
import numpy as np
|
|
|
|
from ..aggregators_base import AggregatorBase
|
|
from ...registry import AGGREGATORS
|
|
from ....index.utils import feature_loader
|
|
|
|
from typing import Dict
|
|
|
|
@AGGREGATORS.register
|
|
class PWA(AggregatorBase):
|
|
"""
|
|
Part-based Weighting Aggregation.
|
|
c.f. https://arxiv.org/abs/1705.01247
|
|
|
|
Hyper-Params
|
|
train_fea_dir (str): path of feature dir for selecting channels.
|
|
n_proposal (int): number of proposals to be selected.
|
|
alpha (float): alpha for calculate spatial weight.
|
|
beta (float): beta for calculate spatial weight.
|
|
"""
|
|
|
|
default_hyper_params = {
|
|
"train_fea_dir": "",
|
|
"n_proposal": 25,
|
|
"alpha": 2.0,
|
|
"beta": 2.0,
|
|
"train_fea_names": ["pool5_GAP"],
|
|
}
|
|
|
|
def __init__(self, hps: Dict or None = None):
|
|
"""
|
|
Args:
|
|
hps (dict): default hyper parameters in a dict (keys, values).
|
|
"""
|
|
super(PWA, self).__init__(hps)
|
|
self.first_show = True
|
|
assert self._hyper_params["train_fea_dir"] != ""
|
|
self.selected_proposals_idx = None
|
|
self.train()
|
|
|
|
def train(self) -> None:
|
|
n_proposal = self._hyper_params["n_proposal"]
|
|
stacked_fea, _, pos_info = feature_loader.load(
|
|
self._hyper_params["train_fea_dir"],
|
|
self._hyper_params["train_fea_names"]
|
|
)
|
|
self.selected_proposals_idx = dict()
|
|
for fea_name in pos_info:
|
|
st_idx, ed_idx = pos_info[fea_name]
|
|
fea = stacked_fea[:, st_idx: ed_idx]
|
|
assert fea.ndim == 2, "invalid train feature"
|
|
channel_variance = np.std(fea, axis=0)
|
|
selected_idx = channel_variance.argsort()[-n_proposal:]
|
|
fea_name = "_".join(fea_name.split("_")[:-1])
|
|
self.selected_proposals_idx[fea_name] = selected_idx.tolist()
|
|
|
|
def __call__(self, features: Dict[str, torch.tensor]) -> Dict[str, torch.tensor]:
|
|
alpha, beta = self._hyper_params["alpha"], self._hyper_params["beta"]
|
|
ret = dict()
|
|
for key in features:
|
|
fea = features[key]
|
|
if fea.ndimension() == 4:
|
|
assert (key in self.selected_proposals_idx), '{} is not in the {}'.format(key, self.selected_proposals_idx.keys())
|
|
proposals_idx = np.array(self.selected_proposals_idx[key])
|
|
proposals = fea[:, proposals_idx, :, :]
|
|
power_norm = (proposals ** alpha).sum(dim=(2, 3), keepdims=True) ** (1.0 / alpha)
|
|
normed_proposals = (proposals / (power_norm + 1e-5)) ** (1.0 / beta)
|
|
fea = (fea[:, None, :, :, :] * normed_proposals[:, :, None, :, :]).sum(dim=(3, 4))
|
|
fea = fea.view(fea.shape[0], -1)
|
|
ret[key + "_{}".format(self.__class__.__name__)] = fea
|
|
else:
|
|
# In case of fc feature.
|
|
assert fea.ndimension() == 2
|
|
if self.first_show:
|
|
print("[PWA Aggregator]: find 2-dimension feature map, skip aggregation")
|
|
self.first_show = False
|
|
ret[key] = fea
|
|
return ret
|