mirror of https://github.com/PyRetri/PyRetri.git
46 lines
1.3 KiB
Python
46 lines
1.3 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
import torch
|
|
import numpy as np
|
|
|
|
from ..splitter_base import SplitterBase
|
|
from ...registry import SPLITTERS
|
|
|
|
from typing import Dict
|
|
|
|
|
|
@SPLITTERS.register
|
|
class PCB(SplitterBase):
|
|
"""
|
|
PCB function to split feature maps.
|
|
c.f. http://openaccess.thecvf.com/content_ECCV_2018/papers/Yifan_Sun_Beyond_Part_Models_ECCV_2018_paper.pdf
|
|
|
|
Hyper-Params:
|
|
stripe_num (int): the number of stripes divided.
|
|
"""
|
|
default_hyper_params = {
|
|
'stripe_num': 2,
|
|
}
|
|
|
|
def __init__(self, hps: Dict or None = None):
|
|
"""
|
|
Args:
|
|
hps (dict): default hyper parameters in a dict (keys, values).
|
|
"""
|
|
super(PCB, self).__init__(hps)
|
|
|
|
def __call__(self, features: torch.tensor) -> Dict:
|
|
ret = dict()
|
|
for key in features:
|
|
fea = features[key]
|
|
assert fea.ndimension() == 4
|
|
assert self.default_hyper_params["stripe_num"] <= fea.shape[2], \
|
|
'stripe num must be less than or equal to the height of fea'
|
|
|
|
stride = fea.shape[2] // self.default_hyper_params["stripe_num"]
|
|
|
|
for i in range(int(self.default_hyper_params["stripe_num"])):
|
|
ret[key + "_part_{}".format(i)] = fea[:, :, stride * i: stride * (i + 1), :]
|
|
ret[key + "_global"] = fea
|
|
return ret
|