PyRetri/pyretri/extract/splitter/splitter_impl/pcb.py

46 lines
1.3 KiB
Python

# -*- coding: utf-8 -*-
import torch
import numpy as np
from ..splitter_base import SplitterBase
from ...registry import SPLITTERS
from typing import Dict
@SPLITTERS.register
class PCB(SplitterBase):
"""
PCB function to split feature maps.
c.f. http://openaccess.thecvf.com/content_ECCV_2018/papers/Yifan_Sun_Beyond_Part_Models_ECCV_2018_paper.pdf
Hyper-Params:
stripe_num (int): the number of stripes divided.
"""
default_hyper_params = {
'stripe_num': 2,
}
def __init__(self, hps: Dict or None = None):
"""
Args:
hps (dict): default hyper parameters in a dict (keys, values).
"""
super(PCB, self).__init__(hps)
def __call__(self, features: torch.tensor) -> Dict:
ret = dict()
for key in features:
fea = features[key]
assert fea.ndimension() == 4
assert self.default_hyper_params["stripe_num"] <= fea.shape[2], \
'stripe num must be less than or equal to the height of fea'
stride = fea.shape[2] // self.default_hyper_params["stripe_num"]
for i in range(int(self.default_hyper_params["stripe_num"])):
ret[key + "_part_{}".format(i)] = fea[:, :, stride * i: stride * (i + 1), :]
ret[key + "_global"] = fea
return ret