PyRetri/pyretri/extract/aggregator/aggregators_impl/pwa.py

82 lines
3.1 KiB
Python

# -*- coding: utf-8 -*-
import torch
import numpy as np
from ..aggregators_base import AggregatorBase
from ...registry import AGGREGATORS
from ....index.utils import feature_loader
from typing import Dict
@AGGREGATORS.register
class PWA(AggregatorBase):
"""
Part-based Weighting Aggregation.
c.f. https://arxiv.org/abs/1705.01247
Hyper-Params
train_fea_dir (str): path of feature dir for selecting channels.
n_proposal (int): number of proposals to be selected.
alpha (float): alpha for calculate spatial weight.
beta (float): beta for calculate spatial weight.
"""
default_hyper_params = {
"train_fea_dir": "",
"n_proposal": 25,
"alpha": 2.0,
"beta": 2.0,
"train_fea_names": ["pool5_GAP"],
}
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(PWA, self).__init__(hps)
self.first_show = True
assert self._hyper_params["train_fea_dir"] != ""
self.selected_proposals_idx = None
self.train()
def train(self) -> None:
n_proposal = self._hyper_params["n_proposal"]
stacked_fea, _, pos_info = feature_loader.load(
self._hyper_params["train_fea_dir"],
self._hyper_params["train_fea_names"]
)
self.selected_proposals_idx = dict()
for fea_name in pos_info:
st_idx, ed_idx = pos_info[fea_name]
fea = stacked_fea[:, st_idx: ed_idx]
assert fea.ndim == 2, "invalid train feature"
channel_variance = np.std(fea, axis=0)
selected_idx = channel_variance.argsort()[-n_proposal:]
fea_name = "_".join(fea_name.split("_")[:-1])
self.selected_proposals_idx[fea_name] = selected_idx.tolist()
def __call__(self, features: Dict[str, torch.tensor]) -> Dict[str, torch.tensor]:
alpha, beta = self._hyper_params["alpha"], self._hyper_params["beta"]
ret = dict()
for key in features:
fea = features[key]
if fea.ndimension() == 4:
assert (key in self.selected_proposals_idx), '{} is not in the {}'.format(key, self.selected_proposals_idx.keys())
proposals_idx = np.array(self.selected_proposals_idx[key])
proposals = fea[:, proposals_idx, :, :]
power_norm = (proposals ** alpha).sum(dim=(2, 3), keepdims=True) ** (1.0 / alpha)
normed_proposals = (proposals / (power_norm + 1e-5)) ** (1.0 / beta)
fea = (fea[:, None, :, :, :] * normed_proposals[:, :, None, :, :]).sum(dim=(3, 4))
fea = fea.view(fea.shape[0], -1)
ret[key + "_{}".format(self.__class__.__name__)] = fea
else:
# In case of fc feature.
assert fea.ndimension() == 2
if self.first_show:
print("[PWA Aggregator]: find 2-dimension feature map, skip aggregation")
self.first_show = False
ret[key] = fea
return ret